In [None]:
import sys
import numpy as np
import torch
import pandas as pd
import scipy.stats as stats
import pingouin as pg
from scipy.optimize import curve_fit
from scipy.ndimage import gaussian_filter
from statsmodels.formula.api import ols
import statsmodels.api as sm
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Arial"

# Import custom modules
sys.path.append("../")
from models.network_hierarchical_recurrent import NetworkHierarchicalRecurrent
from models.network_feedforward_stacked import NetworkFeedforwardStacked
from virtual_physiology.VirtualNetworkPhysiology import VirtualPhysiology
from plotting_functions import *

In [None]:
MODEL_PATH = ''
VPHYS_PATH = ''
UNITS      = 800

# Load model
model, hyperparameters, loss_history = NetworkHierarchicalRecurrent.load(
    model_path=MODEL_PATH,
    device='cpu',
    plot_loss_history=False
)

# Load previously processed vphys data
vphys = VirtualPhysiology.load(
    data_path=VPHYS_PATH,
    model=model,
    hyperparameters=hyperparameters,
    hidden_units=[UNITS, UNITS, UNITS],
    frame_shape=(20, 40),
    device='cpu'
)

display(hyperparameters)

# Single unit plots

In [None]:
orientation_arr        = np.array(vphys.orientations)
temporal_frequency_arr = np.array(vphys.temporal_frequencies)
spatial_frequency_arr  = np.array(vphys.spatial_frequencies)

units = [vphys.data[0][132], vphys.data[1][35], vphys.data[2][2]]

fig, axs = plt.subplots(nrows=1, ncols=3, sharex=True, sharey=False)

for unit_idx, unit in enumerate(units):
    print(unit_idx)
        
    sf_idx  = np.where(spatial_frequency_arr == unit["preferred_sf"])[0][0]
    tf_idx  = np.where(temporal_frequency_arr == unit["preferred_tf"])[0][0]
    ori_idx = np.where(orientation_arr == unit["preferred_orientation"])[0][0]
    
    tf_curve = unit["mean_grating_responses"][sf_idx, ori_idx, :,]
    sf_curve = unit["mean_grating_responses"][:, ori_idx, tf_idx]
        
    axs[unit_idx].plot(temporal_frequency_arr, tf_curve, c='tab:blue', label='TF' if unit_idx==2 else None)
    axs[unit_idx].plot(spatial_frequency_arr, sf_curve, '--', c='tab:blue', label='SF' if unit_idx==2 else None)
    
    axs[unit_idx].set_yticks([])
    axs[unit_idx].set_xticks([0.0, 0.1, 0.2])
    if unit_idx == 1:
        axs[unit_idx].set_xlabel('Spatial/temporal frequency')
    if unit_idx == 0:
            axs[unit_idx].set_ylabel('Response')
    format_plot(axs[unit_idx], fontsize=20)

plt.gca().get_legend().set_bbox_to_anchor((1, 1))
plt.gca().get_legend().remove()
fig.set_size_inches(14,3)
plt.savefig('./figures/modulation_ratio/d.pdf', bbox_inches='tight')
plt.show()
    


# Modulation ratio

In [None]:
# https://journals.physiology.org/doi/prev/20171106-aop/pdf/10.1152/jn.00668.2004

V1_data_raw = """0.10973797959000778, 0.7321993335031416
0.12935738929670756, 0.3679296524391725
0.1901122873620213, 0.8706308523154551
0.18329272000698338, 0.7585775278525113
0.18229588870627783, 0.5772228943793516
0.2036100291739779, 0.9194127823059944
0.20423264652676465, 0.7969249379897148
0.2176239367619749, 0.5394936534216338
0.2167483636595806, 0.8776240660553577
0.24376723244945514, 0.4223061534216338
0.2511001264586374, 0.6973477616163493
0.2636244533880089, 0.6767397372219395
0.28502035733032954, 0.48075918631589176
0.32537567458777256, 0.8717395325425129
0.3383172951021069, 0.7803788955322593
0.33742796947732795, 0.5562688380879608
0.3536139900999173, 0.7237221939208696
0.3519931375115564, 0.8740421760910174
0.37937848670981955, 0.7904647435897438
0.40862906517286535, 0.3387628341581158
0.4433424065596303, 0.7693763664246904
0.4525875085620016, 0.892697852988623
0.45670893254709477, -0.3969317795890641
0.464198700335578, 0.9317148686716156
0.47635484092955127, 0.3630258745118018
0.47688904475690486, 0.49764523900492463
0.4669322434869828, 0.7824054911699782
0.49320772797076606, 0.6985295678383429
0.515056190714562, 0.3372064176855156
0.5336977383693473, 0.5646777511947216
0.5326342378768171, 0.29667562707711737
0.5596622102985928, 0.6818140813380882
0.5594009462157818, -0.780961553622977
0.6169648690085362, 0.5939383808796062
0.6373363943857527, 0.6985892660044153
0.6787519970984375, 0.5185992952963153
0.683490894376656, -0.8016853455595176
0.7305443187934171, 0.43905575315964407
0.7850989684653418, 0.761526843088363
0.8194593900435168, 0.7740477479198509
0.8762009909041233, 0.8241728866470177
0.8856468889064087, 0.6900524282560709
1.0778472835749104, 0.7905287059105356
1.086729264036797, 0.5143010273391071
1.1202322200300707, -0.2627388305678866
1.1815488124440061, 0.9402841795982081
1.3700195673110613, 0.38027864450670756
1.4923500270914456, -0.36322363653179024
1.5419033798974815, 0.6693371646289694
1.6178004436628122, -0.04110938902433081
1.6589535978052634, 0.27153843500594355
1.6883737271821544, -0.13743664413676826
1.694129828415949, -0.36322363653179024
1.7196000306652137, -0.09125584852509494
1.7281780039696018, 0.6734563380879608
1.760417636442242, -0.4219410470186544
1.7654836045401914, -0.2628667552094701
1.8233342240528336, -0.21265633338791434
1.8849218167758628, -0.3382783314229918
1.9099856552331125, 0.9488354498702183
1.9180555680396765, 0.18857930493899056
2.0603255367919506, -0.279752807898503
2.195717525509241, -0.2462365518036047
0.24888074888074896, 0.7179487179487181"""

hist_bins        = np.linspace(0, 2, 11)

V1_data = [float(l.split(', ')[0]) for l in V1_data_raw.split('\n')]
V1_data = np.histogram(V1_data, weights=np.ones_like(V1_data)/len(V1_data), bins=hist_bins)[0]

In [None]:
l         = 10
cc_thresh = 0

def get_modulation_ratio (u):
    y = u['optimum_grating_response'][l:]    

    x = np.arange(len(y))

    # Offset of sine wave
    mn   = np.mean(y)

    frq = np.fft.fftshift(np.fft.fftfreq(len(y)))[1+len(y)//2:]
    phs = np.fft.fftshift(np.angle(np.fft.fft(y-mn)))[1+len(y)//2:]
    mag = np.fft.fftshift(np.abs(np.fft.fft(y-mn)))[1+len(y)//2:]

    # Frequency and phase of sine wave
    max_frq = abs(2*np.pi*frq[np.argmax(mag)])
    max_phs = phs[np.argmax(mag)] + np.pi/2

    # Get amplitude
    max_amp = np.max(mag)/(len(y)/2)

    # Use as inital params for curve fit
    def fit_sine (x, y, p0):
        func = lambda x, a, b, c, d: a*np.sin(b*x + c) + d
        return curve_fit(func, x, y, p0=p0)[0]
    try:
        max_amp, max_frq, max_phs, mn = fit_sine (x, y, [max_amp, max_frq, max_phs, mn])
    except:
        pass

    # Get correlation between fit
    y_est = mn + max_amp*np.sin(max_frq*x + max_phs)
    cc    = stats.pearsonr(y, y_est)[0]

    # Get modulation ratio
    F1_F0 = max_amp/mn

    return F1_F0, cc

hist_bins        = np.linspace(0, 2, 11)
hist_bin_centres = hist_bins[1:] - np.diff(hist_bins)[0]/2
b                = np.diff(hist_bin_centres)[0]

raw_mod  = []
gabor_fit= []
mean_mod = []

model_data = []
for g_idx, g in enumerate(vphys.data):
    mod_arr = []
    gab_arr = []
    for u_idx, u in enumerate(g):
        try:
            F1_F0, cc = get_modulation_ratio (u)
        except:
            continue

        if F1_F0>=0 and F1_F0<=2 and cc>cc_thresh:
            mod_arr.append(F1_F0)   
            if 'gabor_r' in u:
                gab_arr.append(u['OSI'])
    binned_mod_arr, _ = np.histogram(mod_arr, weights=np.ones_like(mod_arr)/len(mod_arr), bins=hist_bins)
    model_data.append(binned_mod_arr)

    mean_mod.append(np.mean(mod_arr))
    raw_mod.append(mod_arr)
    gabor_fit.append(gab_arr)

# Ringach (2002)
#V1_data = np.array([70, 48, 38, 17, 11, 11, 18, 40, 40, 15])
#V1_data = V1_data/V1_data.sum()
# Levitt (1994)
V2_data = [0.23790776, 0.21709786, 0.14510686, 0.09223847, 0.06130484,
           0.06130484, 0.0832396 , 0.04049494, 0.04105737, 0.02024747]
neural_data = [
    V1_data,
    V2_data,
    []
]

cc_arr = []

fig, axs       = plt.subplots(nrows=1, ncols=3, sharex=True, sharey=True)    
for ax_idx, (ax, m, d) in enumerate(zip(axs, model_data, neural_data)):            
    if len(d):
        print('Model:', np.sum(hist_bin_centres*m)/np.sum(m), 'Neural:', np.sum(hist_bin_centres*d)/np.sum(d))
    else:
        print('Model:', np.sum(hist_bin_centres*m)/np.sum(m))


    ax.bar(hist_bin_centres-b/3, m, width=b/3, facecolor='tab:blue')
    if len(d):
        cc_arr.append(np.mean((m-d)**2))
        ax.bar(hist_bin_centres, d, width=b/3, facecolor='black')

    ax.plot([1, 1], [0, 0.225], '--', c='black')
    if ax_idx == 0:
        ax.set_ylabel('Proportion units')
    ax.set_xlabel('Modulation ratio')
    format_plot(ax, fontsize=20)

fig.set_size_inches(12, 4)
fig.tight_layout()
plt.show()

fig = plt.figure()
plt.plot([0, 1, 2], mean_mod, '.-')
plt.xticks([0, 1, 2], ['$\mathregular{G_0}$', '$\mathregular{G_1}$', '$\mathregular{G_2}$'])
plt.ylabel('F1/F0')
format_plot(fontsize=20)
fig.set_size_inches(2,1)
plt.show()

# Modulation ratio model comparison

In [None]:
l         = 10
cc_thresh = 0

# Ringach (2002)
#V1_data = np.array([70, 48, 38, 17, 11, 11, 18, 40, 40, 15])
#V1_data = V1_data/V1_data.sum()
# Levitt (1994)
V2_data = [0.23790776, 0.21709786, 0.14510686, 0.09223847, 0.06130484,
           0.06130484, 0.0832396 , 0.04049494, 0.04105737, 0.02024747]
neural_data = [
    V1_data,
    V2_data,
    []
]

TP_model, hyperparameters, loss_history = NetworkHierarchicalRecurrent.load(
    model_path='',
    device='cpu',
    plot_loss_history=False
)
TP_vphys = VirtualPhysiology.load(
    data_path='',
    model=TP_model,
    hyperparameters=hyperparameters,
    hidden_units=[800, 800, 800],
    frame_shape=(20, 40),
    device='cpu'
)

TP_ff_model = NetworkFeedforwardStacked([])
TP_ff_vphys = VirtualPhysiology.load(
    data_path='',
    model=TP_ff_model,
    hyperparameters=TP_ff_model.stacks[0].hyperparameters,
    hidden_units=[800, 800, 800],
    frame_shape=(20, 40),
    device='cpu'
)


AE_model, hyperparameters, loss_history = NetworkHierarchicalRecurrent.load(
    model_path='',
    device='cpu',
    plot_loss_history=False
)
AE_vphys = VirtualPhysiology.load(
    data_path='',
    model=AE_model,
    hyperparameters=hyperparameters,
    hidden_units=[800, 800, 800],
    frame_shape=(20, 40),
    device='cpu'
)

hist_bins        = np.linspace(0, 2, 11)
hist_bin_centres = hist_bins[1:] - np.diff(hist_bins)[0]/2
b                = np.diff(hist_bin_centres)[0]

all_model_data = []

for vphys in [TP_vphys, TP_ff_vphys, AE_vphys]:
    model_data = []
    
    for g_idx, g in enumerate(vphys.data):
        mod_arr = []
        for u_idx, u in enumerate(g):
            try:
                F1_F0, cc = get_modulation_ratio (u)
            except:
                continue
                
            
            if F1_F0>=0 and F1_F0<=2 and cc>0:
                mod_arr.append(F1_F0)    
        binned_mod_arr, _ = np.histogram(mod_arr, weights=np.ones_like(mod_arr)/len(mod_arr), bins=hist_bins)
        print(binned_mod_arr)
        model_data.append(binned_mod_arr)
    
    all_model_data.append(model_data)

In [None]:
COLORS = ['tab:blue', 'tab:red', 'tab:green']
LABELS = ['TP$\mathregular{_{full}}$', 'TP$\mathregular{_{FF}}$', 'Autoencoder']

def get_KS (a, b):
    a = np.array(a)
    b = np.array(b)
    
    a_dist = np.cumsum(a/a.sum())
    b_dist = np.cumsum(b/b.sum())
    
    return np.max(np.abs(a_dist-b_dist))

n_bars      = 4
total_width = 1/6
bar_width = total_width / n_bars

KS_dist_all = [[], [], []]

fig, axs       = plt.subplots(nrows=1, ncols=3, sharex=True, sharey=True)    
for ax_idx, (ax, m_tp, m_tp_ff, m_ae, d) in enumerate(zip(axs, *all_model_data, neural_data)): 
    for m_i, m in enumerate([m_tp, m_tp_ff, m_ae]):
        x_offset = (m_i - n_bars / 2) * bar_width + bar_width / 2
        ax.bar(hist_bin_centres+x_offset, m, width=bar_width, facecolor=COLORS[m_i], label=LABELS[m_i] if ax_idx==2 else '')
    
    if len(d):
        KS_dist_all[0].append(get_KS(d, m_tp))
        KS_dist_all[1].append(get_KS(d, m_tp_ff))
        KS_dist_all[2].append(get_KS(d, m_ae))
        
        x_offset = (3 - n_bars / 2) * bar_width + bar_width / 2
        ax.bar(hist_bin_centres+x_offset, d, width=bar_width, facecolor='gray')

    ax.plot([1, 1], [0, 1], '--', c='black')

    ax.set_xlabel('Modulation ratio')
    format_plot(ax, fontsize=20)
    ax.get_legend().set_bbox_to_anchor((1, 1))
    ax.get_legend().remove()
    
fig.set_size_inches(12, 3)
fig.tight_layout()
plt.savefig('./figures/model_comparison/A_i.pdf', bbox_inches='tight')
plt.show()

    
x = [0, 1, 2]
fig = plt.figure()
b = plt.bar(x, [np.mean(d) for d in KS_dist_all], yerr=[np.std(d)/(len(d)**0.5) for d in KS_dist_all], facecolor='tab:gray')
for b_i, c_i in zip(b, COLORS):
    b_i.set_facecolor(c_i)
plt.xticks(x, ['$\mathregular{TP_{full}}$', '$\mathregular{TP_{FF}}$', 'AE'], rotation=0)
plt.ylabel('KS distance')
format_plot(fontsize=20)
fig.set_size_inches(3, 4)
plt.savefig(f'./figures/model_comparison/A_ii.pdf', bbox_inches='tight')
plt.show()

# SF and TF

In [None]:
# https://elifesciences.org/articles/81794.pdf
# https://physoc.onlinelibrary.wiley.com/doi/abs/10.1113/jphysiol.1985.sp015776
# https://onlinelibrary.wiley.com/doi/pdf/10.1111/j.1460-9568.2007.05453.x
# https://www.jneurosci.org/content/jneuro/26/11/2941.full.pdf

In [None]:
mode_key = 'tf'
unit_model = 'cycles/pixel' if mode_key == 'sf' else 'cycles/frame'
unit_data  = 'cycles/degree' if mode_key == 'sf' else 'Hz'

macaque_sf  = [2.4, 2.16, 0.9]  # Jones et al., 2001; Hu 2018; Bair and Movshon, 2004
macaque_tf  = [7.8, 8.29, 11.5] # Bair and Movshon, 2004; Hu 2018, Bair and Movshon, 2004
macaque_data = macaque_sf if mode_key == 'sf' else macaque_tf

hist_bins = np.linspace(0, 0.25, 6)
shift = np.diff(hist_bins)[0]
hist_bins_centres = hist_bins[:-1] + shift/2

dist_rw = []
dist_mn = []
dist_er = []

fig, axs = plt.subplots(nrows=1, ncols=2, gridspec_kw={'width_ratios': [3, 1.5]})

for g_idx, (g, a) in enumerate(zip(vphys.data, [1, 0.5, 0.25])):
    dist = np.array([u[f'preferred_{mode_key}'] for u in g if u['DSI']>vphys.dsi_thresh])
    dist_mn.append(np.mean(dist))
    dist_er.append(np.std(dist)/(len(dist)**0.5))
    dist_rw.append(dist)
    
    hist = np.histogram(dist, weights=np.ones_like(dist)/len(dist), bins=hist_bins)[0]
    
    l = ''.join(['$\mathregular{G_{', str(g_idx), '}}}$'])    
    axs[0].bar(hist_bins_centres-shift/4 + shift/4*g_idx, hist, width=shift/4, facecolor='tab:blue', alpha=a, label=l)
axs[0].set_xticks(hist_bins_centres)
axs[0].set_xticklabels(np.round(hist_bins_centres, 2))
axs[0].set_ylabel('% units')
axs[0].set_xlabel(f'{mode_key.upper()} ({unit_model})')
format_plot(axs[0],fontsize=20)
    
axs[1].set_ylabel(f'Mean {mode_key.upper()} ({unit_model})')
for idx, (mn, er, a) in enumerate(zip(dist_mn, dist_er, [1, 0.5, 0.25])):
    axs[1].errorbar([idx], [mn], yerr=[er], marker='o', markersize=10, c='tab:blue', alpha=a)

axs_2 = axs[1].twinx()
axs_2.scatter([0], [macaque_data[0]], marker='o', s=100, c='black', alpha=1)
axs_2.scatter([1], [macaque_data[1]], marker='o', s=100, c='black', alpha=0.5)
axs_2.scatter([2], [macaque_data[2]], marker='o', s=100, c='black', alpha=0.25)
axs_2.set_ylabel(f'Mean {mode_key.upper()} ({unit_data})')

axs[1].set_xticks([0, 1, 2])
axs[1].set_xlim(-0.5, 2.5)
axs[1].set_xticklabels(['$\mathregular{G_{0}}$', '$\mathregular{G_{1}}$', '$\mathregular{G_{2}}$'])
format_plot(axs[1],fontsize=20)
format_plot(axs_2,fontsize=20)

axs[1].yaxis.label.set_color('tab:blue')
axs_2.spines['right'].set_visible(True)


fig.set_size_inches(12, 4)
plt.tight_layout()
axs[0].get_legend().remove()
plt.savefig(f'./figures/modulation_ratio/{"e" if mode_key=="tf" else "f"}.pdf', bbox_inches='tight')
plt.show()