In [None]:
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import pingouin as pg
import scipy.stats as stats
from scipy.stats import pearsonr
from scipy.stats import norm
from scipy.optimize import curve_fit
plt.rcParams["font.family"] = "Arial"

# Import custom modules
sys.path.append("../")
from models.network_feedforward_stacked import NetworkFeedforwardStacked
from models.network_hierarchical_recurrent import NetworkHierarchicalRecurrent
from virtual_physiology.VirtualNetworkPhysiology import VirtualPhysiology
from plotting_functions import *

MODEL_PATH = ''
VPHYS_PATH = ''
UNITS      = 800
DEVICE     = 'cpu'

In [None]:
# Load model
model, hyperparameters, loss_history = NetworkHierarchicalRecurrent.load(
    model_path=MODEL_PATH,
    device='cpu',
    plot_loss_history=False
)

# Load previously processed vphys data
vphys = VirtualPhysiology.load(
    data_path=VPHYS_PATH,
    model=model,
    hyperparameters=hyperparameters,
    hidden_units=[UNITS, UNITS, UNITS],
    frame_shape=(20, 40),
    device='cpu'
)

physiology_data_group1 = vphys.data[0]
physiology_data_group2 = vphys.data[1]
physiology_data_group3 = vphys.data[2]

# Global vars
FRAME_SIZE           = hyperparameters["frame_size"]
WARMUP               = hyperparameters["warmup"]
T_STEPS              = 50
TEMPORAL_FREQUENCIES = vphys.temporal_frequencies
SPATIAL_FREQUENCIES  = vphys.spatial_frequencies
ORIENTATIONS         = vphys.orientations
FRAME_SHAPE          = (20, 40)

HEIGHT, WIDTH = FRAME_SHAPE

In [None]:
def load_no_fb_model (pct_inactivated):
    # Load model
    model_no_fb, hyperparameters, loss_history = NetworkHierarchicalRecurrent.load(
        model_path=MODEL_PATH,
        device='cpu',
        plot_loss_history=False
    )
    
    print('\t\tGetting masks')
    mask_1 = np.random.choice([0,1], size=(800, 800), p=[pct_inactivated,1-pct_inactivated])
    mask_2 = np.random.choice([0,1], size=(800, 800), p=[pct_inactivated,1-pct_inactivated])
    mask_3 = np.random.choice([0,1], size=(800, 800), p=[pct_inactivated,1-pct_inactivated])
    print('\t\tGot masks')
    
    weights = model_no_fb.rnn.weight_hh_l0.detach().cpu().numpy().copy()
    
    weights[0:800, 800:1600] = weights[0:800, 800:1600]*mask_1
    weights[800:1600, 1600:2400] = weights[800:1600, 1600:2400]*mask_2
        
    model_no_fb.rnn.weight_hh_l0 = torch.nn.Parameter(torch.Tensor(weights).to(DEVICE))

    return model_no_fb

model_no_fb = load_no_fb_model(1)

# Functions

In [None]:
def R(x, G_e, G_i, W_e, W_i):
    
    phi_W_e = norm(loc=0, scale=W_e).cdf
    phi_W_i = norm(loc=0, scale=W_i).cdf

    
    r = (G_e*(phi_W_e(x)-phi_W_e(-x))) / (1 + G_i*(phi_W_i(x)-phi_W_i(-x)))
    
    return r**2

def get_grating_stimuli (spatial_frequency, orientation, temporal_frequency, grating_amplitude, frames):
    y_size, x_size = FRAME_SHAPE

    theta = (orientation-90) * np.pi/180
    x, y = np.meshgrid(np.arange(0, x_size), np.arange(0, y_size))
    x_theta = x * np.cos(theta) + y * np.sin(theta)

    phase_shift = 2*np.pi*temporal_frequency
    phases = np.arange(frames)*phase_shift

    grating_frames = []
    for phase in phases:
        grating_frames.append( grating_amplitude * np.sin(2*spatial_frequency*np.pi*x_theta - phase) )


    gratings = np.array(grating_frames).reshape(1, frames, y_size*x_size)
    #gratings = (gratings-np.mean(gratings))/np.std(gratings)
    gratings = torch.Tensor(gratings).to(DEVICE)

    return gratings

def create_circular_mask(h, w, center=None, radius=None):
    if center is None: # use the middle of the image
        center = (int(w/2), int(h/2))
    if radius is None: # use the smallest distance between the center and image walls
        radius = min(center[0], center[1], w-center[0], h-center[1])

    Y, X = np.ogrid[:h, :w]
    dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)

    mask = dist_from_center <= radius
    return mask

def get_suppression_idx (responses, idxs):
    sup_arr = []
    res_arr = []
    idx_arr = []
    
    for unit_idx, unit_response in zip(idxs, responses[idxs]):   
        max_y_idx, max_x_idx, max_size_idx = np.unravel_index(
            unit_response.argmax(), unit_response.shape
        )
        response_by_size = unit_response[max_y_idx, max_x_idx]
        
        if np.mean(response_by_size) != 0:
            largest_stim_res = response_by_size[-1]
            maximum_stim_res = np.max(response_by_size)
            ss_pct           = (maximum_stim_res-largest_stim_res)/maximum_stim_res * 100
                        
            sup_arr.append(ss_pct)
            res_arr.append(response_by_size)
            idx_arr.append(unit_idx)
        
    return np.array(sup_arr), np.array(res_arr), np.array(idx_arr)

def get_init_params (x, y):
    params = []
    loss = []
    
    for i in range(1000):
        G_e = np.random.uniform(0, 20)
        G_i = np.random.uniform(0, 20)
        W_e =  np.random.uniform(0, 20)
        W_i = np.random.uniform(0, 20)
                
        p0 = [G_e, G_i, W_e, W_i]
        
        y_hat = R(x, *p0)

        loss.append( np.mean((y-y_hat)**2) )
        params.append(p0)
        
    return params[np.argmin(loss)]

def fit_and_plot(x, y, label, c, p0=None):
    xvals_model = np.linspace(0, x[-1], 100)
    
    if p0 == None:
        p0 = get_init_params(x, y)
    fitted_params, _ = curve_fit(R, x, y, p0=p0, maxfev=10000)
    
    plt.plot(xvals_model, R(xvals_model, *fitted_params), '--', c=c, alpha=0.75, label=label)
    plt.scatter(x, y)

def R(x, G_e, G_i, W_e, W_i):
    
    phi_W_e = norm(loc=0, scale=W_e).cdf
    phi_W_i = norm(loc=0, scale=W_i).cdf

    
    r = (G_e*(phi_W_e(x)-phi_W_e(-x))) / (1 + G_i*(phi_W_i(x)-phi_W_i(-x)))
    
    return r**2

def get_model_responses (grating_amplitude):
    warmup = hyperparameters['warmup']
    width = int(hyperparameters['frame_size']**0.5)

    xy_vals = [0]
    sizes = np.arange(2, 40+2, 4)[:]
    frames = 100
    
    units_len, sizes_len = len(physiology_data_group1), len(sizes)
    responses = np.zeros((units_len, 1, 1, sizes_len))
    responses_no_fb = np.zeros((units_len, 1, 1, sizes_len))

    params = sizes_len*units_len
    param_count = 0

    for unit_idx, unit in enumerate(physiology_data_group1):
        x_centre, y_centre = unit['gabor_x'], unit['gabor_y']

        g = get_grating_stimuli(
            spatial_frequency=unit['preferred_sf'],
            orientation=unit['preferred_orientation'],
            temporal_frequency=unit['preferred_tf'],
            grating_amplitude=grating_amplitude,
            frames=frames
        )
        g = g.cpu().detach().numpy().reshape(frames, HEIGHT, WIDTH)


        for size_idx, size in enumerate(sizes):
            g_classical = g.copy()
            mask = create_circular_mask(HEIGHT, WIDTH, center=(x_centre-0.5, y_centre-0.5), radius=size)

            g_classical[:, ~mask] = 0

            g_classical = torch.Tensor(g_classical.reshape(frames, HEIGHT*WIDTH)).unsqueeze(dim=0).to(DEVICE)

            with torch.no_grad():
                # Full model
                _, hidden_states = model(g_classical)
                unit_responses = np.mean(
                    hidden_states[0, -4:, unit['hidden_unit_index']].detach().cpu().numpy(),
                    axis=0
                )
                responses[unit_idx, 0, 0, size_idx] = unit_responses

                # No feedback model
                _, hidden_states = model_no_fb(g_classical)
                unit_responses = np.nanmean(
                    hidden_states[0, -4:, unit['hidden_unit_index']].detach().cpu().numpy(),
                    axis=0
                )
                responses_no_fb[unit_idx, 0, 0, size_idx] = unit_responses

            if param_count % 100 == 0:
                print(param_count, '/', params)
            param_count += 1

    return responses, responses_no_fb

# RF size

https://www.nature.com/articles/s41467-018-04500-5

In [None]:
responses, responses_no_fb = get_model_responses (grating_amplitude=2)

In [None]:
def get_paired_lists (indices, curves, idxs):
    indices_a, indices_b = indices
    curves_a, curves_b = curves
    idxs_a, idxs_b = idxs
    
    idxs_a, idxs_b = list(idxs_a), list(idxs_b)
    
    idxs_joint = list(set(idxs_a).intersection(idxs_b))
    
    ret_indices_a, ret_indices_b = [], []
    ret_curves_a, ret_curves_b = [], []
    ret_idxs_a, ret_idxs_b = [], []
    
    for idx in idxs_joint:
        idx_a = idxs_a.index(idx)
        idx_b = idxs_b.index(idx)
        
        ret_indices_a.append(indices_a[idx_a])
        ret_indices_b.append(indices_b[idx_b])
        
        ret_curves_a.append(curves_a[idx_a])
        ret_curves_b.append(curves_b[idx_b])
        
        ret_idxs_a.append(idxs_a[idx_a])
        ret_idxs_b.append(idxs_b[idx_b])

    return (
        ret_indices_a, ret_curves_a, ret_idxs_a, 
        ret_indices_b, ret_curves_b, ret_idxs_b, 
    )

suppression_indices_full, suppression_curve_full, idxs_full = get_suppression_idx(responses, idxs=np.arange(800))
suppression_indices_nofb, suppression_curve_nofb, idxs_nofb = get_suppression_idx(responses_no_fb, idxs=np.arange(800))

suppression_indices_full, suppression_curve_full, idxs_full, \
suppression_indices_nofb, suppression_curve_nofb, idxs_nofb = get_paired_lists(
    (suppression_indices_full, suppression_indices_nofb),
    (suppression_curve_full, suppression_curve_nofb),
    (idxs_full, idxs_nofb)
)

In [None]:
def get_sRF_size (curve_arr):
    sizes = np.arange(2, 40+2, 4)[:]
    
    raw = []
    
    for curve in curve_arr:
        size = sizes[np.argmax(curve)]
        if np.isfinite(size):
            raw.append( size )
        
    return np.mean(raw), np.std(raw)/(len(raw)**0.5), raw

def get_sRF_FR (curves_full, curves_nofb):    
    raw_full = []
    raw_nofb = []
    
    for c_full, c_nofb in zip(curves_full, curves_nofb):
        r_full = np.max(c_full)
        r_nofb = np.max(c_nofb)
        
        
        if np.isfinite(r_full) and np.isfinite(r_nofb):
            raw_full.append( r_full/r_full )
            raw_nofb.append( r_nofb/r_full )
        
    return (
        np.mean(raw_full), np.std(raw_full)/(len(raw_full)**0.5), raw_full,
        np.mean(raw_nofb), np.std(raw_nofb)/(len(raw_nofb)**0.5), raw_nofb,
    )

def get_sRF_proximal_FR (curves_full, curves_nofb):    
    def get_proximal (curve):
        max_val = np.max(curve)
        max_idx = np.argmax(curve)
        
        prx_val  = max_val*0.95
        prx_idxs = np.where(curve<prx_val)[0]
        
        prx_idxs = prx_idxs[prx_idxs>max_idx]
        if len(prx_idxs):
            return curve[prx_idxs[0]]
        else:
            return np.nan
    
    
    raw_full = []
    raw_nofb = []
    
    for c_full, c_nofb in zip(curves_full, curves_nofb):
        r_full = get_proximal(c_full)
        r_nofb = get_proximal(c_nofb)
                
        if np.isfinite(r_full/r_nofb):
            raw_full.append( r_full/r_nofb )
            raw_nofb.append( r_nofb/r_nofb )
        
    return (
        np.nanmean(raw_full), np.nanstd(raw_full)/(len(raw_full)**0.5), raw_full,
        np.nanmean(raw_nofb), np.nanstd(raw_nofb)/(len(raw_nofb)**0.5), raw_nofb,
    )


size_full_mn, size_full_er, size_full_raw = get_sRF_size(suppression_curve_full)
size_nofb_mn, size_nofb_er, size_nofb_raw = get_sRF_size(suppression_curve_nofb)

sRF_proximal_FR_full_mn, sRF_proximal_FR_full_er, sRF_proximal_FR_full_raw, sRF_proximal_FR_nofb_mn, sRF_proximal_FR_nofb_er, sRF_proximal_FR_nofb_raw = get_sRF_proximal_FR(
    suppression_curve_full, suppression_curve_nofb
)

sRF_FR_full_mn, sRF_FR_full_er, sRF_FR_full_raw, sRF_FR_nofb_mn, sRF_FR_nofb_er, sRF_FR_nofb_raw = get_sRF_FR(
    suppression_curve_full, suppression_curve_nofb
)

sRF_diameter_mn = [1.268, 1.826]
sRF_diameter_er = [0.094, 0.136]

sRF_proximal_FR_mn = [0.860, 1]
sRF_proximal_FR_er = [0.051, 0]

sRF_FR_mn = [1, 0.668]
sRF_FR_er = [0, 0.054]

x = np.arange(2)

fig, axs = plt.subplots(nrows=1, ncols=2)
b = axs[0].bar(x, [size_full_mn, size_nofb_mn], yerr=[size_full_er, size_nofb_er])
b[0].set_facecolor('tab:blue')
b[1].set_facecolor('tab:orange')
axs[0].set_xticks([0.5])
axs[0].set_xticklabels(['Model'])
axs[0].set_ylabel('RF size (pixels)')
b = axs[1].bar(x, sRF_diameter_mn, yerr=sRF_diameter_er)
b[0].set_facecolor('black')
b[1].set_facecolor('tab:grey')
axs[1].set_xticks([0.5])
axs[1].set_xticklabels(['Mouse V1'])
axs[1].set_ylabel('RF size (degrees))')
format_plot(axs[0], fontsize=20)
format_plot(axs[1], fontsize=20)
fig.set_size_inches((4.5,4))
plt.tight_layout()
plt.savefig('./figures/surround_suppression/C_i.pdf', bbox_inches='tight')
plt.show()

fig, axs = plt.subplots(nrows=1, ncols=2)
b = axs[0].bar(x, [sRF_proximal_FR_full_mn, sRF_proximal_FR_nofb_mn], yerr=[sRF_proximal_FR_full_er, sRF_proximal_FR_nofb_er])
b[0].set_facecolor('tab:blue')
b[1].set_facecolor('tab:orange')
axs[0].set_xticks([0.5])
axs[0].set_xticklabels(['Model'])
axs[0].set_ylabel('Model response (norm.)')
b = axs[1].bar(x, sRF_proximal_FR_mn, yerr=sRF_proximal_FR_er)
b[0].set_facecolor('black')
b[1].set_facecolor('tab:grey')
axs[1].set_xticks([0.5])
axs[1].set_xticklabels(['Mouse V1'])
axs[1].set_ylabel('Firing rate (norm.)')
format_plot(axs[0], fontsize=20)
format_plot(axs[1], fontsize=20)
fig.set_size_inches((4.5,4))
plt.tight_layout()
plt.show()

fig, axs = plt.subplots(nrows=1, ncols=2)
b = axs[0].bar(x, [sRF_FR_full_mn, sRF_FR_nofb_mn], yerr=[sRF_FR_full_er, sRF_FR_nofb_er])
b[0].set_facecolor('tab:blue')
b[1].set_facecolor('tab:orange')
axs[0].set_xticks([0.5])
axs[0].set_xticklabels(['Model'])
axs[0].set_ylabel('Model response (norm.)')
b = axs[1].bar(x, sRF_FR_mn, yerr=sRF_FR_er)
b[0].set_facecolor('black')
b[1].set_facecolor('tab:grey')
axs[1].set_xticks([0.5])
axs[1].set_xticklabels(['Mouse V1'])
axs[1].set_ylabel('Firing rate (norm.)')
format_plot(axs[0], fontsize=20)
format_plot(axs[1], fontsize=20)
fig.set_size_inches((4.5,4))
plt.tight_layout()
plt.show()

# Model surround suppression

In [None]:
sizes = np.arange(2, 40+2, 4)[:]

#responses, responses_no_fb = get_model_responses (grating_amplitude=2)
    
idxs = np.arange(800)
suppression_indices_full, suppression_curve_full, _ = get_suppression_idx(responses, idxs)
suppression_indices_nofb, suppression_curve_nofb, _ = get_suppression_idx(responses_no_fb, idxs)
        
fig = plt.figure()
w_full = np.ones_like(suppression_indices_full)/len(suppression_indices_full)
w_nofb = np.ones_like(suppression_indices_nofb)/len(suppression_indices_nofb)
plt.hist([suppression_indices_full, suppression_indices_nofb], weights=[w_full, w_nofb], label=['Full model', 'No feedback'])
plt.xlabel('Suppression index')
plt.ylabel('Proportion of units')
format_plot(fontsize=20)
fig.set_size_inches(2.5, 4)
plt.gca().get_legend().set_bbox_to_anchor((1, 1))
plt.gca().get_legend().remove()
plt.savefig('./figures/surround_suppression/B_i.pdf', bbox_inches='tight')
plt.show()

print(np.mean(suppression_indices_full))

suppression_curve_full = [r/r.max() for r in suppression_curve_full]
suppression_curve_nofb = [r/r.max() for r in suppression_curve_nofb]


fig = plt.figure()

init_params = [1.75,12,0.5,15]
sizes_model = np.linspace(0, 40, 100)
fitted_params, _ = curve_fit(R, sizes, np.nanmean(suppression_curve_full, axis=0), p0=init_params, maxfev=10000)
plt.plot(sizes_model, R(sizes_model, *fitted_params), c='tab:blue', alpha=0.75, label='Full model')
plt.errorbar(
    sizes, np.nanmean(suppression_curve_full, axis=0),
    yerr=np.nanstd(suppression_curve_full, axis=0)/len(suppression_curve_full)**0.5,
    marker='o', linestyle=''
)

init_params = [1, 1, 1, 1]
sizes_model = np.linspace(0, 40, 100)
fitted_params, _ = curve_fit(R, sizes, np.nanmean(suppression_curve_nofb, axis=0), p0=init_params, maxfev=10000)
plt.plot(sizes_model, R(sizes_model, *fitted_params), c='tab:orange', alpha=0.75, label='No feedback')
plt.errorbar(
    sizes, np.nanmean(suppression_curve_nofb, axis=0),
    yerr=np.nanstd(suppression_curve_nofb, axis=0)/len(suppression_curve_nofb)**0.5,
    marker='o', linestyle=''
)

plt.xlabel('Stimulus size (pixels)')
plt.ylabel('Model response (norm.)')
format_plot(fontsize=20)
fig.set_size_inches(2.5, 4)
plt.gca().get_legend().set_bbox_to_anchor((1, 1))
plt.gca().get_legend().remove()
plt.savefig('./figures/surround_suppression/A_i.pdf', bbox_inches='tight')
plt.show()

# Model comparison

In [None]:
TP_model, hyperparameters, loss_history = NetworkHierarchicalRecurrent.load(
    model_path='',
    device='cpu',
    plot_loss_history=False
)
TP_model_no_fb, hyperparameters, loss_history = NetworkHierarchicalRecurrent.load(
    model_path='',
    device='cpu',
    plot_loss_history=False
)
weights = TP_model.rnn.weight_hh_l0.detach().cpu().numpy().copy()
weights[0:800, 800:1600] = 0
weights[800:1600, 1600:2400] = 0
TP_model_no_fb.rnn.weight_hh_l0 = torch.nn.Parameter(torch.Tensor(weights).to(DEVICE))
TP_vphys = VirtualPhysiology.load(
    data_path='',
    model=TP_model,
    hyperparameters=hyperparameters,
    hidden_units=[800, 800, 800],
    frame_shape=(20, 40),
    device='cpu'
)

AE_model, hyperparameters, loss_history = NetworkHierarchicalRecurrent.load(
    model_path='',
    device='cpu',
    plot_loss_history=False
)
AE_model_no_fb, hyperparameters, loss_history = NetworkHierarchicalRecurrent.load(
    model_path='',
    device='cpu',
    plot_loss_history=False
)
weights = AE_model.rnn.weight_hh_l0.detach().cpu().numpy().copy()
weights[0:800, 800:1600] = 0
weights[800:1600, 1600:2400] = 0
AE_model_no_fb.rnn.weight_hh_l0 = torch.nn.Parameter(torch.Tensor(weights).to(DEVICE))
AE_vphys = VirtualPhysiology.load(
    data_path='',
    model=AE_model,
    hyperparameters=hyperparameters,
    hidden_units=[800, 800, 800],
    frame_shape=(20, 40),
    device='cpu'
)

HEIGHT, WIDTH = 20, 40
sizes = np.arange(2, 40+2, 4)
frames = 100

suppression_indices_all        = []
response_full_tuning_curve_all = []
response_nofb_tuning_curve_all = []

for model, model_no_fb, vphys in zip(
    [TP_model      , AE_model      ],
    [TP_model_no_fb, AE_model_no_fb],
    [TP_vphys      , AE_vphys      ]
):
    units_len, sizes_len = len(physiology_data_group1), len(sizes)
    responses = np.zeros((units_len, 1, 1, sizes_len))
    responses_no_fb = np.zeros((units_len, 1, 1, sizes_len))

    params = sizes_len*units_len
    param_count = 0

    for unit_idx, unit in enumerate(physiology_data_group1):
        x_centre, y_centre = unit['gabor_x'], unit['gabor_y']

        g = get_grating_stimuli(
            spatial_frequency=unit['preferred_sf'],
            orientation=unit['preferred_orientation'],
            temporal_frequency=unit['preferred_tf'],
            grating_amplitude=3,
            frames=frames
        )
        g = g.cpu().detach().numpy().reshape(frames, HEIGHT, WIDTH)


        for size_idx, size in enumerate(sizes):
            g_classical = g.copy()
            mask = create_circular_mask(HEIGHT, WIDTH, center=(x_centre-0.5, y_centre-0.5), radius=size)

            g_classical[:, ~mask] = 0

            g_classical = torch.Tensor(g_classical.reshape(frames, HEIGHT*WIDTH)).unsqueeze(dim=0).to(DEVICE)

            with torch.no_grad():
                # Full model
                _, hidden_states = model(g_classical)
                unit_responses = np.mean(
                    hidden_states[0, 100-4:100, unit['hidden_unit_index']].detach().cpu().numpy(),
                    axis=0
                )
                responses[unit_idx, 0, 0, size_idx] = unit_responses

                # No feedback model
                _, hidden_states = model_no_fb(g_classical)
                unit_responses = np.mean(
                    hidden_states[0, 100-4:100, unit['hidden_unit_index']].detach().cpu().numpy(),
                    axis=0
                )
                responses_no_fb[unit_idx, 0, 0, size_idx] = unit_responses

            if param_count % 250 == 0:
                print(param_count, '/', params)
            param_count += 1

    suppression_indices, response_full_tuning_curve, _ = get_suppression_idx(responses, np.arange(len(responses)))
    suppression_indices_all.append(suppression_indices)
    response_full_tuning_curve_all.append(response_full_tuning_curve)
    
    _, response_nofb_tuning_curve, _ = get_suppression_idx(responses_no_fb, np.arange(len(responses)))
    response_nofb_tuning_curve_all.append(response_nofb_tuning_curve)
    
fig = plt.figure()

for r, r_fb, c, l in zip(
    response_full_tuning_curve_all,
    response_nofb_tuning_curve_all,
    ['tab:blue', 'tab:green'],
    ['TP$\mathregular{_{full}}$', 'AE']
):    
    plt.plot(sizes, np.mean(r   , axis=0),       c=c, label=l,  linewidth=2.5)
    plt.plot(sizes, np.mean(r_fb, axis=0), '--', c=c, label=f'{l} (no feedback)',  linewidth=2.5)
    
plt.xticks([0, 10, 20, 30, 40])
plt.xlabel('Stimulus size (pixels)')
plt.ylabel('Unit response')
format_plot(fontsize=20)
plt.gca().get_legend().set_bbox_to_anchor((1, 1))
fig.set_size_inches(4, 4)
plt.savefig('./figures/model_comparison/C_i.pdf', bbox_inches='tight')
plt.show()

In [None]:
fig = plt.figure()

for r, r_fb, c, l in zip(
    response_full_tuning_curve_all,
    response_nofb_tuning_curve_all,
    ['tab:blue', 'tab:green'],
    ['TP$\mathregular{_{full}}$', 'AE']
):    
    plt.plot(sizes, np.mean(r   , axis=0),       c=c, label=l, linewidth=2.5)
    plt.plot(sizes, np.mean(r_fb, axis=0), '--', c=c, label=f'{l} (no feedback)', linewidth=2.5)
    
plt.xticks([0, 10, 20, 30, 40])
plt.xlabel('Stimulus size (pixels)')
plt.ylabel('Model response')
plt.ylim(0, None)
format_plot(fontsize=20)
plt.gca().get_legend().set_bbox_to_anchor((1, 1))
fig.set_size_inches(4, 4)
plt.savefig('./figures/model_comparison/C_i.pdf', bbox_inches='tight')
plt.show()

In [None]:
fig = plt.figure()

for r_idx, r in enumerate(response_tuning_curve_ff_all):
    r_mn = np.mean(r, axis=0)
    
    l = '$\\mathregular{G_{' + str(r_idx) + '}}$'
    plt.plot(sizes, r_mn, linestyle=['solid', 'dashed', 'dotted'][r_idx], label=l, c='tab:red', linewidth=2.5, alpha=[1, 0.75, 0.5][r_idx])

plt.xticks([0, 10, 20, 30, 40])
plt.xlabel('Stimulus size (pixels)')
plt.ylabel('Model response')
format_plot(fontsize=20)
plt.gca().get_legend().set_bbox_to_anchor((1, 1))
plt.gca().get_legend().remove()
fig.set_size_inches(3, 4)
plt.savefig('./figures/model_comparison/C_ii.pdf', bbox_inches='tight')
plt.show()

In [None]:
x = np.array([0, 1, 2, 3])

v1_supp = [20.04, 27.7, 67.26, 49.24, 84.64, 86.48, 81.56, 81.57, 79.21, 76.66, 75.92, 73.73, 68.79, 65.33, 63.87, 56.77, 55.85, 52.93, 31.74, 34.85, 35.03, 35.23, 41.06, 41.8, 44.72, 39.27, 37.45, 35.99, 35.62, 45.48, 35.13, 33.45, 29.62, 28.71, 26.89, 24.69, 24.69, 16.15, 13.23, 12.88, 7.57, 12.48, 13.57, 17.93, 10.64, 14.46, 26.12, 29.21, 22.46, 21.9, 17.16, 15.52, 8.25, 9.53, 7.71, 8.08, 8.81]

most_suppressed_ff_layer = np.argmax([np.mean(d) for d in suppression_indices_ff_all])

all_dists = [suppression_indices_all[0], suppression_indices_ff_all[most_suppressed_ff_layer], suppression_indices_all[1], v1_supp]
means     = [np.mean(d) for d in all_dists]
colors    = ['tab:blue', 'tab:red', 'tab:green', 'tab:gray']
labels    = ['TP$\mathregular{_{full}}$', 'TP$\mathregular{_{FF}}$', 'AE', 'V1']

fig = plt.figure()
vp = plt.violinplot(all_dists, x, showmedians=False, showextrema=False)
plt.xticks(x, labels)
for i, pc in enumerate(vp["bodies"], 0):
    pc.set_facecolor(colors[i])
    pc.set_edgecolor('black')
    pc.set_alpha(0.8)
plt.hlines(means, x-0.5, x+0.5, linewidth=2, color='black', zorder=3, alpha=0.8)
plt.ylabel('Surround suppression (%)')
format_plot(fontsize=20)
fig.set_size_inches(3.25, 4)
plt.savefig('./figures/model_comparison/C_iii_mouse.pdf', bbox_inches='tight')
plt.show()  

ks_dist = [
    stats.ks_2samp(all_dists[0], all_dists[3])[0],
    stats.ks_2samp(all_dists[1], all_dists[3])[0],
    stats.ks_2samp(all_dists[2], all_dists[3])[0],
]
    
x = [0, 1, 2]
fig = plt.figure()
b = plt.bar(x, ks_dist)
for b_i, c_i in zip(b, colors):
    b_i.set_facecolor(c_i)
plt.xticks(x, ['$\mathregular{TP_{full}}$', '$\mathregular{TP_{FF}}$', 'AE'], rotation=0)
plt.ylabel('KS distance')
format_plot(fontsize=20)
fig.set_size_inches(3, 4)
plt.savefig('./figures/model_comparison/C_iv_mouse.pdf', bbox_inches='tight')
plt.show()

for d in all_dists[:3]:
    print(np.mean(d), np.median(d))
    
    d_h = np.histogram(d, bins=np.linspace(0, 100, 25))[0]
    d_h = np.cumsum(d_h/d_h.sum())
    plt.plot(d_h)
    
    d_m = np.histogram(all_dists[-1], bins=np.linspace(0, 100, 25))[0]
    d_m = np.cumsum(d_m/d_m.sum())
    plt.plot(d_m)
    
    plt.title(np.max(np.abs(d_h-d_m)))
    
    plt.show()
    

# Comparison data

## Macaque data

In [None]:
raw_macaque_data = """33.5195530726257, 5.526599599818269
31.843575418994412, 22.280134133696748
33.14711359404096, 38.486106342914795
40.4096834264432, 31.809244405605313
44.1340782122905, 34.61641161151162
60.89385474860335, 34.12017241797834
62.94227188081937, 22.76839743523054
41.340782122905026, 63.84250734302229
50.279329608938546, 77.8422784695997
54.37616387337058, 74.50558139051006
51.769087523277456, 55.315238461831875
54.00372439478584, 53.088924260236
62.94227188081937, 55.170632072102045
64.80446927374302, 45.86658066574422
63.68715083798882, 42.32424428423303
72.43947858472998, 51.85404811196763
79.70204841713223, 46.853163828289425
90.68901303538175, 45.9629849255641
66.10800744878958, 76.97013201835146
70.76350093109869, 81.64296439631164
73.9292364990689, 64.8949783090415
76.72253258845436, 72.5403909574191
77.0949720670391, 69.74848197968572
76.16387337057728, 67.32415758975478
85.84729981378024, 57.118136831628895
84.54376163873368, 64.18963203395647
91.06145251396649, 64.02768674857558
87.52327746741153, 67.18024475585099
84.54376163873368, 69.21756499485032
89.19925512104282, 68.49002493333191
98.13780260707634, 70.57173274519798
88.26815642458101, 74.44558881155739
81.9366852886406, 75.72554608851851
83.4264432029795, 78.33817088521995
84.72998137802608, 79.64656395104882
83.61266294227187, 77.96642496246128
96.64804469273741, 79.50473178462315
95.90316573556797, 81.36415495424262
96.46182495344507, 81.92489483959784
99.06890130353818, 83.7968020140861
99.81378026070763, 88.08263024111466
91.80633147113593, 83.9559730761628
97.20670391061452, 89.00401915601188
97.39292364990689, 90.30825088688448
93.6685288640596, 89.17706133460945
86.59217877094972, 85.98497064524963
86.96461824953445, 85.42769853902463
82.12290502793294, 85.78210556613226
81.00558659217879, 88.01258110268438
86.59217877094972, 89.8955851703893
90.87523277467409, 94.38081069740504
97.20670391061452, 94.2181718561981
99.81378026070763, 94.04166189847034
89.57169459962756, 100.52120720327079
95.90316573556797, 100.1723486227715
99.62756052141526, 96.6480446927374
100.37243947858474, 97.20947813391867
99.62756052141526, 97.7653631284916
99.4413407821229, 98.88198800841974
99.62756052141526, 99.44134078212288
99.62756052141526, 100.18621973929234
99.81378026070763, 99.81447381653365
99.4413407821229, 98.32332879054265
99.81378026070763, 97.39361720573291
62.756052141527, 63.549826784432426
67.78398510242084, 50.71939078056239"""

In [None]:
# https://www.jneurosci.org/content/jneuro/33/19/8504.full.pdf

macaque_x = np.array([0.25, 0.45, 1, 1.8, 4, 7.15, 16])
macaque_x_vals = np.linspace(0, 16, 100)
macaque_cooled_normed_response  = np.array([0.109, 0.502, 1.06, 0.816, 0.400, 0.261, 0.133])
macaque_control_normed_response = np.array([0.093, 0.453, 0.975, 0.557, 0.202, 0.125, 0.063])

fig = plt.figure(dpi=150)

init_params = [1.75,12,0.5,15]
fitted_params, _ = curve_fit(R, macaque_x, macaque_control_normed_response, p0=init_params, maxfev=10000)
plt.scatter(macaque_x, macaque_control_normed_response, c='black')
plt.plot(
    macaque_x_vals, R(macaque_x_vals, *init_params),
    c='black', label='Control'
)

init_params = [1.75,12,0.5,15]
fitted_params, _ = curve_fit(R, macaque_x, macaque_cooled_normed_response, p0=init_params, maxfev=10000)
plt.scatter(macaque_x, macaque_cooled_normed_response, c='black', alpha=0.375)
plt.plot(
    macaque_x_vals, R(macaque_x_vals, *fitted_params),
    c='black', alpha=0.375, label='Cooled'
)

plt.xlabel('Stimulus size (degrees)')
plt.ylabel('Spike rate (norm.)')
format_plot(fontsize=20)
fig.set_size_inches(2.5, 4)
plt.gca().get_legend().set_bbox_to_anchor((1, 1))
plt.gca().get_legend().remove()
plt.savefig('./figures/surround_suppression/A_iii.pdf', bbox_inches='tight')
plt.show()

In [None]:
macaque_control = [float(l.split(', ')[0]) for l in raw_macaque_data.split('\n')]
macaque_cooled = [float(l.split(', ')[1]) for l in raw_macaque_data.split('\n')]

fig = plt.figure()
w_full = np.ones_like(macaque_control)/len(macaque_control)
w_nofb = np.ones_like(macaque_cooled)/len(macaque_cooled)
plt.hist([macaque_control, macaque_cooled], weights=[w_full, w_nofb], label=['Control', 'Feedback\nsuppressed'], color=['black', 'tab:gray'])
plt.xlabel('Suppression index')
plt.ylabel('Proportion of cells')
plt.xticks([0, 100])
format_plot(fontsize=20)
fig.set_size_inches(2.5, 4)
plt.gca().get_legend().set_bbox_to_anchor((1, 1))
plt.gca().get_legend().remove()
plt.savefig('./figures/surround_suppression/B_iii.pdf', bbox_inches='tight')
plt.show()

## Mouse data

In [None]:
xvals_v1 = np.linspace(0, 110, 100)
mouse_x = [10, 20, 30, 40, 60, 80, 110]
mouse_v1_laseron_y = [0.3113, 0.7660, 0.7375, 0.7692, 0.6255, 0.5614, 0.5769]
mouse_v1_laseroff_y = [0.3184, 0.7642, 0.7251, 0.6842, 0.5865, 0.4747, 0.4317]

fig = plt.figure(dpi=150)

init_params = [2,1,10,10]
fitted_params, _ = curve_fit(R, mouse_x, mouse_v1_laseroff_y, p0=init_params, maxfev=10000)
plt.scatter(mouse_x, mouse_v1_laseroff_y, c='black')
plt.plot(
    xvals_v1, R(xvals_v1, *fitted_params),
    c='black', label='Control'
)

init_params = [2,2,10,10]
fitted_params, _ = curve_fit(R, mouse_x, mouse_v1_laseron_y, p0=init_params, maxfev=10000)
plt.scatter(mouse_x, mouse_v1_laseron_y, c='black', alpha=0.375)
plt.plot(
    xvals_v1, R(xvals_v1, *fitted_params),
    c='black', alpha=0.375, label='Feedback\nsuppressed'
)

plt.xlabel('Stimulus size (degrees)')
plt.ylabel('Spike rate (norm.)')
format_plot(fontsize=20)
fig.set_size_inches(2.5, 4)
plt.gca().get_legend().set_bbox_to_anchor((1, 1))
plt.gca().get_legend().remove()
plt.savefig('./figures/surround_suppression/A_ii.pdf', bbox_inches='tight')
plt.show()

In [None]:
v1_full = [1.09, 4.56, 28.47, 30.84, 68.98, 76.28, 78.1, 84.12, 86.31, 83.76, 81.93, 78.28, 67.52, 63.87, 64.05, 66.79, 64.05, 56.2, 26.09, 30.47, 31.2, 40.15, 42.88, 49.82, 53.65, 56.93, 58.58, 58.76, 57.3, 71.17, 87.23, 65.33, 62.23, 63.32, 62.77, 57.12, 52.92, 64.78, 65.15, 72.08, 56.39, 51.64, 47.81, 46.17, 41.79, 38.5, 39.6, 37.23, 29.93, 24.09, 20.99, 19.71, 27.01, 30.66, 33.39, 35.95, 39.05]
v1_supp = [20.04, 27.7, 67.26, 49.24, 84.64, 86.48, 81.56, 81.57, 79.21, 76.66, 75.92, 73.73, 68.79, 65.33, 63.87, 56.77, 55.85, 52.93, 31.74, 34.85, 35.03, 35.23, 41.06, 41.8, 44.72, 39.27, 37.45, 35.99, 35.62, 45.48, 35.13, 33.45, 29.62, 28.71, 26.89, 24.69, 24.69, 16.15, 13.23, 12.88, 7.57, 12.48, 13.57, 17.93, 10.64, 14.46, 26.12, 29.21, 22.46, 21.9, 17.16, 15.52, 8.25, 9.53, 7.71, 8.08, 8.81]

fig = plt.figure()
w_full = np.ones_like(v1_full)/len(v1_full)
w_nofb = np.ones_like(v1_supp)/len(v1_supp)
plt.hist([v1_full, v1_supp], weights=[w_full, w_nofb], label=['Control', 'Feedback\nsuppressed'], color=['black', 'tab:gray'])
plt.xlabel('Suppression index')
plt.ylabel('Proportion of cells')
plt.xticks([0, 100])
format_plot(fontsize=20)
fig.set_size_inches(2.5, 4)
plt.gca().get_legend().set_bbox_to_anchor((1, 1))
plt.gca().get_legend().remove()
plt.savefig('./figures/surround_suppression/B_ii.pdf', bbox_inches='tight')
plt.show()
