In [None]:
import pickle, sys
import torch
import math
import numpy as np
import pandas as pd
import statsmodels.api as sm
import scipy.stats as stats
import matplotlib.pyplot as plt

import scipy.ndimage as ndimage

plt.rcParams["font.family"] = "Arial"

sys.path.append("../")
from models.network_hierarchical_recurrent import NetworkHierarchicalRecurrent
from virtual_physiology.VirtualNetworkPhysiology import VirtualPhysiology
from plotting_functions import *
from connectivity_functions import *
np.random.seed(0)

TOTAL_UNITS      = 800
EXCITATORY_UNITS = int(TOTAL_UNITS*0.8)
INHIBITORY_UNITS = TOTAL_UNITS-EXCITATORY_UNITS

MODEL_PATH = ''
VPHYS_PATH = ''

# Load network checkpoint
model, hyperparameters, _ = NetworkHierarchicalRecurrent.load(
    model_path=MODEL_PATH, device='cpu', plot_loss_history=True
)

# Instantiate new VirtualPhysiology object
vphys = VirtualPhysiology.load(
    data_path=VPHYS_PATH,
    model=model,
    hyperparameters=hyperparameters,
    frame_shape=(20, 40),
    hidden_units=[800, 800, 800],
    device='cpu'
)

excitatory_units = [u for u in vphys.data[0] if u['hidden_unit_index'] <  EXCITATORY_UNITS]
inhibitory_units = [u for u in vphys.data[0] if u['hidden_unit_index'] >= EXCITATORY_UNITS]

# Unit connectivity

# Ho

In [None]:
def get_cochrane_armitage (data_a, data_b):
    contingency_table = {idx: [a, b] for idx, (a, b) in enumerate(zip(data_a, data_b)) }
    contingency_table_pd = pd.DataFrame(contingency_table, index=[0, 1])
    contingency_table_sm = sm.stats.Table(contingency_table_pd)

    return contingency_table_sm.test_ordinal_association(
        row_scores=np.array([0,1]),
        col_scores=np.arange(len(data_a)),
    )

In [None]:
total_units   = 20*40
excit_units_n = int(total_units*0.8)
weight_matrix = model.rnn.weight_hh_l0.detach().numpy()
weight_matrix = weight_matrix[:excit_units_n, :excit_units_n]

threshold = np.percentile(np.abs(weight_matrix.reshape(-1)), 95)

all_units     = []
for unit in vphys.data[0]:
    if unit['hidden_unit_index'] < excit_units_n:
        r = unit['gabor_r']
        sx, sy = unit['gabor_params'][4:6]

        if unit['gabor_r'] > 0.7 and sx>0.5 and sy>0.5:
            all_units.append(unit)        
        
ALL_MODES = [
    ['short_range', 'all', 'orientation', [10, 11, 4], [26, 44, 24]],
    ['short_range', 'all', 'direction',   [6, 3, 4, 4, 2], [13, 16, 22, 16, 5]],
]

for CONNECTION_DISTANCE_MODE, CONNECTION_AXIS_MODE, ORIENTATION_MODE, NEURAL_DATA_A, NEURAL_DATA_B in ALL_MODES:
    NEURAL_DATA = np.array(NEURAL_DATA_A)/np.array(NEURAL_DATA_B)
    
    all_orientation_differences = []
    all_post_unit_orientations  = []
    all_pre_unit_orientations   = []
    all_weights = []

    for post_unit_idx, post_unit in enumerate(all_units):    
        if (post_unit_idx % 200) == 0:
            print('Starting unit', post_unit_idx, '/',  len(all_units))
        
        all_pre_units     = []
        
        for pre_unit_idx, pre_unit in enumerate(all_units):
            is_in_range       = get_is_in_range(pre_unit, post_unit, CONNECTION_DISTANCE_MODE)
            is_connected, w   = get_is_connected(pre_unit, post_unit, weight_matrix, threshold)
            is_different_unit = pre_unit_idx != post_unit_idx
            is_selective      = is_orientation_or_direction_selective(pre_unit, post_unit, ORIENTATION_MODE)

            if is_in_range and is_different_unit and is_selective:
                all_pre_units.append(pre_unit)
                       
        translation_matrix           = get_translation_matrix(post_unit)
        rotation_matrix              = get_rotation_matrix(post_unit)
        transformation               = rotation_matrix@translation_matrix
        all_pre_units_transformed    = apply_matrix_transformation(all_pre_units, transformation)
        all_pre_units_filtered       = apply_axis_filter(all_pre_units_transformed, CONNECTION_AXIS_MODE)
        orientation_differences      = get_orientation_differences(post_unit, all_pre_units_filtered, ORIENTATION_MODE)        
        
        all_orientation_differences += orientation_differences
        all_weights                 += [abs(get_is_connected(u, post_unit, weight_matrix, threshold)[1]) for u in all_pre_units_filtered]
    
        for u in all_pre_units_filtered:
            all_post_unit_orientations.append(post_unit['preferred_orientation'])
            all_pre_unit_orientations.append(u['preferred_orientation'])

    
    all_orientation_differences      = np.array(all_orientation_differences)
    all_weights                      = np.array(all_weights)
    all_post_unit_orientations       = np.array(all_post_unit_orientations)
    all_pre_unit_orientations        = np.array(all_pre_unit_orientations)
    
    bins        = [0, 22.5, 67.5, 112.5, 157.5, 180] if ORIENTATION_MODE == 'direction' else [0, 22.5, 67.5, 90] # [0, 30, 90]
    bin_centres = [0, 45, 90, 135, 180] if ORIENTATION_MODE == 'direction' else [0, 45, 90]
    x_ticks     = np.arange(len(bin_centres))

    connected_binned_data, _  = np.histogram(all_orientation_differences[all_weights>threshold], bins=bins)
    all_binned_data, _        = np.histogram(all_orientation_differences, bins=bins)
    true_dist                 = connected_binned_data/all_binned_data
        
    if ORIENTATION_MODE == 'orientation':        
        print(get_cochrane_armitage(connected_binned_data, all_binned_data))
    else:
        print(get_cochrane_armitage(connected_binned_data[:3], all_binned_data[:3]))
        print(get_cochrane_armitage(connected_binned_data[2:], all_binned_data[2:]))
                
    fig = plt.figure()    
    plt.bar(x_ticks-1/3, true_dist, facecolor='black', width=1/3)        
    plt.bar(x_ticks, NEURAL_DATA, facecolor='gray', width=1/3)
    plt.xticks(x_ticks-1/6, bin_centres)
    plt.xlabel(f"{'Orientation' if ORIENTATION_MODE == 'orientation' else 'Direction'} difference (°)")
    plt.ylabel('Connection probability')    
    plt.ylim(0, 0.5)
    format_plot(fontsize=20)
    fig.set_size_inches(4,4)
    save_plot('connectivity', f'COMPLEX_{CONNECTION_DISTANCE_MODE}_{CONNECTION_AXIS_MODE}_{ORIENTATION_MODE}')
    plt.show()

# Iacoruso

In [None]:
def get_shuffled_binned_data_iacoruso (post_unit_orientations, pre_unit_orientations, bins, total_units):
    iters = 10000
    shuffled_results = np.zeros((len(bins)-1, iters))
    
    pre_unit_orientations_copy = pre_unit_orientations.copy()
    
    for i in range(iters):
        if i % 1000 == 0:
            print('Iteration', i)
            
        np.random.shuffle(pre_unit_orientations_copy)
                
        orientation_differences = []
        for a1, a2 in zip(post_unit_orientations, pre_unit_orientations_copy):
            freq_diff = abs((a1%180)-(a2%180))
            freq_diff = min(freq_diff, 180-freq_diff)
            orientation_differences.append(freq_diff)

        binned_data, _ = np.histogram(orientation_differences, bins=bins)

        shuffled_results[:, i] = binned_data/total_units
        
    return shuffled_results


## Bar plots

In [None]:
total_units   = 800
excit_units_n = int(total_units*0.8)
weight_matrix = model.rnn.weight_hh_l0[:excit_units_n, :excit_units_n].detach().numpy()

threshold = np.percentile(weight_matrix.reshape(-1), 95)

all_units     = []
for unit in vphys.data[0]:
    if unit['hidden_unit_index'] < excit_units_n:
        r = unit['gabor_r']
        sx, sy = unit['gabor_params'][4:6]

        if unit['gabor_r'] > 0.7 and sx>0.5 and sy>0.5:
            all_units.append(unit)
        
        
ALL_MODES = [
    ['long_range', 'coaxial', 'orientation'],
    ['long_range', 'orthogonal', 'orientation']
]

axial_results = {
    'orientation_differences': [],
    'post_unit_orientations' : [],
    'pre_unit_orientations'  : [],        
    'neural_data_a'          : [[48, 33, 16], [20, 19, 23]],
    'neural_data_b'          : [[159, 159, 159], [159, 159, 159]],
    'save_string'            : ['long_range_coaxial_orientation', 'long_range_orthogonal_orientation']
}

for CONNECTION_DISTANCE_MODE, CONNECTION_AXIS_MODE, ORIENTATION_MODE in ALL_MODES:
    all_orientation_differences = []
    all_post_unit_orientations  = []
    all_pre_unit_orientations   = []

    for post_unit_idx, post_unit in enumerate(all_units):        
        if (post_unit_idx % 200) == 0:
            print('Starting unit', post_unit_idx, '/',  len(all_units))
        
        all_pre_units     = []
        
        for pre_unit_idx, pre_unit in enumerate(all_units):
            is_in_range       = get_is_in_range(pre_unit, post_unit, CONNECTION_DISTANCE_MODE)
            is_connected, w   = get_is_connected(pre_unit, post_unit, weight_matrix, threshold)
            is_different_unit = pre_unit_idx != post_unit_idx
            is_selective      = is_orientation_or_direction_selective(pre_unit, post_unit, ORIENTATION_MODE)

            if is_in_range and is_different_unit and is_selective and is_connected:
                all_pre_units.append(pre_unit)
                                              
        if not len(all_pre_units):
            continue
                       
        translation_matrix           = get_translation_matrix(post_unit)
        rotation_matrix              = get_rotation_matrix(post_unit)
        transformation               = rotation_matrix@translation_matrix
        all_pre_units_transformed    = apply_matrix_transformation(all_pre_units, transformation)
        all_pre_units_filtered       = apply_axis_filter(all_pre_units_transformed, CONNECTION_AXIS_MODE)
        orientation_differences      = get_orientation_differences(post_unit, all_pre_units_filtered, ORIENTATION_MODE)        
        
        all_orientation_differences += orientation_differences
    
        for u in all_pre_units_filtered:
            all_post_unit_orientations.append(post_unit['preferred_orientation'])
            all_pre_unit_orientations.append(u['preferred_orientation'])
    
    all_orientation_differences      = np.array(all_orientation_differences)
    all_post_unit_orientations       = np.array(all_post_unit_orientations)
    all_pre_unit_orientations        = np.array(all_pre_unit_orientations)
    
    axial_results['orientation_differences'].append(all_orientation_differences)
    axial_results['post_unit_orientations'].append(all_post_unit_orientations)
    axial_results['pre_unit_orientations'].append(all_pre_unit_orientations)

total_units = np.concatenate(axial_results['orientation_differences']).shape[0]

for orientation_differences, post_unit_orientations, pre_unit_orientations, neural_data_a, neural_data_b, save_string in zip(*axial_results.values()):    
    neural_data = np.array(neural_data_a)/np.array(neural_data_b)
    
    bins        = [0, 22.5, 67.5, 90]
    bin_centres = [0, 45, 90]
    x_ticks     = np.arange(len(bin_centres))
    
    connected_binned_data, _  = np.histogram(orientation_differences, bins=bins)    
    true_fraction             = connected_binned_data/total_units
    null_dist                 = get_shuffled_binned_data_iacoruso(
        post_unit_orientations, pre_unit_orientations, bins, total_units
    )
        
    print(true_fraction)
    print(np.percentile(null_dist, 2.5, axis=1))
    print(np.percentile(null_dist, 97.5, axis=1))
    p_vals = [get_p_val(true_fraction[i], null_dist[i]) for i in range(len(connected_binned_data))]
    print(p_vals)
        
    fig = plt.figure()
    plt.bar(x_ticks-1/3, true_fraction, facecolor='black', width=1/3) 
    plt.bar(x_ticks, neural_data, facecolor='gray', width=1/3)
    
    plt.xticks(x_ticks-1/6, bin_centres)
    plt.xlabel(f"{'Orientation' if ORIENTATION_MODE == 'orientation' else 'Direction'} difference (°)")
    plt.ylabel('Presynaptic fraction')
    plt.ylim(0, 0.5)
    format_plot(fontsize=20)
    fig.set_size_inches(4,4)
    save_plot('connectivity', f'{save_string}')
    plt.show()


## Heatmap

In [None]:
def plot_heatmaps (m, nan_mask, cmap_str):
    fig, axs = plt.subplots(nrows=1, ncols=len(bins)-1, dpi=100, figsize=[6, 4])
    
    final_ims = []
    for h_idx, h in enumerate(m):
        filt = ndimage.gaussian_filter(np.nan_to_num(h.copy(), 0), 2).reshape(-1)
        filt[nan_mask] = np.nan
        filt = filt.reshape(h.shape)
        final_ims.append(filt)
        
    vmax = np.nanmax([np.nanmax(h) for h in final_ims])
    vmin = np.nanmin([np.nanmin(h) for h in final_ims])
        
    for h_idx, (h, ax) in enumerate(zip(final_ims, axs)):
        cmap = ax.imshow(h, origin='lower', cmap=cmap_str, vmax=vmax, vmin=vmin, interpolation=None)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.spines['top'].set_visible(True)
        if h_idx == len(bins)-2:
            ticks = [round(vmin, 2), round((vmin+vmax)/2, 2), round(vmax-0.01, 2)]
            cbar = plt.colorbar(cmap, cax=fig.add_axes([1, 0.275, 0.03, 0.45]), ticks=ticks)
            cbar.ax.tick_params(labelsize=20)    
    plt.tight_layout()

In [None]:
total_units   = 800
excit_units_n = int(total_units*0.8)
weight_matrix = model.rnn.weight_hh_l0[:excit_units_n, :excit_units_n].detach().numpy()

threshold = np.percentile(weight_matrix.reshape(-1), 95)

all_units     = []
for unit in vphys.data[0]:
    if unit['hidden_unit_index'] < excit_units_n:
        r = unit['gabor_r']
        sx, sy = unit['gabor_params'][4:6]

        if unit['gabor_r'] > 0.7 and sx>0.5 and sy>0.5: # and unit['modulation_ratio'] and unit['modulation_ratio']<1:
            all_units.append(unit)
        

ALL_MODES = [
    ['long_range', 'all', 'orientation'],
]

for CONNECTION_DISTANCE_MODE, CONNECTION_AXIS_MODE, ORIENTATION_MODE in ALL_MODES:
    all_orientation_differences = []
    all_weights = []
    all_coords = []
  
    for post_unit_idx, post_unit in enumerate(all_units):
        if (post_unit_idx % 200) == 0:
            print('Starting unit', post_unit_idx, '/',  len(all_units))
        
        all_pre_units     = []
        
        for pre_unit_idx, pre_unit in enumerate(all_units):            
            is_in_range       = get_is_in_range(pre_unit, post_unit, CONNECTION_DISTANCE_MODE)
            is_connected, w   = get_is_connected(pre_unit, post_unit, weight_matrix, threshold)
            is_different_unit = pre_unit_idx != post_unit_idx
            is_selective      = is_orientation_or_direction_selective(pre_unit, post_unit, ORIENTATION_MODE)

            if is_in_range and is_different_unit and is_selective:
                all_pre_units.append(pre_unit)
                              
        if not len(all_pre_units):
            continue
            
        translation_matrix           = get_translation_matrix(post_unit)
        rotation_matrix              = get_rotation_matrix(post_unit)
        transformation               = rotation_matrix@translation_matrix
        all_pre_units_transformed    = apply_matrix_transformation(all_pre_units, transformation)
        all_pre_units_filtered       = apply_axis_filter(all_pre_units_transformed, CONNECTION_AXIS_MODE)
        orientation_differences      = get_orientation_differences(post_unit, all_pre_units_filtered, ORIENTATION_MODE)

        all_orientation_differences += orientation_differences
        all_weights                 += [get_is_connected(u, post_unit, weight_matrix, threshold)[1] for u in all_pre_units_filtered]
        all_coords                  += [{'x': u['gabor_x'], 'y': u['gabor_y']} for u in all_pre_units_filtered]    
    
    all_orientation_differences      = np.array(all_orientation_differences)
    all_weights                      = np.array(all_weights)
    all_coords                       = np.array(all_coords)

    bins               = [0, 22.5, 67.5, 90]
    true_heat_maps     = get_heat_maps(all_orientation_differences, all_coords, all_weights, bins, threshold)
    all_heat_maps      = get_heat_maps(all_orientation_differences, all_coords, all_weights, bins, 0)

    normalized_heat_maps = true_heat_maps/all_heat_maps
    heat_maps_to_use = normalized_heat_maps
    
    nan_mask = np.argwhere(np.isnan(normalized_heat_maps[0].reshape(-1)))[:, 0]
    
    plot_heatmaps(normalized_heat_maps, nan_mask, cmap_str='viridis')
    save_plot('connectivity', 'COMPLEX_heatmap')
    plt.show()