In [None]:
import pickle, sys
import seaborn as sns
import torch
import numpy as np
import scipy.stats as stats
import scipy.ndimage as ndimage
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Arial"

# Import custom modules
sys.path.append("../models")
from models.network_hierarchical_recurrent_temporal_prediction import NetworkHierarchicalRecurrentTemporalPrediction as Network

from virtual_physiology.VirtualNetworkPhysiology import VirtualPhysiology
from plotting_functions import *
from connectivity_functions import *

np.random.seed(0)

#Â Add model and vphys paths here
MODEL_PATH = ''
VPHYS_PATH = ''

TOTAL_UNITS      = 2*36*36
EXCITATORY_UNITS = int(TOTAL_UNITS*0.9)
INHIBITORY_UNITS = TOTAL_UNITS-EXCITATORY_UNITS

# Load network checkpoint
model, hyperparameters, _ = NetworkHierarchicalRecurrentTemporalPrediction.load(
    model_path=MODEL_PATH, device='cpu', plot_loss_history=True
)

# Instantiate new VirtualPhysiology object
vphys = VirtualPhysiology.load(
    data_path=VPHYS_PATH,
    model=model,
    hyperparameters=hyperparameters,
    frame_shape=(36,36),
    hidden_units=[2592],
    device='cpu'
)

excitatory_units = [u for u in vphys.data[0] if u['hidden_unit_index'] <  EXCITATORY_UNITS]
inhibitory_units = [u for u in vphys.data[0] if u['hidden_unit_index'] >= EXCITATORY_UNITS]

In [None]:
DSI_MODE = 'DSI'
CONN_PCT = 95
LOCAL    = 'local'

def get_is_in_range (pre_unit, post_unit, mode):
    if mode not in ['short_range', 'long_range', 'all']:
        raise NotImplementedError
    
    pre_x, pre_y = pre_unit['gabor_x'], pre_unit['gabor_y']
    post_x, post_y = post_unit['gabor_x'], post_unit['gabor_y']
    
    dist = ((pre_x-post_x)**2 + (pre_y-post_y)**2)**0.5
    
    if (mode == 'short_range') and (dist < 2.5):
        return True
    elif (mode == 'long_range') and (dist > 5) and (dist < 9.166):
        return True
    elif (mode == 'all') and (LOCAL=='local') and (dist < 9.166):
        return True
    elif (mode == 'all') and (LOCAL=='all'):
        return True
    else:
        return False

# Get pre and post units 

In [None]:
total_units   = 36*36*2
excit_units_n = int(total_units*0.9)
weight_matrix = model.rnn.weight_hh_l0[:total_units, :total_units].detach().numpy()

inhib_threshold = np.percentile(np.abs(weight_matrix[:excit_units_n, excit_units_n:total_units]).reshape(-1), CONN_PCT)
excit_threshold = np.percentile(weight_matrix[:excit_units_n, :excit_units_n].reshape(-1), CONN_PCT)
all_threshold   = np.percentile(weight_matrix.reshape(-1), CONN_PCT)

weight_matrix = np.abs(weight_matrix)

excit_units     = []
inhib_units     = []

for unit in vphys.data[0]:
    if unit['hidden_unit_index'] < total_units:
        r = unit['gabor_r']
        sx, sy = unit['gabor_params'][4:6]

        if (unit['gabor_r'] > 0.7)                       and \
           (unit['gabor_x']<36) and (unit['gabor_y']<36) and \
           (unit['gabor_x']>0) and (unit['gabor_y']>0)   and \
           (sx>0.5 and sy>0.5):
            
            if unit['hidden_unit_index'] < excit_units_n:
                excit_units.append(unit)
            if unit['hidden_unit_index'] >= excit_units_n:
                inhib_units.append(unit)

CONNECTION_AXIS_MODE = 'all'
CONNECTION_DISTANCE_MODE = 'all'
ORIENTATION_MODE = 'direction'

# Get tuning curve

In [None]:
tuning_curve_shifted_arr = []

for u in excit_units:
    if DSI_MODE == 'DSI' and u['DSI'] < 0.8:
        continue
    elif DSI_MODE == 'nonDSI' and u['DSI'] > 0.8:
        continue
        
    pref_ori_idx = vphys.orientations.tolist().index(u['preferred_orientation'])
    tuning_curve = vphys.get_orientation_tuning_curve(u)
    tuning_curve_shifted = np.roll(tuning_curve, -(pref_ori_idx-18)).tolist()
    tuning_curve_shifted.append(tuning_curve_shifted[0])
    
    tuning_curve_shifted_arr.append(tuning_curve_shifted)

x = [*np.deg2rad(vphys.orientations), vphys.orientations[0]]
mn = np.mean(tuning_curve_shifted_arr, axis=0)
er = np.std(tuning_curve_shifted_arr, axis=0)/(len(tuning_curve_shifted_arr)**0.5)
    
ax = plt.subplot(111, projection='polar')
ax.fill(x, mn, 'tab:gray')
ax.set_theta_zero_location("N") 
ax.set_theta_direction(-1) 
ax.set_yticks([])
format_plot()
save_plot(3, f'{DSI_MODE}_tuning_curve')
plt.show()

# Pooled pre-synaptic density

In [None]:
excit_inhib_all_density_proportion_opposite = []
excit_inhib_all_density_proportion_ahead = []

for PRE_UNITS, THRESH, TYPE in zip([inhib_units, excit_units], [inhib_threshold, excit_threshold], ['inhib', 'excit']):
    all_density_proportion_opposite = []
    all_density_proportion_ahead    = []
    
    for post_unit_idx, post_unit in enumerate(excit_units):     
        if (post_unit_idx % 200) == 0:
            print('Starting unit', post_unit_idx, '/',  len(excit_units))
        
        if (DSI_MODE == 'DSI') and (post_unit['DSI'] < 0.8):
            continue
        elif (DSI_MODE == 'nonDSI') and (post_unit['DSI'] >= 0.8):
            continue
                        
        if post_unit['gabor_x'] > 26 or post_unit['gabor_x'] < 10:
            continue
        if post_unit['gabor_y'] > 26 or post_unit['gabor_y'] < 10:
            continue    
            
        pre_unit_ensemble_density = np.zeros((102, 102))
        
        all_pre_units     = []
        
        for pre_unit_idx, pre_unit in enumerate(PRE_UNITS):
            is_in_range       = get_is_in_range(pre_unit, post_unit, CONNECTION_DISTANCE_MODE)
            is_connected, w   = get_is_connected(pre_unit, post_unit, weight_matrix, THRESH)
            is_different_unit = pre_unit_idx != post_unit_idx

            if is_in_range and is_different_unit and is_connected:
                all_pre_units.append(pre_unit)
                              
        if not len(all_pre_units):
            continue            
            
        translation_matrix           = get_translation_matrix(post_unit)
        rotation_matrix              = get_rotation_matrix(post_unit)
        transformation               = rotation_matrix@translation_matrix
        all_pre_units_transformed    = apply_matrix_transformation(all_pre_units, transformation)
        all_pre_units_filtered       = apply_axis_filter(all_pre_units_transformed, CONNECTION_AXIS_MODE)
        orientation_differences      = get_orientation_differences(post_unit, all_pre_units_filtered, ORIENTATION_MODE)        
                
        weights                      = [abs(get_is_connected(u, post_unit, weight_matrix, THRESH)[1]) for u in all_pre_units_filtered]
        coords                       = [{'x': u['gabor_x'], 'y': u['gabor_y']} for u in all_pre_units_filtered]
        bins                         = [0, 180]
        
        heat_map = get_heat_maps(np.array(orientation_differences), np.array(coords), np.array(weights), bins, THRESH, full=True)[0]
        heat_map /= heat_map.sum()

        t = np.tri(102, 102)
        ahead = (t*np.rot90(t, 1)).T
        opp = np.rot90(ahead, 2)
        
        curve = heat_map.sum(axis=0)
        l = len(curve)//2
        proportion_opposite = np.sum(curve[:l])/np.sum(curve)
        all_density_proportion_opposite.append(proportion_opposite)
        all_density_proportion_ahead.append(1-proportion_opposite)
        
    excit_inhib_all_density_proportion_opposite.append(all_density_proportion_opposite)
    excit_inhib_all_density_proportion_ahead.append(all_density_proportion_ahead)
    
mn_opp = [np.mean(v) for v in excit_inhib_all_density_proportion_opposite]
er_opp = [np.std(v)/(len(v)**2) for v in excit_inhib_all_density_proportion_opposite]

mn_ahead = [np.mean(v) for v in excit_inhib_all_density_proportion_ahead]
er_ahead = [np.std(v)/(len(v)**2) for v in excit_inhib_all_density_proportion_ahead]

fig = plt.figure()        
plt.bar([0, 3], [mn_opp[0], mn_ahead[0]], yerr=[er_opp[0], er_ahead[0]], facecolor='tab:blue')
plt.bar([1, 4], [mn_opp[1], mn_ahead[1]], yerr=[er_opp[1], er_ahead[1]], facecolor='tab:red')
plt.plot([0, 4], [0.5, 0.5], '--', c='black')
plt.xticks([0.5, 3.5], ['Opposite', 'Ahead'])
plt.ylabel('Presynaptic density')
format_plot(fontsize=20)
fig.set_size_inches(4,4)
save_plot(3, f'{DSI_MODE}_pooled_density_bar_{LOCAL}')
plt.show()

print('Inhibitory')
print(pg.ttest(excit_inhib_all_density_proportion_opposite[0], 0.5))

print('Excitatory')
print(pg.ttest(excit_inhib_all_density_proportion_opposite[1], 0.5))

print('Inhibitory vs excitatory')
print(pg.ttest(excit_inhib_all_density_proportion_opposite[0], excit_inhib_all_density_proportion_opposite[1]))

In [None]:
curves = []
heat_maps = []

for PRE_UNITS, THRESH, TYPE in zip([inhib_units, excit_units], [inhib_threshold, excit_threshold], ['inhib', 'excit']):
    all_orientation_differences = []
    all_weights = []
    all_coords = []
    
    all_ahead_density    = []
    all_opposite_density = []

    for post_unit_idx, post_unit in enumerate(excit_units):     
        if (post_unit_idx % 200) == 0:
            print('Starting unit', post_unit_idx, '/',  len(excit_units))
        
        if (DSI_MODE == 'DSI') and (post_unit['DSI'] < 0.8):
            continue
        elif (DSI_MODE == 'nonDSI') and (post_unit['DSI'] >= 0.8):
            continue
                        
        if post_unit['gabor_x'] > 26 or post_unit['gabor_x'] < 10:
            continue
        if post_unit['gabor_y'] > 26 or post_unit['gabor_y'] < 10:
            continue    
            
        pre_unit_ensemble_density = np.zeros((102, 102))
        
        all_pre_units     = []
        
        for pre_unit_idx, pre_unit in enumerate(PRE_UNITS):
            is_in_range       = get_is_in_range(pre_unit, post_unit, CONNECTION_DISTANCE_MODE)
            is_connected, w   = get_is_connected(pre_unit, post_unit, weight_matrix, THRESH)
            is_different_unit = pre_unit_idx != post_unit_idx

            if is_in_range and is_different_unit and is_connected:
                all_pre_units.append(pre_unit)
                              
        if not len(all_pre_units):
            continue            
            
        translation_matrix           = get_translation_matrix(post_unit)
        rotation_matrix              = get_rotation_matrix(post_unit)
        transformation               = rotation_matrix@translation_matrix
        all_pre_units_transformed    = apply_matrix_transformation(all_pre_units, transformation)
        all_pre_units_filtered       = apply_axis_filter(all_pre_units_transformed, CONNECTION_AXIS_MODE)
        orientation_differences      = get_orientation_differences(post_unit, all_pre_units_filtered, ORIENTATION_MODE)        
                
        all_orientation_differences += orientation_differences
        all_weights                 += [abs(get_is_connected(u, post_unit, weight_matrix, THRESH)[1]) for u in all_pre_units_filtered]
        all_coords                  += [{'x': u['gabor_x'], 'y': u['gabor_y']} for u in all_pre_units_filtered]    
            
    all_orientation_differences      = np.array(all_orientation_differences)
    all_weights                      = np.array(all_weights)
    all_coords                       = np.array(all_coords)
    bins                             = [0, 180]

    heat_map = get_heat_maps(all_orientation_differences, all_coords, all_weights, bins, THRESH, full=True)[0]
    heat_map /= heat_map.sum()
        
    heat_maps.append(heat_map)
    curves.append(heat_map.sum(axis=0))
     
    
    
    
    
        
exc_curve = curves[1]
inh_curve = curves[0]
diff_curve = exc_curve-inh_curve
max_v, min_v = max(max(exc_curve), max(inh_curve), max(diff_curve)), min(min(exc_curve), min(inh_curve), min(diff_curve))

fig = plt.figure()
plt.plot([len(inh_curve)//2, len(inh_curve)//2], [min_v, max_v], '--', c='gray', linewidth=2)
plt.plot(inh_curve, label='Inhibitory', c='tab:blue')
plt.plot(exc_curve, label='Excitatory', c='tab:red')
plt.plot(diff_curve, c='black', label='Exc-Inh')
plt.xticks([])
plt.xlabel('Visual position')
plt.ylabel('Presynaptic density (normed)')
format_plot(fontsize=18)
save_plot(3, f'{DSI_MODE}_pooled_density_line_raw_{LOCAL}')
plt.show()

movmean = lambda x: np.convolve(x, np.ones(5), mode='same')/5

exc_curve = movmean(curves[1])
inh_curve = movmean(curves[0])
diff_curve = exc_curve-inh_curve
max_v, min_v = max(max(exc_curve), max(inh_curve), max(diff_curve)), min(min(exc_curve), min(inh_curve), min(diff_curve))

fig = plt.figure()
x = np.arange(-51, 51)
plt.plot([0, 0], [min_v, max_v], '--', c='gray', linewidth=2)
plt.plot(x, inh_curve, label='Inhibitory', c='tab:blue')
plt.plot(x, exc_curve, label='Excitatory', c='tab:red')
plt.plot(x, diff_curve, c='black', label='Exc-Inh')
plt.xticks()
plt.xlim(-16, 16)
plt.xlabel('Visual position (pixels)')
plt.ylabel('Presynaptic density\n(Normalized)')
format_plot(fontsize=20)
fig.set_size_inches(4,4)
plt.gca().get_legend().set(bbox_to_anchor=(1,1))
save_plot(3, f'{DSI_MODE}_pooled_density_line_smoothed_{LOCAL}')
plt.show()

In [None]:
ims = []
coords = []

for heat_map in heat_maps:
    heat_map = ndimage.gaussian_filter(heat_map, 2)
    heat_map_thresh = np.percentile(heat_map.reshape(-1), 99)
    binary_mask = np.where(heat_map>heat_map_thresh, 1, 0)
    l = label(binary_mask)
    largest_area_binary_mask = (l==np.bincount(l.ravel())[1:].argmax()+1).astype(int)
    
    ims.append(largest_area_binary_mask)
    
    mask_cv2 = cv2.cvtColor(np.uint8(largest_area_binary_mask), cv2.COLOR_GRAY2BGR)
    mask_cv2 = cv2.cvtColor(mask_cv2, cv2.COLOR_BGR2GRAY)
    contours, _ = cv2.findContours(mask_cv2, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)    
    coords.append(contours[0])
    
    c = heat_map.shape[1]//2
    
fig = plt.figure()
plt.gca().set_aspect(1)
for color, coords_set in zip(['tab:blue', 'tab:red'], coords):
    x = [c[0][0] for c in coords_set]
    y = [c[0][1] for c in coords_set]
    
    x.append(x[0])
    y.append(y[0])

    plt.plot(x, y, c=color)
plt.xlim(c-10, c+10)
plt.ylim(c-10, c+10)
plt.xticks([])
plt.yticks([])
plt.plot([c, c], [0, 100], '--', c='black')
format_plot()
plt.gca().spines['top'].set_visible(True)
plt.gca().spines['right'].set_visible(True)
save_plot(3, f'{DSI_MODE}_pooled_density_outline_{LOCAL}')
plt.show()

# Exemplar presynaptic ensembles

In [None]:
for post_unit_idx, post_unit in enumerate(excit_units):     
    if (post_unit_idx % 200) == 0:
        print('Starting unit', post_unit_idx, '/',  len(excit_units))

    if post_unit['DSI'] < 0.8:
        continue

    if post_unit['gabor_x'] > 26 or post_unit['gabor_x'] < 10:
        continue
    if post_unit['gabor_y'] > 26 or post_unit['gabor_y'] < 10:
        continue    
                    
    pre_unit_ensembles = []
        
    for PRE_UNITS, THRESH, TYPE in zip([inhib_units, excit_units], [inhib_threshold, excit_threshold], ['inhib', 'excit']):
        all_pre_units = []
        for pre_unit_idx, pre_unit in enumerate(PRE_UNITS):
            is_in_range       = get_is_in_range(pre_unit, post_unit, CONNECTION_DISTANCE_MODE)
            is_connected, w   = get_is_connected(pre_unit, post_unit, weight_matrix, THRESH)
            is_different_unit = pre_unit_idx != post_unit_idx

            if is_in_range and is_different_unit and is_connected:
                all_pre_units.append(pre_unit)

        if not len(all_pre_units):
            continue

        translation_matrix           = get_translation_matrix(post_unit)
        rotation_matrix              = get_rotation_matrix(post_unit)
        transformation               = rotation_matrix@translation_matrix
        all_pre_units_transformed    = apply_matrix_transformation(all_pre_units, transformation)
        all_pre_units_filtered       = apply_axis_filter(all_pre_units_transformed, CONNECTION_AXIS_MODE)
        orientation_differences      = get_orientation_differences(post_unit, all_pre_units_filtered, ORIENTATION_MODE)        

        if len(all_pre_units_filtered):
            pre_unit_ensembles.append(all_pre_units_filtered)
        
        
    if len(pre_unit_ensembles) == 2:
        fig = plt.figure(dpi=100, figsize=[4,4])
        plt.gca().set_aspect(1)
        plt.gca().spines['top'].set_visible(True)
        plt.gca().spines['right'].set_visible(True)
        plt.xticks([])
        plt.yticks([])
        
        plt.plot([0, 0], [30, -30], '--', c='black')
        plt.plot([0], [0], marker='>', c='black')
        plt.xticks([])
        plt.yticks([])
        plt.gca().spines['top'].set_visible(True)
        
        for pre_unit_ensemble, c in zip([pre_unit_ensembles[1], pre_unit_ensembles[0]], ['red', 'blue']):
            for u in pre_unit_ensemble:
                plt.scatter(u['gabor_x'], u['gabor_y'], c=c, alpha=0.5)
        plt.xlim(-30, 30)
        plt.ylim(-30, 30)
        save_plot(3, f'exemplar_ensembles_{post_unit_idx}')
        plt.show()