In [None]:
import pickle, sys
import seaborn as sns
import torch
import numpy as np
import scipy.stats as stats
import scipy.ndimage as ndimage
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Arial"

# Import custom modules
sys.path.append("../models")
from models.network_hierarchical_recurrent_temporal_prediction import NetworkHierarchicalRecurrentTemporalPrediction as Network

from virtual_physiology.VirtualNetworkPhysiology import VirtualPhysiology
from plotting_functions import *
from connectivity_functions import *

np.random.seed(0)

# Add model and vphys paths here
MODEL_PATH = ''
VPHYS_PATH = ''

TOTAL_UNITS      = 2*36*36
EXCITATORY_UNITS = int(TOTAL_UNITS*0.9)
INHIBITORY_UNITS = TOTAL_UNITS-EXCITATORY_UNITS

# Load network checkpoint
model, hyperparameters, _ = NetworkHierarchicalRecurrentTemporalPrediction.load(
    model_path=MODEL_PATH, device='cpu', plot_loss_history=True
)

# Instantiate new VirtualPhysiology object
vphys = VirtualPhysiology.load(
    data_path=VPHYS_PATH,
    model=model,
    hyperparameters=hyperparameters,
    frame_shape=(36,36),
    hidden_units=[2592],
    device='cpu'
)

excitatory_units = [u for u in vphys.data[0] if u['hidden_unit_index'] <  EXCITATORY_UNITS]
inhibitory_units = [u for u in vphys.data[0] if u['hidden_unit_index'] >= EXCITATORY_UNITS]

# Exemplar unit RFs

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=3, dpi=100, figsize=[6, 4], sharex=True, sharey=True)

chosen_units = [60, 129, 130, 249, 261, 328, 363, 475, 510, 539, 574, 932, 960, 1010, 1011, 1023, 1115]

exemplar_units = []

for u_idx, u in enumerate(excitatory_units):
    if u_idx in chosen_units:
        exemplar_units.append(u)

exemplar_units = np.random.choice(exemplar_units, 6, replace=False)

for ax, unit in zip(axs.reshape(-1), exemplar_units):
    rf = unit['response_weighted_average'].reshape(36, 36)
    
    max_r, max_c = np.unravel_index(np.argmax((rf**2).reshape(-1)), (36, 36))
        
    r_plus = (36-max_r)
    if r_plus > 15:
        min_r = max(max_r-13, 0)
        rf = rf[min_r:min_r+26, :]
    else:
        rf = rf[10:]

    c_plus = (36-max_c)
    if c_plus > 15:
        min_c = max(max_c-13, 0)
        rf = rf[:, min_c:min_c+26]
    else:
        rf = rf[:, 10:]
        
    vmax = np.max(np.abs(rf))
    
    ax.imshow(rf, vmax=vmax, vmin=-vmax, cmap='bwr')
    ax.set_xticks([])
    ax.set_yticks([])
save_plot(1, 'model_example_rfs_7')
plt.show()

# OSI and DSI distribution

In [None]:
v1_data = np.load('./v1_data/drifting_grating_tuning.npy', allow_pickle=True).item()

In [None]:
# Polar plot of orientation vs temporal frequency

temporal_frequency_arr = [1, 2, 4, 8, 15]
orientation_arr        = [0, 45, 90, 135, 180, 225, 270, 315]

for idx, (OSI, DSI, pref_ori, responses) in enumerate(zip(v1_data['OSI'], v1_data['DSI'], v1_data['pref_ori'], v1_data['response'])):
    if idx not in [66, 14, 33, 159]:
        continue
        
    print(idx)
        
    z     = np.concatenate((responses, responses[:1]),axis=0)
    rad   = temporal_frequency_arr
    azm   = np.radians([*orientation_arr, orientation_arr[0]])
    r, th = np.meshgrid(rad, azm)

    fig = plt.figure(dpi=100, figsize=[2,2])
    ax = plt.subplot(projection="polar")
    im = plt.pcolormesh(th, r, z, shading='gouraud') 

    font_size = 20
      
    ax.set_yticks([])
    ax.tick_params(axis="x", labelsize=font_size)
    plt.setp(ax.get_yticklabels(), color="w", size=font_size)
    ax.set_rlabel_position(0)
    ax.set_theta_direction(-1)
    ax.set_theta_zero_location('N')
    ax.spines['polar']. set_visible(False) 
    plt.grid(axis='x')
    save_plot(1, f'polar_plot_V1_{idx}')
    plt.show()


In [None]:
orientation_arr        = np.array(vphys.orientations)
temporal_frequency_arr = np.array(vphys.temporal_frequencies)
spatial_frequency_arr  = np.array(vphys.spatial_frequencies)

for unit_idx, unit in enumerate(vphys.data[0]):
    if unit['OSI'] < vphys.osi_thresh:
        continue
        
    if not unit['preferred_orientation'] > 40 and not unit['preferred_orientation'] < 60:
        continue
        
    if not unit_idx in [101, 495, 725, 689]:
        continue
        
    print(unit_idx)
        
    sf_idx = np.where(spatial_frequency_arr == unit["preferred_sf"])[0][0]
    z     = unit["mean_grating_responses"][sf_idx, :, :]
    z     = np.concatenate((z, z[:1]),axis=0)
    rad   = np.arange(len(temporal_frequency_arr))
    azm   = np.radians([*orientation_arr, orientation_arr[0]])
    r, th = np.meshgrid(rad, azm)

    fig = plt.figure(dpi=100, figsize=[2,2])
    ax = plt.subplot(projection="polar")
    im = plt.pcolormesh(th, r, z, shading='gouraud') 

    font_size = 20
     
    ax.set_yticks([])
    ax.tick_params(axis="x", labelsize=font_size)
    plt.setp(ax.get_yticklabels(), color="w", size=font_size)
    ax.set_rlabel_position(0)
    ax.set_theta_direction(-1)
    ax.set_theta_zero_location('N')
    ax.spines['polar']. set_visible(False) 
    plt.grid(axis='x')
    
    save_plot(1, f'polar_plot_model_{unit_idx}')
    plt.show()        

In [None]:
OSI_dist = np.array([u['OSI'] for u in excitatory_units])
DSI_dist = np.array([u['DSI'] for u in excitatory_units])

OSI_dist_V1 = np.array(v1_data['OSI'])
DSI_dist_V1 = np.array(v1_data['DSI'])

fig = plt.figure()
plt.hist(
    [OSI_dist, OSI_dist_V1],
    weights=[np.ones_like(OSI_dist)/len(OSI_dist), np.ones_like(OSI_dist_V1)/len(OSI_dist_V1)],
    color=['black', 'gray']
)
plt.ylabel('Proportion units')
plt.xlabel('OSI')
format_plot(plt.gca(), fontsize=20)
save_plot(1, 'OSI')
plt.show()

fig = plt.figure()
b=plt.hist(
    [DSI_dist, DSI_dist_V1],
    weights=[np.ones_like(DSI_dist)/len(DSI_dist), np.ones_like(DSI_dist_V1)/len(DSI_dist_V1)],
    label=['Model', 'V1'],
    color=['black', 'gray']
)
    
plt.ylabel('Proportion units')
plt.xlabel('DSI')
format_plot(plt.gca(), fontsize=20)
save_plot(1, 'DSI')
plt.show()


# Preferred orientation and direction

In [None]:
def bin_orientations (pref_ori):
    pref_ori_binned = {
        0:   0,
        45:  0,
        90:  0,
        135: 0
    }

    for ori in pref_ori:
        if ori < 22.5 or ori > 157.5:
            pref_ori_binned[0] += 1
        elif ori >= 22.5 and ori < 67.5:
            pref_ori_binned[45] += 1
        elif ori >= 67.5 and ori < 112.5:
            pref_ori_binned[90] += 1
        else:
            pref_ori_binned[135] +=1
            
    return pref_ori_binned

def norm_bin_values (pref_ori_binned):
    b = np.array(list(pref_ori_binned.values()))
    return b/b.sum()

pref_ori    = [u['preferred_orientation']%180 for u in vphys.data[0] if u['OSI']>vphys.osi_thresh]
pref_ori_V1 = [ori%180 for ori, OSI in zip(v1_data['pref_ori'], v1_data['OSI']) if OSI>vphys.osi_thresh]

pref_ori_binned    = bin_orientations (pref_ori)
pref_ori_V1_binned = bin_orientations (pref_ori_V1)


fig = plt.figure()
x = np.arange(len(pref_ori_binned))

plt.bar(x-1/6, norm_bin_values(pref_ori_binned), width=1/3, facecolor='black')
plt.bar(x+1/6, norm_bin_values(pref_ori_V1_binned), width=1/3, facecolor='gray')

plt.xticks(x, list(pref_ori_binned.keys()))
plt.xlabel('Preferred orientation (degrees)')
plt.ylabel('Proportion units')
format_plot(plt.gca(), fontsize=20)
save_plot(1, 'preferred_orientation')
plt.show()


In [None]:
def bin_directions (pref_dir):
    pref_dir_binned = {
        0:   0,
        45:  0,
        90:  0,
        135: 0,
        180: 0,
        225: 0,
        270: 0,
        315: 0
    }

    for dir_ in pref_dir:
        if dir_ < 22.5 or dir_ > 337.5:
            pref_dir_binned[0] += 1
        elif dir_ >= 22.5 and dir_ < 67.5:
            pref_dir_binned[45] += 1
        elif dir_ >= 67.5 and dir_ < 112.5:
            pref_dir_binned[90] += 1
        elif dir_ >= 112.5 and dir_ < 157.5:
            pref_dir_binned[135] += 1
        elif dir_ >= 157.5 and dir_ < 202.5:
            pref_dir_binned[180] += 1
        elif dir_ >= 202.5 and dir_ < 247.5:
            pref_dir_binned[225] += 1
        elif dir_ >= 247.5 and dir_ < 292.5:
            pref_dir_binned[270] += 1
        else:
            pref_dir_binned[315] += 1
            
    return pref_dir_binned
        
pref_dir = [u['preferred_orientation'] for u in vphys.data[0] if u['DSI']>vphys.dsi_thresh]
pref_dir_V1 = [ori for ori, OSI in zip(v1_data['pref_ori'], v1_data['DSI']) if OSI>vphys.dsi_thresh]

pref_dir_binned    = bin_directions (pref_dir)
pref_dir_V1_binned = bin_directions (pref_dir_V1)


fig = plt.figure()
x = np.arange(len(pref_dir_binned))

plt.bar(x-1/6, norm_bin_values(pref_dir_binned), width=1/3, facecolor='black')
plt.bar(x+1/6, norm_bin_values(pref_dir_V1_binned), width=1/3, facecolor='gray')

plt.xticks(x, list(pref_dir_binned.keys()))
plt.xlabel('Preferred direction (degrees)')
plt.ylabel('Proportion units')
format_plot(plt.gca(), fontsize=20)
save_plot(1, 'preferred_direction')
plt.show()



# Modulation distribution

In [None]:
def fit_sine (x, y, verbose=False):
    # Fit to sine
    def func(x, a, b, c, d):
        return a*np.sin(b*x + c) + d

    def get_mse_loss (y, y_est):
        return np.sum((y-y_est)**2)/ len(y)

    # Get r_squared from https://stackoverflow.com/a/37899817
    def get_rsq (y, y_est):
        residuals = y - y_est
        ss_res = np.sum(residuals**2)
        ss_tot = np.sum((y-np.mean(y))**2)
        r_squared = 1 - (ss_res / ss_tot)

        return r_squared

    
    best_params = []

    for iteration in range(5):
        scale = 1-0.2*iteration
        n_random_guesses = 10000 if iteration == 0 else 1000

        params = []
        loss_list = []

        for _ in range(n_random_guesses):
            if iteration == 0:
                rand_params = [
                    np.random.uniform(low=np.mean(y)-2.5, high=np.mean(y)+2.5),
                    np.random.uniform(low=0, high=10),
                    np.random.uniform(low=0, high=len(y)),
                    np.random.uniform(low=np.min(y)-2.5, high=np.max(y)+2.5)
                ]
            else:
                prev_best_params = best_params[-1]
                rand_params_ = [
                    np.random.uniform(low=-2*scale, high=2*scale),
                    np.random.uniform(low=-1*scale, high=1*scale),
                    np.random.uniform(low=-2*scale, high=2*scale),
                    np.random.uniform(low=-2*scale, high=2*scale)
                ]
                rand_params = [p+prev_best_params[idx] for idx, p in enumerate(rand_params_)]

            # Get the estimated curve based on fitted parameters
            y_est = func(x, *rand_params)
            loss = get_mse_loss(y, y_est) #get_rsq(y, y_est)

            params.append(rand_params)
            loss_list.append(loss)

        # Get the index of the lowest RSQ, use this to find the
        # corresponding parameters used
        best_params.append(params[np.argmin(loss_list)])

        if verbose:
            print('Iteration {}, min loss = {}'.format(iteration, min(loss_list)))

    final_params = best_params[-1]
    final_y_est = func(x, *final_params)
    final_loss = get_mse_loss(y, final_y_est)
    final_rsq = get_rsq(y, final_y_est)

    return final_params, final_y_est, final_rsq, final_loss


# Takes list of response as well as a start and
# end offset for where curve fitting should occur
# Returns modulation ratio, estimated curve and RSQ of curve fit
def get_modulation_ratio (activity, start_offset, end_offset, verbose=False):
    x = np.arange(start_offset, end_offset)
    y = activity[start_offset:end_offset]

    final_params, final_y_est, final_rsq, final_loss = fit_sine (x, y)
    # Try one more time if it fails
    if final_loss < 0.05 or final_rsq > 0.5:
        final_params, final_y_est, final_rsq, final_loss = fit_sine (x, y)

    # Average unit activity
    f0 = np.mean(activity[vphys.warmup:])
    # Absolute of the amplitude of the fitted sine
    f1 = (abs(final_params[0]))
    mod_ratio = f1/f0

    # Reject f values for those units with poor sine fits
    #if (final_loss < 0.05 or final_rsq > 0.5) and f0 != 0:
    if stats.pearsonr(final_y_est, y)[0] > 0.9 and f0 != 0:
        return mod_ratio, final_y_est, final_rsq, final_loss, final_params
    else:
        return False, final_y_est, final_rsq, final_loss, final_params


In [None]:
MI_all = []
MI_exc = []
MI_inh = []

for unit_i, unit_data in enumerate(vphys.data[0]):    
    if unit_i % 10 == 0:
        print('Starting unit', unit_i)
        
    unit_data["modulation_ratio"], unit_data["modulation_ratio_y"], unit_data["modulation_ratio_rsq"], unit_data["modulation_ratio_loss"], unit_data["modulation_ratio_params"] = \
        get_modulation_ratio(unit_data['optimum_grating_response'], vphys.warmup, vphys.t_steps)
    
    if unit_data["modulation_ratio"] != False:
        if unit_data['hidden_unit_index'] < EXCITATORY_UNITS:
            MI_exc.append(unit_data["modulation_ratio"])
        else:
            MI_inh.append(unit_data["modulation_ratio"])

        MI_all.append(unit_data["modulation_ratio"])

In [None]:
x           = np.arange(12)
bins        = np.linspace(0, 2, 12+1)
bin_centres = np.round(bins[:-1] + np.diff(bins)[0]/2, 1)

# https://www.jneurosci.org/content/36/48/12144.abstract
v1_exc_count = np.array([
    0,
    0.8174603174603176,
    1.0158730158730158,
    1.0079365079365081,
    0.4761904761904763,
    0.35714285714285726,
    0.5476190476190479,
    0.7063492063492064,
    0.6269841269841271,
    0.3968253968253969,
    0.1984126984126987,
    0.2380952380952383
])
v1_exc_count = v1_exc_count/v1_exc_count.sum() # * 176

v1_inh_count = np.array([
    0,
    0.20901639344262302,
    1.0122950819672132,
    0.7131147540983607,
    0.9098360655737706,
    0.6106557377049181,
    0.21311475409836064,
    0.31147540983606564,
    0.40983606557377056,
    0.11065573770491807,
    0.11065573770491807,
    0.20901639344262302
])
v1_inh_count = v1_inh_count/v1_inh_count.sum() # * 51

v1_all_count = (v1_exc_count+v1_inh_count)/(v1_exc_count+v1_inh_count).sum() # * 51

model_all_count, _ = np.histogram(MI_all, bins)
model_all_count = model_all_count/model_all_count.sum()

fig = plt.figure(dpi=100)
plt.bar(x, v1_all_count, width=1/3, label='Mouse V1', facecolor='black')
plt.bar(x-1/3, model_all_count, width=1/3, label='Model', facecolor='gray')
plt.xticks([x[0], x[-1]], [0, 2])
plt.xlabel('Modulation ratio')
plt.ylabel('Proportion of units')
format_plot(fontsize=20)
save_plot(1, 'modulation_ratio_all_OS')
plt.show()

# Exemplar tuning curves

In [None]:
def plot_orientation (ax, unit_data):
    orientations = np.deg2rad(vphys.orientations).tolist()
    orientation_tuning_curve = vphys.get_orientation_tuning_curve(unit_data).tolist()
    
    orientations.append(orientations[0])
    orientation_tuning_curve.append(orientation_tuning_curve[0])
    

    ax.plot(orientations, orientation_tuning_curve, c='black')
    ax.set_theta_zero_location("N") 
    ax.set_theta_direction(-1)
    ax.tick_params(labelsize=20)
    ax.set_rticks([])

for offset in range(100):
    if offset != 72:
        continue
    
    fig, axs = plt.subplots(nrows=1, ncols=2, dpi=100, subplot_kw={'projection': 'polar'}, figsize=[6, 4])

    for unit_data in excitatory_units[30:]:
        if (unit_data['OSI']>vphys.osi_thresh and unit_data['DSI']<vphys.dsi_thresh and unit_data['preferred_orientation']==15):
            plot_orientation(axs[0], unit_data)
            break


    for unit_data in excitatory_units[offset:]:
        if (unit_data['DSI'] > vphys.dsi_thresh) and (unit_data['preferred_orientation']==225):
            plot_orientation(axs[1], unit_data)
            break


    plt.tight_layout()
    save_plot(1, 'example_tuning_curve')
    plt.show()