In [None]:
import sys
import torch
import math
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pingouin as pg

from connectivity_functions import *
from plotting_functions import *

plt.rcParams["font.family"] = "Arial"
plt.rcParams.update({'mathtext.default':  'regular' })

sys.path.append("../models")

from virtual_physiology.VirtualNetworkPhysiology import VirtualPhysiology

from models.models.network_hierarchical_recurrent_temporal_prediction import NetworkHierarchicalRecurrentTemporalPrediction
from models.network_hierarchical_recurrent_denoise        import NetworkHierarchicalRecurrentDenoise
from models.network_hierarchical_recurrent                import NetworkHierarchicalRecurrent as NetworkHierarchicalRecurrent
from models.network_hierarchical_recurrent_autoencoder    import NetworkHierarchicalRecurrentAutoencoder


In [None]:
label = {
    'temporal_prediction': 'Temporal\nprediction',
    'inpainting': 'Inpainting',
    'denoise': 'Denoise',
    'autoencoder': 'Sparse\nautoencoder',
}

color = {
    'temporal_prediction': 'tab:red',
    'inpainting': 'tab:purple',
    'denoise': 'tab:orange',
    'autoencoder': 'tab:pink',
}

models = {
    'temporal_prediction': [
        NetworkHierarchicalRecurrentTemporalPrediction,
        '', # Model path
        ''  # Vphys path
    ],
    'inpainting': [
        NetworkHierarchicalRecurrent,
        '', # Model path
        ''  # Vphys path
    ],
    'denoise': [
        NetworkHierarchicalRecurrentDenoise,
        '', # Model path
        ''  # Vphys path
    ],
    'autoencoder': [
        NetworkHierarchicalRecurrentAutoencoder,
        '', # Model path
        ''  # Vphys path
    ]
}

# Unit statistics

In [None]:
OS_prop_arr = []
DS_prop_arr = []
other_prop_arr = []

labels = []
colors = []

for model_key, model_data in models.items():
    print('Starting', model_key)
    
    Network, model_path, vphys_path = model_data
    
    model, hyperparameters, _ = Network.load(model_path, 'cpu')
    print('\tLoaded model')
    
    vphys = VirtualPhysiology.load(
        data_path=vphys_path,
        model=model,
        hyperparameters=hyperparameters,
        frame_shape=(36, 36),
        hidden_units=[2592],
        device='cpu'
    )

    print('\tLoaded vphys')
    
    all_units = []
    for unit in vphys.data[0]:
        r = unit['gabor_r']
        sx, sy = unit['gabor_params'][4:6]
        if r > 0.7 and sx>0.5 and sy>0.5:
            all_units.append(unit)

    OS_prop = len([u for u in all_units if u['OSI']>vphys.osi_thresh and u['DSI']<=vphys.dsi_thresh])/len(all_units)
    DS_prop = len([u for u in all_units if u['DSI']>vphys.dsi_thresh])/len(all_units)
    
    OS_prop_arr.append(OS_prop)
    DS_prop_arr.append(DS_prop)
    other_prop_arr.append(1-OS_prop-DS_prop)
    
    labels.append(label[model_key])
    colors.append(color[model_key])


v1_data = np.load('./v1_data/drifting_grating_tuning.npy', allow_pickle=True).item()
OS_prop = len([OSI for OSI, DSI in zip(v1_data['OSI'], v1_data['DSI']) if OSI>vphys.osi_thresh and DSI<=vphys.dsi_thresh])/len(v1_data['OSI'])
DS_prop = len([DSI for DSI in v1_data['DSI'] if DSI>vphys.dsi_thresh])/len(v1_data['DSI'])
OS_prop_arr.append(OS_prop)
DS_prop_arr.append(DS_prop)
other_prop_arr.append(1-OS_prop-DS_prop)
labels.append('Mouse V1')

OS_prop_arr = [OS_prop_arr[-1], *OS_prop_arr[:-1]]
DS_prop_arr = [DS_prop_arr[-1], *DS_prop_arr[:-1]]
other_prop_arr = [other_prop_arr[-1], *other_prop_arr[:-1]]
labels = [labels[-1], *labels[:-1]]

In [None]:
width = 0.5

bar_counts = {
    'Orientation-selective': OS_prop_arr,
    'Direction-selective': DS_prop_arr,
    'Non-selective': other_prop_arr
    
}

fig, ax = plt.subplots()
bottom = np.zeros(6)

for c, (l, bar_count) in zip(['black', 'tab:gray', 'white'], bar_counts.items()):
    p = ax.bar(labels, bar_count, width, label=l, bottom=bottom, facecolor=c, edgecolor='black')
    bottom += bar_count

plt.ylabel('Proportion units')
format_plot(fontsize=18)
plt.gca().get_legend().set_bbox_to_anchor((1.05, 1.05))
fig.set_size_inches(8, 4)
save_plot(4, 'unit_properties')
plt.show()


# Connectivity score

In [None]:
scores_rw = []
scores_mn = []
labels    = []
colors    = []

connectivity_data = get_connectivity_data()

for k in label.keys():
    scores_mn.append(connectivity_data[k][0])
    scores_rw.append(connectivity_data[k][2])
    labels.append(label[k])
    colors.append(color[k])

x = np.arange(len(scores))
    
fig = plt.figure()
b = plt.bar(x, scores_mn) 
for c_, b_ in zip(colors, b):
    b_.set_facecolor(c_)
plt.xticks(x, labels)
plt.ylabel('Mean correlation with V1')
format_plot(fontsize=18)
fig.set_size_inches(6, 4)
save_plot(4, 'connectivity_performance')
plt.show()