In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib inline

In [3]:
import os
import torch
import numpy as np
import pickle 

# import nnfabrik
from nnfabrik import builder

# import nnvision
from numpy import linalg
import matplotlib.pyplot as plt
import seaborn as sns
from neuralpredictors.data.datasets import FileTreeDataset
# from neuralpredictors.data.transforms import Subsample, ToTensor, NeuroNormalizer, AddBehaviorAsChannels
from collections import OrderedDict
from cotton2020 import get_oracles, get_correlations

# Training Configuration

Configuration of the model, the trainer and the dataset used for training. 

In [8]:
scan = '20457-5-9'
paths = [f'data/static{scan}-preproc0/']


model_fn = 'model_components.se_core_spatialXfeature_readout'
model_config = {
   'pad_input': False,
   'stack': -1,
   'layers':4,
   'input_kern': 15,
   'gamma_input': 1., #20 #6.3831,          
   'gamma_readout':  0.002362354239446914, #0.0076,
   'hidden_dilation': 1,
   'hidden_kern': 13,
   'hidden_channels': 64,
   'n_se_blocks': 0,
   'depth_separable': True,
   'normalize': False,
    'init_noise' :4.1232e-05,
}

trainer_fn = 'cotton2020.nnvision_trainer'
trainer_config = dict(max_iter=100,
                      verbose=False, 
                      lr_decay_steps=4,
                      avg_loss=False, 
                      patience=5,
                      lr_init=.0041)

readout_key = f'{scan}-0/'

# Train Full Klindt et al. model

## Get DataLoader

In [9]:
unit_ids = np.load(f'configs/{scan}_train_units.npy')

dataset_fn = 'cotton2020.neurips_loaders'
dataset_config = dict(
    paths=paths,
    areas = ('V1',),
    batch_size=64,
    normalize=True,
    neuron_ids=[unit_ids],
)

dataloaders = builder.get_data(dataset_fn, dataset_config)
dataloaders

OrderedDict([('train',
              OrderedDict([('20457-5-9-0/',
                            <torch.utils.data.dataloader.DataLoader at 0x7f2fb903fa00>)])),
             ('validation',
              OrderedDict([('20457-5-9-0/',
                            <torch.utils.data.dataloader.DataLoader at 0x7f2fa5dd1700>)])),
             ('test',
              OrderedDict([('20457-5-9-0/',
                            <torch.utils.data.dataloader.DataLoader at 0x7f2fa5dd15b0>)]))])

## Define Model

In [10]:
model = builder.get_model(model_fn=model_fn, 
                          model_config=model_config, 
                          dataloaders=dataloaders, 
                          seed=1000)
model

Encoder(
  (core): SE2dCore(
    (_input_weights_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv): Conv2d(1, 64, kernel_size=(15, 15), stride=(1, 1), bias=False)
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0, inplace=True)
      )
      (layer1): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (spatial_conv): Conv2d(64, 64, kernel_size=(13, 13), stride=(1, 1), padding=(6, 6), groups=64, bias=False)
          (out_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0, inplace=True)
      )
      (layer2): Sequential(
        (ds_conv): DepthSeparableConv2d(
        

## Get Trainer

In [11]:
trainer = builder.get_trainer(trainer_fn, trainer_config)

## Run Training

In [13]:
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=1000)

Epoch 1: 100%|██████████| 70/70 [00:12<00:00,  5.71it/s]


[001|00/05] ---> 0.06483694911003113


Epoch 2: 100%|██████████| 70/70 [00:07<00:00,  9.55it/s]


[002|00/05] ---> 0.09954481571912766


Epoch 3: 100%|██████████| 70/70 [00:07<00:00,  9.59it/s]


[003|00/05] ---> 0.11677460372447968


Epoch 4: 100%|██████████| 70/70 [00:07<00:00,  9.58it/s]


[004|00/05] ---> 0.1514599472284317


Epoch 5: 100%|██████████| 70/70 [00:07<00:00,  9.46it/s]


[005|00/05] ---> 0.17109455168247223


Epoch 6: 100%|██████████| 70/70 [00:07<00:00,  9.54it/s]


[006|00/05] ---> 0.1721285581588745


Epoch 7: 100%|██████████| 70/70 [00:07<00:00,  9.51it/s]


[007|00/05] ---> 0.19527208805084229


Epoch 8: 100%|██████████| 70/70 [00:07<00:00,  9.57it/s]


[008|00/05] ---> 0.20207111537456512


Epoch 9: 100%|██████████| 70/70 [00:07<00:00,  9.56it/s]


[009|01/05] -/-> 0.2010180652141571


Epoch 10: 100%|██████████| 70/70 [00:07<00:00,  9.49it/s]


[010|01/05] ---> 0.20870552957057953


Epoch 11: 100%|██████████| 70/70 [00:07<00:00,  9.56it/s]


[011|00/05] ---> 0.21070457994937897


Epoch 12: 100%|██████████| 70/70 [00:07<00:00,  9.61it/s]


[012|00/05] ---> 0.21426929533481598


Epoch 13: 100%|██████████| 70/70 [00:07<00:00,  9.44it/s]


[013|00/05] ---> 0.21739770472049713


Epoch 14: 100%|██████████| 70/70 [00:07<00:00,  9.40it/s]


[014|00/05] ---> 0.22558771073818207


Epoch 15: 100%|██████████| 70/70 [00:07<00:00,  9.54it/s]


[015|00/05] ---> 0.23083560168743134


Epoch 16: 100%|██████████| 70/70 [00:07<00:00,  9.50it/s]


[016|01/05] -/-> 0.2300853282213211


Epoch 17: 100%|██████████| 70/70 [00:07<00:00,  9.49it/s]


[017|01/05] ---> 0.2366214096546173


Epoch 18: 100%|██████████| 70/70 [00:07<00:00,  9.44it/s]


[018|00/05] ---> 0.24023956060409546


Epoch 19: 100%|██████████| 70/70 [00:07<00:00,  9.63it/s]


[019|00/05] ---> 0.24157123267650604


Epoch 20: 100%|██████████| 70/70 [00:07<00:00,  9.45it/s]


[020|01/05] -/-> 0.24120983481407166


Epoch 21: 100%|██████████| 70/70 [00:07<00:00,  9.47it/s]


[021|01/05] ---> 0.24260498583316803


Epoch 22: 100%|██████████| 70/70 [00:07<00:00,  9.43it/s]


[022|01/05] -/-> 0.23708751797676086


Epoch 23: 100%|██████████| 70/70 [00:07<00:00,  9.42it/s]


[023|01/05] ---> 0.2494056671857834


Epoch 24: 100%|██████████| 70/70 [00:07<00:00,  9.37it/s]


[024|00/05] ---> 0.25135719776153564


Epoch 25: 100%|██████████| 70/70 [00:07<00:00,  9.54it/s]


[025|01/05] -/-> 0.2490425556898117


Epoch 26: 100%|██████████| 70/70 [00:07<00:00,  9.42it/s]


[026|02/05] -/-> 0.25037428736686707


Epoch 27: 100%|██████████| 70/70 [00:07<00:00,  9.44it/s]


[027|02/05] ---> 0.25214189291000366


Epoch 28: 100%|██████████| 70/70 [00:07<00:00,  9.35it/s]


[028|00/05] ---> 0.2539500594139099


Epoch 29: 100%|██████████| 70/70 [00:07<00:00,  9.40it/s]


[029|01/05] -/-> 0.2494632601737976


Epoch 30: 100%|██████████| 70/70 [00:07<00:00,  9.34it/s]


[030|01/05] ---> 0.2581081688404083


Epoch 31: 100%|██████████| 70/70 [00:07<00:00,  9.41it/s]


[031|00/05] ---> 0.2581847012042999


Epoch 32: 100%|██████████| 70/70 [00:07<00:00,  9.61it/s]


[032|01/05] -/-> 0.2568398714065552


Epoch 33: 100%|██████████| 70/70 [00:07<00:00,  9.63it/s]


[033|01/05] ---> 0.2582002580165863


Epoch 34: 100%|██████████| 70/70 [00:07<00:00,  9.38it/s]


[034|01/05] -/-> 0.2567746043205261


Epoch 35: 100%|██████████| 70/70 [00:07<00:00,  9.43it/s]


[035|02/05] -/-> 0.25109171867370605


Epoch 36: 100%|██████████| 70/70 [00:07<00:00,  9.34it/s]


[036|02/05] ---> 0.2582603991031647


Epoch 37: 100%|██████████| 70/70 [00:07<00:00,  9.36it/s]


[037|01/05] -/-> 0.2561570703983307


Epoch 38: 100%|██████████| 70/70 [00:07<00:00,  9.29it/s]


[038|02/05] -/-> 0.25739333033561707


Epoch 39: 100%|██████████| 70/70 [00:07<00:00,  9.36it/s]


[039|02/05] ---> 0.2602827548980713


Epoch 40: 100%|██████████| 70/70 [00:07<00:00,  9.33it/s]


[040|01/05] -/-> 0.2558630406856537


Epoch 41: 100%|██████████| 70/70 [00:07<00:00,  9.31it/s]


[041|01/05] ---> 0.2617564797401428


Epoch 42: 100%|██████████| 70/70 [00:07<00:00,  9.42it/s]


[042|01/05] -/-> 0.25816231966018677


Epoch 43: 100%|██████████| 70/70 [00:07<00:00,  9.30it/s]


[043|02/05] -/-> 0.2582595646381378


Epoch 44: 100%|██████████| 70/70 [00:07<00:00,  9.38it/s]


[044|03/05] -/-> 0.25726228952407837


Epoch 45: 100%|██████████| 70/70 [00:07<00:00,  9.50it/s]


[045|04/05] -/-> 0.25653913617134094


Epoch 46: 100%|██████████| 70/70 [00:07<00:00,  9.40it/s]


[046|05/05] -/-> 0.25719794631004333


Epoch 47:   3%|▎         | 2/70 [00:00<00:04, 14.43it/s]

Restoring best model after lr decay! 0.257198 ---> 0.261756


Epoch 47: 100%|██████████| 70/70 [00:07<00:00,  9.59it/s]


[047|01/05] -/-> 0.257427841424942


Epoch 48: 100%|██████████| 70/70 [00:07<00:00,  9.48it/s]


[048|01/05] ---> 0.26940762996673584


Epoch 49: 100%|██████████| 70/70 [00:07<00:00,  9.40it/s]


[049|01/05] -/-> 0.26811179518699646


Epoch 50: 100%|██████████| 70/70 [00:07<00:00,  9.48it/s]


[050|01/05] ---> 0.2696393132209778


Epoch 51: 100%|██████████| 70/70 [00:07<00:00,  9.28it/s]


[051|01/05] -/-> 0.2678483724594116


Epoch 52: 100%|██████████| 70/70 [00:07<00:00,  9.36it/s]


[052|02/05] -/-> 0.26898083090782166


Epoch 53: 100%|██████████| 70/70 [00:07<00:00,  9.33it/s]


[053|03/05] -/-> 0.2667967975139618


Epoch 54: 100%|██████████| 70/70 [00:07<00:00,  9.33it/s]


[054|04/05] -/-> 0.26896795630455017


Epoch 55: 100%|██████████| 70/70 [00:07<00:00,  9.41it/s]


[055|05/05] -/-> 0.26902708411216736


Epoch 56:   3%|▎         | 2/70 [00:00<00:04, 14.82it/s]

Restoring best model after lr decay! 0.269027 ---> 0.269639


Epoch 56: 100%|██████████| 70/70 [00:07<00:00,  9.45it/s]


[056|01/05] -/-> 0.26879826188087463


Epoch 57: 100%|██████████| 70/70 [00:07<00:00,  9.35it/s]


[057|02/05] -/-> 0.2693394422531128


Epoch 58: 100%|██████████| 70/70 [00:07<00:00,  9.56it/s]


[058|03/05] -/-> 0.2695113718509674


Epoch 59: 100%|██████████| 70/70 [00:07<00:00,  9.42it/s]


[059|03/05] ---> 0.2713336646556854


Epoch 60: 100%|██████████| 70/70 [00:07<00:00,  9.34it/s]


[060|01/05] -/-> 0.27051350474357605


Epoch 61: 100%|██████████| 70/70 [00:07<00:00,  9.26it/s]


[061|02/05] -/-> 0.26925569772720337


Epoch 62: 100%|██████████| 70/70 [00:07<00:00,  9.37it/s]


[062|03/05] -/-> 0.2696235477924347


Epoch 63: 100%|██████████| 70/70 [00:07<00:00,  9.31it/s]


[063|04/05] -/-> 0.26621273159980774


Epoch 64: 100%|██████████| 70/70 [00:07<00:00,  9.33it/s]


[064|05/05] -/-> 0.2676547169685364


Epoch 65:   3%|▎         | 2/70 [00:00<00:04, 14.49it/s]

Restoring best model after lr decay! 0.267655 ---> 0.271334


Epoch 65: 100%|██████████| 70/70 [00:07<00:00,  9.62it/s]


[065|01/05] -/-> 0.2681646943092346


Epoch 66: 100%|██████████| 70/70 [00:07<00:00,  9.34it/s]


[066|02/05] -/-> 0.26834338903427124


Epoch 67: 100%|██████████| 70/70 [00:07<00:00,  9.42it/s]


[067|03/05] -/-> 0.2673272490501404


Epoch 68: 100%|██████████| 70/70 [00:07<00:00,  9.28it/s]


[068|04/05] -/-> 0.2684917151927948


Epoch 69: 100%|██████████| 70/70 [00:07<00:00,  9.35it/s]


[069|05/05] -/-> 0.26971253752708435
Restoring best model after lr decay! 0.269713 ---> 0.271334
Restoring best model! 0.271334 ---> 0.271334


In [14]:
state_dict = model.state_dict()
with open(f'models/{scan}_full_model_statedict.pkl', 'wb') as fid:
    pickle.dump(state_dict, fid)

# Train Readout only

In [15]:
total_results = dict()

In [76]:
data_size = 1000


In [77]:
unit_ids = np.load(f'configs/{scan}_test_units.npy')
train_trials =  np.load(f'configs/{scan}_train_trials.npy')
dataset_fn = 'cotton2020.neurips_loaders'
dataset_config = dict(
    paths=paths,
    batch_size=64,
    normalize=True,
    neuron_ids=[unit_ids],
    trial_ids = dict(train=train_trials[:data_size]) if data_size <= 1000 else None
)
dataloaders = builder.get_data(dataset_fn, dataset_config)
dataloaders

Subsampling train set to 1000 trials


OrderedDict([('train',
              OrderedDict([('20457-5-9-0/',
                            <torch.utils.data.dataloader.DataLoader at 0x7f2fdc24d250>)])),
             ('validation',
              OrderedDict([('20457-5-9-0/',
                            <torch.utils.data.dataloader.DataLoader at 0x7f2f3c01a280>)])),
             ('test',
              OrderedDict([('20457-5-9-0/',
                            <torch.utils.data.dataloader.DataLoader at 0x7f2f34465610>)]))])

In [78]:
model = builder.get_model(model_fn=model_fn, 
                          model_config=model_config, 
                          dataloaders=dataloaders, 
                          seed=1000)
model

Encoder(
  (core): SE2dCore(
    (_input_weights_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv): Conv2d(1, 64, kernel_size=(15, 15), stride=(1, 1), bias=False)
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0, inplace=True)
      )
      (layer1): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (spatial_conv): Conv2d(64, 64, kernel_size=(13, 13), stride=(1, 1), padding=(6, 6), groups=64, bias=False)
          (out_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0, inplace=True)
      )
      (layer2): Sequential(
        (ds_conv): DepthSeparableConv2d(
        

## Load old model, copy core, and freeze it

In [79]:
with open(f'models/{scan}_full_model_statedict.pkl', 'rb') as fid:
    state_dict = pickle.load(fid)
    core_state_dict = OrderedDict(
        [(k[5:], v) for k,v in state_dict.items() if k.startswith('core.')]
    )
old_state = OrderedDict(core_state_dict) 
model.core.load_state_dict(core_state_dict)


for param in model.core.parameters():
    param.requires_grad = False

## Train

In [80]:
trainer = builder.get_trainer(trainer_fn, trainer_config)

In [None]:
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=1000)

Epoch 1:  62%|██████▎   | 10/16 [00:00<00:00, 10.66it/s]

## Make sure core was actually frozen

In [None]:
for k, v in model.core.state_dict().items():
    if not 'running_mean' in k and not 'running_var' in k and not 'num_batches_tracked' in k:
        print(f'Checking {k}', (v - old_state[k]).abs().max())
        assert (v - old_state[k]).abs().max() < 1e-6
    

## Save model

In [None]:
state_dict = model.state_dict()
with open(f'models/{scan}_datasize{data_size}readout_model_statedict.pkl', 'wb') as fid:
    pickle.dump(state_dict, fid)

## Evaluate model

In [None]:
corrs = get_correlations(model, dataloaders['test'])
oracles = get_oracles(dataloaders["test"], as_dict=True)

In [None]:
corrs.mean()

In [None]:
a, _, _, _ = np.linalg.lstsq(oracles[readout_key][:, None], corrs)
print(f'Percent oracle is {a[0]*100:.2f}%')

In [None]:
total_results[data_size] =  dict(
    unit_ids = dataloaders['train'][readout_key].dataset.neurons.unit_ids,
    test_correlation = corrs.squeeze(),
    oracle_correlation = oracles[readout_key],
    mean_test_correlation = corrs.mean(),
    fraction_oracle = a
)
np.savez(f'results/{scan}-data_efficieny.npz', total_results)

In [None]:
{k:total_results[k]['mean_test_correlation'] for k in total_results}

Values from the Cotton et al. 2020