# Demo Notebook how to run models on static mouse datasets

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2 

In [2]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USER']
dj.config['database.password'] = os.environ['DJ_PASS']
dj.config['enable_python_native_blobs'] = True

name = "test"
dj.config['schema_name'] = f"konstantin_nnsysident_{name}"

In [3]:
import torch
import numpy as np
import pickle 
import pandas as pd
from collections import OrderedDict, Iterable

import nnfabrik
from nnfabrik.main import *
from nnfabrik import builder

from nnsysident.tables.experiments import *
from nnsysident.datasets.mouse_loaders import static_shared_loaders
from nnsysident.datasets.mouse_loaders import static_loaders
from nnsysident.datasets.mouse_loaders import static_loader

  from collections import OrderedDict, Iterable


Connecting konstantin@134.2.168.16:3306


# Get Dataloader

In [4]:
# change path here
paths = ['data/static22564-2-12-preproc0.zip',
                     'data/static22564-2-13-preproc0.zip',
                     'data/static22564-3-8-preproc0.zip',
                     'data/static22564-3-12-preproc0.zip']

#paths = ['data/static22564-2-12-preproc0.zip']
dataset_fn = 'nnsysident.datasets.mouse_loaders.static_shared_loaders'
dataset_config = dict(
    paths=paths,
    batch_size=64,
    seed=1,
    #image_n=50,
    #image_base_seed=1,
    multi_match_n=972,
    multi_match_base_seed=1,
    exclude_multi_match_n = 3625,

)

dataloaders = builder.get_data(dataset_fn, dataset_config)

data/static22564-2-12-preproc0 exists already. Not unpacking data/static22564-2-12-preproc0.zip
data/static22564-2-13-preproc0 exists already. Not unpacking data/static22564-2-13-preproc0.zip
data/static22564-3-8-preproc0 exists already. Not unpacking data/static22564-3-8-preproc0.zip
data/static22564-3-12-preproc0 exists already. Not unpacking data/static22564-3-12-preproc0.zip
data/static22564-2-12-preproc0 exists already. Not unpacking data/static22564-2-12-preproc0.zip
data/static22564-2-13-preproc0 exists already. Not unpacking data/static22564-2-13-preproc0.zip
data/static22564-3-8-preproc0 exists already. Not unpacking data/static22564-3-8-preproc0.zip
data/static22564-3-12-preproc0 exists already. Not unpacking data/static22564-3-12-preproc0.zip


# Get Model

### The New gaussian readout: change gauss_type for the different modes

In [5]:
model_fn = 'nnsysident.models.models.se2d_fullgaussian2d'

model_config = {'share_features': True,
                'share_transform': True,
                 'init_mu_range': 0.55,
                 'init_sigma': 0.4,
                 'input_kern': 15,
                 'hidden_kern': 13,
                 'gamma_input': 1.0,
                 'gamma_readout': 0.333593, #2.117604964706911, #626.6499356203459, # 2.117604964706911, #626.6499356203459
                 'grid_mean_predictor': {'type': 'cortex',
                                          'input_dimensions': 2,
                                          'hidden_layers': 0,
                                          'hidden_features': 0,
                                          'final_tanh': False,
                                        }}


model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1)
model_real = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1)
model_start = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1)

  all_multi_unit_ids = set(np.hstack(shared_match_ids.values()))


In [None]:
#core_dict = OrderedDict([(k, v) for k, v in torch.load('dfc6a9fbcca790d7a6b59ff787d96356.pth.tar').items() if k[0:5] == 'core.'])
core_dict = torch.load('21ee7ec25ed7fc959838386c61f5c91a.pth.tar')

def get_grids(my_model):
    grids = {}
    for key, readout in my_model.readout.items():
        grid = readout.grid.squeeze().cpu().data.numpy()
        grids[key] = grid
    return grids

list_of_load = [] #'scales', '_features' ,'mu_transform.0.weight', 'mu_transform.0.bias'
list_of_detach = []



model_real.load_state_dict(core_dict, strict=False)
real_grids = get_grids(model_real)

remove=[]
keep=[]
for key in core_dict.keys():
    name = '.'.join(key.split('.')[2:])
    if key.split('.')[0] == 'readout': 
        if not np.isin(name, list_of_load):
            print('Not loading:    {}'.format(key))
            remove.append(key)
        else:
            keep.append(key)
for key in keep:
    print('Loading:  {}'.format(key))

for k in remove: del core_dict[k]
model_start.load_state_dict(core_dict, strict=False)
start_grids = get_grids(model_start)



model.load_state_dict(core_dict, strict=False)

for param in model.named_parameters():
    name = '.'.join(param[0].split('.')[2:])
    if param[0].split('.')[0] == 'readout':
        if np.isin(name, list_of_detach):
            print('detaching:    {}'.format(param[0]))
            param[1].requires_grad = False

# Get Trainer

In [6]:
trainer_fn = 'nnsysident.training.trainers.standard_trainer'
trainer_config = dict(track_training=True) #detach_core=True, 
trainer = builder.get_trainer(trainer_fn, trainer_config)

# Run Training

In [7]:
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=1)

correlation 0.0019739168
poisson_loss 2381225.8


Epoch 1: 100%|██████████| 280/280 [00:26<00:00, 10.42it/s]


[001|00/05] ---> 0.11457011848688126
correlation 0.11457012
poisson_loss 1215937.2


Epoch 2: 100%|██████████| 280/280 [00:10<00:00, 27.09it/s]


[002|00/05] ---> 0.16468699276447296
correlation 0.164687
poisson_loss 1177759.2


Epoch 3: 100%|██████████| 280/280 [00:09<00:00, 28.73it/s]


[003|00/05] ---> 0.22402040660381317
correlation 0.2240204
poisson_loss 1126747.0


Epoch 4: 100%|██████████| 280/280 [00:09<00:00, 28.48it/s]


[004|00/05] ---> 0.2636294960975647
correlation 0.2636295
poisson_loss 1091534.6


Epoch 5: 100%|██████████| 280/280 [00:09<00:00, 29.01it/s]


[005|00/05] ---> 0.28658995032310486
correlation 0.28658995
poisson_loss 1070227.4


Epoch 6: 100%|██████████| 280/280 [00:09<00:00, 28.47it/s]


[006|00/05] ---> 0.3041680157184601
correlation 0.30416802
poisson_loss 1055588.0


Epoch 7: 100%|██████████| 280/280 [00:09<00:00, 29.45it/s]


[007|00/05] ---> 0.3140873610973358
correlation 0.31408736
poisson_loss 1046370.75


Epoch 8: 100%|██████████| 280/280 [00:09<00:00, 29.22it/s]


[008|00/05] ---> 0.32155558466911316
correlation 0.32155558
poisson_loss 1038759.6


Epoch 9: 100%|██████████| 280/280 [00:09<00:00, 29.38it/s]


[009|00/05] ---> 0.3264850974082947
correlation 0.3264851
poisson_loss 1033640.9


Epoch 10: 100%|██████████| 280/280 [00:09<00:00, 28.65it/s]


[010|00/05] ---> 0.32948240637779236
correlation 0.3294824
poisson_loss 1030951.25


Epoch 11: 100%|██████████| 280/280 [00:09<00:00, 29.32it/s]


[011|00/05] ---> 0.33533450961112976
correlation 0.3353345
poisson_loss 1024743.5


Epoch 12: 100%|██████████| 280/280 [00:09<00:00, 29.42it/s]


[012|00/05] ---> 0.33724096417427063
correlation 0.33724096
poisson_loss 1023318.6


Epoch 13: 100%|██████████| 280/280 [00:09<00:00, 28.96it/s]


[013|00/05] ---> 0.34007319808006287
correlation 0.3400732
poisson_loss 1021328.5


Epoch 14: 100%|██████████| 280/280 [00:09<00:00, 28.89it/s]


[014|00/05] ---> 0.3440011143684387
correlation 0.3440011
poisson_loss 1016834.4


Epoch 15: 100%|██████████| 280/280 [00:09<00:00, 29.31it/s]


[015|01/05] -/-> 0.34258317947387695
correlation 0.34258318
poisson_loss 1017582.1


Epoch 16: 100%|██████████| 280/280 [00:09<00:00, 29.17it/s]


[016|01/05] ---> 0.34568220376968384
correlation 0.3456822
poisson_loss 1016028.5


Epoch 17: 100%|██████████| 280/280 [00:09<00:00, 28.02it/s]


[017|00/05] ---> 0.34950020909309387
correlation 0.3495002
poisson_loss 1011286.0


Epoch 18: 100%|██████████| 280/280 [00:09<00:00, 28.98it/s]


[018|00/05] ---> 0.3508361279964447
correlation 0.35083613
poisson_loss 1010315.0


Epoch 19: 100%|██████████| 280/280 [00:09<00:00, 28.42it/s]


[019|00/05] ---> 0.3517776429653168
correlation 0.35177764
poisson_loss 1010146.25


Epoch 20: 100%|██████████| 280/280 [00:09<00:00, 28.63it/s]


[020|00/05] ---> 0.35322341322898865
correlation 0.3532234
poisson_loss 1008098.9


Epoch 21: 100%|██████████| 280/280 [00:09<00:00, 28.60it/s]


[021|00/05] ---> 0.35605987906455994
correlation 0.35605988
poisson_loss 1005133.1


Epoch 22: 100%|██████████| 280/280 [00:09<00:00, 28.70it/s]


[022|00/05] ---> 0.356294184923172
correlation 0.35629418
poisson_loss 1006189.5


Epoch 23: 100%|██████████| 280/280 [00:09<00:00, 28.75it/s]


[023|00/05] ---> 0.35785990953445435
correlation 0.3578599
poisson_loss 1002774.2


Epoch 24: 100%|██████████| 280/280 [00:09<00:00, 28.78it/s]


[024|00/05] ---> 0.35946714878082275
correlation 0.35946715
poisson_loss 1001375.1


Epoch 25: 100%|██████████| 280/280 [00:09<00:00, 28.39it/s]


[025|00/05] ---> 0.3625336289405823
correlation 0.36253363
poisson_loss 998479.8


Epoch 26: 100%|██████████| 280/280 [00:09<00:00, 28.44it/s]


[026|01/05] -/-> 0.36138710379600525
correlation 0.3613871
poisson_loss 1001708.0


Epoch 27: 100%|██████████| 280/280 [00:09<00:00, 28.16it/s]


[027|01/05] ---> 0.3651205003261566
correlation 0.3651205
poisson_loss 995775.5


Epoch 28: 100%|██████████| 280/280 [00:09<00:00, 28.50it/s]


[028|01/05] -/-> 0.36412161588668823
correlation 0.36412162
poisson_loss 997228.25


Epoch 29: 100%|██████████| 280/280 [00:09<00:00, 28.86it/s]


[029|02/05] -/-> 0.3635002374649048
correlation 0.36350024
poisson_loss 998789.0


Epoch 30: 100%|██████████| 280/280 [00:09<00:00, 28.49it/s]


[030|03/05] -/-> 0.36460208892822266
correlation 0.3646021
poisson_loss 996971.9


Epoch 31: 100%|██████████| 280/280 [00:09<00:00, 29.15it/s]


[031|04/05] -/-> 0.36369574069976807
correlation 0.36369574
poisson_loss 998193.0


Epoch 32: 100%|██████████| 280/280 [00:10<00:00, 27.70it/s]


[032|05/05] -/-> 0.36376360058784485
Restoring best model after lr decay! 0.363764 ---> 0.365121
correlation 0.3651205
poisson_loss 995775.5


Epoch 33: 100%|██████████| 280/280 [00:09<00:00, 28.23it/s]


Epoch    33: reducing learning rate of group 0 to 1.5000e-03.
[033|01/05] -/-> 0.36147627234458923
correlation 0.36147627
poisson_loss 1000436.1


Epoch 34: 100%|██████████| 280/280 [00:10<00:00, 27.34it/s]


[034|01/05] ---> 0.3701150715351105
correlation 0.37011507
poisson_loss 988723.56


Epoch 35: 100%|██████████| 280/280 [00:10<00:00, 27.95it/s]


[035|00/05] ---> 0.3701483905315399
correlation 0.3701484
poisson_loss 989184.9


Epoch 36: 100%|██████████| 280/280 [00:09<00:00, 29.57it/s]


[036|01/05] -/-> 0.3694760203361511
correlation 0.36947602
poisson_loss 992000.75


Epoch 37: 100%|██████████| 280/280 [00:09<00:00, 28.84it/s]


[037|01/05] ---> 0.3705351650714874
correlation 0.37053517
poisson_loss 988919.75


Epoch 38: 100%|██████████| 280/280 [00:09<00:00, 28.62it/s]


[038|01/05] -/-> 0.3696020543575287
correlation 0.36960205
poisson_loss 989862.5


Epoch 39: 100%|██████████| 280/280 [00:09<00:00, 29.47it/s]


[039|02/05] -/-> 0.36921530961990356
correlation 0.3692153
poisson_loss 990644.06


Epoch 40: 100%|██████████| 280/280 [00:09<00:00, 29.46it/s]


[040|02/05] ---> 0.37127459049224854
correlation 0.3712746
poisson_loss 988646.25


Epoch 41: 100%|██████████| 280/280 [00:09<00:00, 29.42it/s]


[041|01/05] -/-> 0.37050527334213257
correlation 0.37050527
poisson_loss 989302.25


Epoch 42: 100%|██████████| 280/280 [00:09<00:00, 29.02it/s]


[042|02/05] -/-> 0.3704655170440674
correlation 0.37046552
poisson_loss 991292.75


Epoch 43: 100%|██████████| 280/280 [00:09<00:00, 29.00it/s]


[043|03/05] -/-> 0.370671808719635
correlation 0.3706718
poisson_loss 989206.8


Epoch 44: 100%|██████████| 280/280 [00:09<00:00, 29.48it/s]


[044|04/05] -/-> 0.3710007071495056
correlation 0.3710007
poisson_loss 989149.75


Epoch 45: 100%|██████████| 280/280 [00:09<00:00, 28.98it/s]


[045|04/05] ---> 0.37204253673553467
correlation 0.37204254
poisson_loss 988327.94


Epoch 46: 100%|██████████| 280/280 [00:09<00:00, 29.21it/s]


[046|01/05] -/-> 0.370368093252182
correlation 0.3703681
poisson_loss 989729.75


Epoch 47: 100%|██████████| 280/280 [00:09<00:00, 29.01it/s]


[047|02/05] -/-> 0.3708271384239197
correlation 0.37082714
poisson_loss 989385.25


Epoch 48: 100%|██████████| 280/280 [00:09<00:00, 29.26it/s]


[048|03/05] -/-> 0.37131428718566895
correlation 0.3713143
poisson_loss 988657.9


Epoch 49: 100%|██████████| 280/280 [00:09<00:00, 29.53it/s]


[049|04/05] -/-> 0.37005454301834106
correlation 0.37005454
poisson_loss 990344.1


Epoch 50: 100%|██████████| 280/280 [00:10<00:00, 27.96it/s]


[050|05/05] -/-> 0.3701555132865906
Restoring best model after lr decay! 0.370156 ---> 0.372043
correlation 0.37204254
poisson_loss 988327.94


Epoch 51: 100%|██████████| 280/280 [00:10<00:00, 27.18it/s]


Epoch    51: reducing learning rate of group 0 to 4.5000e-04.
[051|01/05] -/-> 0.368363618850708
correlation 0.36836362
poisson_loss 992332.25


Epoch 52: 100%|██████████| 280/280 [00:09<00:00, 28.03it/s]


[052|01/05] ---> 0.37224283814430237
correlation 0.37224284
poisson_loss 988017.25


Epoch 53: 100%|██████████| 280/280 [00:10<00:00, 27.96it/s]


[053|00/05] ---> 0.37297868728637695
correlation 0.3729787
poisson_loss 986660.0


Epoch 54: 100%|██████████| 280/280 [00:09<00:00, 28.11it/s]


[054|01/05] -/-> 0.3723757863044739
correlation 0.3723758
poisson_loss 986953.75


Epoch 55: 100%|██████████| 280/280 [00:09<00:00, 28.96it/s]


[055|01/05] ---> 0.37322643399238586
correlation 0.37322643
poisson_loss 986726.1


Epoch 56: 100%|██████████| 280/280 [00:09<00:00, 29.48it/s]


[056|01/05] -/-> 0.3720886707305908
correlation 0.37208867
poisson_loss 987605.4


Epoch 57: 100%|██████████| 280/280 [00:09<00:00, 29.46it/s]


[057|02/05] -/-> 0.3719620406627655
correlation 0.37196204
poisson_loss 987140.1


Epoch 58: 100%|██████████| 280/280 [00:09<00:00, 29.45it/s]


[058|03/05] -/-> 0.3728509545326233
correlation 0.37285095
poisson_loss 986870.44


Epoch 59: 100%|██████████| 280/280 [00:09<00:00, 29.23it/s]


[059|03/05] ---> 0.3735097050666809
correlation 0.3735097
poisson_loss 986059.9


Epoch 60: 100%|██████████| 280/280 [00:09<00:00, 28.86it/s]


[060|01/05] -/-> 0.3733965754508972
correlation 0.37339658
poisson_loss 986663.9


Epoch 61: 100%|██████████| 280/280 [00:09<00:00, 29.51it/s]


[061|02/05] -/-> 0.3717556297779083
correlation 0.37175563
poisson_loss 989135.6


Epoch 62: 100%|██████████| 280/280 [00:09<00:00, 29.27it/s]


[062|03/05] -/-> 0.3716176450252533
correlation 0.37161765
poisson_loss 988323.9


Epoch 63: 100%|██████████| 280/280 [00:10<00:00, 27.34it/s]


[063|04/05] -/-> 0.37281760573387146
correlation 0.3728176
poisson_loss 986451.75


Epoch 64: 100%|██████████| 280/280 [00:10<00:00, 27.99it/s]


[064|05/05] -/-> 0.3726043403148651
Restoring best model after lr decay! 0.372604 ---> 0.373510
Restoring best model! 0.373510 ---> 0.373510


Best old:                                  0.376 - 0.379  
Best new (no shared position):             0.377  
Best new (with shared position):           0.373510  

In [9]:
from torch.nn import Parameter

a = Parameter(torch.Tensor(1, 10, 1, 1))

In [13]:
a.data.normal_(1, 1.e-3)

tensor([[[[1.0010]],

         [[0.9991]],

         [[0.9998]],

         [[0.9999]],

         [[0.9995]],

         [[0.9991]],

         [[0.9998]],

         [[1.0010]],

         [[1.0010]],

         [[1.0003]]]])

In [None]:
grids = get_grids(model)
    
from nnsysident.utility.measures import get_correlations

cors = get_correlations(model, dataloaders['test'], per_neuron=False, as_dict=True)
for key, value in cors.items():
    cors[key] = np.mean(value)

In [None]:
title = 'feat|scale: random init + learned - position: random 0.33 ortho init|0 + learned + shared'

dictionary = dict(cors=cors,
                  grids = grids,
                  real_grids = real_grids,
                  start_grids = start_grids,)
#torch.save(dictionary, title)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

fig, axes = plt.subplots(2, 2, dpi=200, figsize=(6, 4.3))

for ind, (ax, (key, grid), (real_key, real_grid), (start_key, start_grid)) in enumerate(zip(axes.flat, grids.items(), real_grids.items(), start_grids.items())):
    assert key==real_key==start_key, 'keys do not match'
    ax.scatter(*grid.T, color='navy', s=.1, label='Learned position')
    ax.scatter(*real_grid.T, color='red', s=.1, label='"True" position')
    ax.scatter(*start_grid.T, color='black', s=.1, label='Init position', alpha=0.2)

    ax.set(xlim=(-1.1, 1.1), ylim=(-1.1, 1.1), label='Learned correlation')
    ax.text(-1,.8, round(cors[key], 3), fontsize=8.5, color='navy')
    sns.despine(top=True, ax=ax)
    
    ax.set(xticks=[-1, -.5, 0, .5, 1])
    ax.set_axisbelow(True)
    ax.grid(ls='--')
    if ind != 3:
        ax.set(xticklabels=[], yticklabels=[])
        
    ax.set_title(key, fontsize=10)
fig.suptitle(title.split('-')[0] + '\n' + title.split('-')[-1], y=1.05, fontsize=10.)

plt.tight_layout()
#plt.savefig(title + '.png', dpi=200, bbox_inches='tight')