# Demo Notebook how to run models on static mouse datasets

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2 

In [2]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USER']
dj.config['database.password'] = os.environ['DJ_PASS']
dj.config['enable_python_native_blobs'] = True

name = "test"
dj.config['schema_name'] = f"konstantin_nnsysident_{name}"

In [3]:
import torch
import numpy as np
import pickle 
import pandas as pd
from collections import OrderedDict, Iterable

import nnfabrik
from nnfabrik.main import *
from nnfabrik import builder

from nnsysident.tables.experiments import *
from nnsysident.datasets.mouse_loaders import static_shared_loaders
from nnsysident.datasets.mouse_loaders import static_loaders
from nnsysident.datasets.mouse_loaders import static_loader

  from collections import OrderedDict, Iterable


Connecting konstantin@sinzlab.chlkmukhxp6i.eu-central-1.rds.amazonaws.com:3306


# Get Dataloader

In [6]:
# change path here
paths = ['data/static22564-2-12-preproc0.zip',
                     'data/static22564-2-13-preproc0.zip',
                     'data/static22564-3-8-preproc0.zip',
                     'data/static22564-3-12-preproc0.zip']

dataset_fn = 'nnsysident.datasets.mouse_loaders.static_shared_loaders'
dataset_config = dict(
    paths=paths,
    batch_size=64,
    multi_match_n=3625,
    image_n=4399,
    seed=1
)
dataloaders = builder.get_data(dataset_fn, dataset_config)

data/static22564-2-12-preproc0 exists already. Not unpacking data/static22564-2-12-preproc0.zip
data/static22564-2-13-preproc0 exists already. Not unpacking data/static22564-2-13-preproc0.zip
data/static22564-3-8-preproc0 exists already. Not unpacking data/static22564-3-8-preproc0.zip
data/static22564-3-12-preproc0 exists already. Not unpacking data/static22564-3-12-preproc0.zip
data/static22564-2-12-preproc0 exists already. Not unpacking data/static22564-2-12-preproc0.zip
data/static22564-2-13-preproc0 exists already. Not unpacking data/static22564-2-13-preproc0.zip
data/static22564-3-8-preproc0 exists already. Not unpacking data/static22564-3-8-preproc0.zip
data/static22564-3-12-preproc0 exists already. Not unpacking data/static22564-3-12-preproc0.zip


# Get Model

### The old gaussian readout

In [None]:
model_fn = 'nnvision.models.se_core_gauss_readout'
model_config = {
   'pad_input': False,
   'stack': -1,
   'layers':4,
   'input_kern': 9,
   'gamma_input': 20,
   'gamma_readout': 0.012,
   'hidden_dilation': 1,
   'hidden_kern': 7,
   'hidden_channels': 64,
    'depth_separable': True,
    
}
model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1000)

### Spatial Transformer

In [None]:
model_fn = 'nnvision.models.se_core_spatialXfeature_readout'
model_config = {
   'pad_input': False,
   'stack': -1,
   'layers':4,
   'input_kern': 9,
   'gamma_input': 20,
   'gamma_readout': 0.005,
   'hidden_dilation': 1,
   'hidden_kern': 7,
   'hidden_channels': 64,
    'init_noise': 1e-3,
    'depth_separable': True,
    
}
model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1000)

### The New gaussian readout: change gauss_type for the different modes

In [8]:
model_fn = 'nnsysident.models.models.se2d_fullgaussian2d'

model_config = {'share_features': True,
 'init_mu_range': 0.55,
 'init_sigma': 0.4,
 'input_kern': 15,
 'hidden_kern': 13,
 'gamma_input': 1.0,
 'gamma_readout': 2.117604964706911,
 'grid_mean_predictor': {'type': 'cortex',
  'input_dimensions': 2,
  'hidden_layers': 0,
  'hidden_features': 0,
  'final_tanh': False}}


model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1)

  all_multi_unit_ids = set(np.hstack(shared_match_ids.values()))


In [None]:
model_fn = 'nnsysident.models.models.se2d_deterministicgaussian2d'
model_config = {#'share_features': True
}
model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1)

In [None]:
model_fn = 'nnsysident.models.models.se2d_spatialxfeaturelinear'
model_config = {
}
model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1)

In [None]:
restricted_trained_model_table = (trained_model_table &
                                          "model_hash = '{}'".format(t_model_hash) &
                                          "dataset_hash = '{}'".format(t_dataset_hash) &
                                          "trainer_hash = '{}'".format(t_trainer_hash))
trained_model_entries = pd.DataFrame(restricted_trained_model_table.fetch())
trained_model_entry = trained_model_entries.loc[trained_model_entries['score'] == trained_model_entries['score'].max()]
state_dict = (restricted_trained_model_table * restricted_trained_model_table.ModelStorage & "seed = {}".format(int(trained_model_entry['seed']))).fetch1('model_state', download_path='models/')
core_dict = OrderedDict([(k, v) for k, v in torch.load(state_dict).items() if k[0:5] == 'core.'])
model.load_state_dict(core_dict, strict=False)

# Get Trainer

In [9]:
trainer_fn = 'nnsysident.training.trainers.standard_trainer'
trainer_config = dict()
trainer = builder.get_trainer(trainer_fn, trainer_config)

# Run Training

In [10]:
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=1)

Epoch 1: 100%|██████████| 276/276 [01:00<00:00,  4.55it/s]


[001|00/05] ---> 0.052468642592430115


Epoch 2: 100%|██████████| 276/276 [00:33<00:00,  8.16it/s]


[002|00/05] ---> 0.07284536957740784


Epoch 3: 100%|██████████| 276/276 [00:34<00:00,  8.01it/s]


[003|00/05] ---> 0.10009320080280304


Epoch 4: 100%|██████████| 276/276 [00:35<00:00,  7.85it/s]


[004|00/05] ---> 0.14443719387054443


Epoch 5: 100%|██████████| 276/276 [00:35<00:00,  7.85it/s]


[005|00/05] ---> 0.18207822740077972


Epoch 6: 100%|██████████| 276/276 [00:35<00:00,  7.83it/s]


[006|00/05] ---> 0.21691153943538666


Epoch 7: 100%|██████████| 276/276 [00:35<00:00,  7.83it/s]


[007|00/05] ---> 0.24420961737632751


Epoch 8: 100%|██████████| 276/276 [00:35<00:00,  7.84it/s]


[008|00/05] ---> 0.258683979511261


Epoch 9: 100%|██████████| 276/276 [00:35<00:00,  7.85it/s]


[009|00/05] ---> 0.27518966794013977


Epoch 10: 100%|██████████| 276/276 [00:35<00:00,  7.85it/s]


[010|00/05] ---> 0.2845032215118408


Epoch 11: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[011|00/05] ---> 0.2974141240119934


Epoch 12: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[012|00/05] ---> 0.3081170320510864


Epoch 13: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[013|00/05] ---> 0.3166813850402832


Epoch 14: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[014|00/05] ---> 0.31989458203315735


Epoch 15: 100%|██████████| 276/276 [00:34<00:00,  7.89it/s]


[015|00/05] ---> 0.32478633522987366


Epoch 16: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[016|00/05] ---> 0.32777178287506104


Epoch 17: 100%|██████████| 276/276 [00:34<00:00,  7.89it/s]


[017|00/05] ---> 0.3292769193649292


Epoch 18: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[018|00/05] ---> 0.3345966041088104


Epoch 19: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[019|00/05] ---> 0.3368801772594452


Epoch 20: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[020|00/05] ---> 0.34146085381507874


Epoch 21: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[021|00/05] ---> 0.34321942925453186


Epoch 22: 100%|██████████| 276/276 [00:35<00:00,  7.85it/s]


[022|00/05] ---> 0.34492987394332886


Epoch 23: 100%|██████████| 276/276 [00:35<00:00,  7.86it/s]


[023|00/05] ---> 0.34803536534309387


Epoch 24: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[024|00/05] ---> 0.3480927348136902


Epoch 25: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[025|00/05] ---> 0.35232987999916077


Epoch 26: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[026|00/05] ---> 0.3527892529964447


Epoch 27: 100%|██████████| 276/276 [00:35<00:00,  7.86it/s]


[027|00/05] ---> 0.353314608335495


Epoch 28: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[028|01/05] -/-> 0.35273006558418274


Epoch 29: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[029|01/05] ---> 0.3569779098033905


Epoch 30: 100%|██████████| 276/276 [00:35<00:00,  7.85it/s]


[030|00/05] ---> 0.3578653335571289


Epoch 31: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[031|00/05] ---> 0.3583291471004486


Epoch 32: 100%|██████████| 276/276 [00:35<00:00,  7.86it/s]


[032|00/05] ---> 0.35934266448020935


Epoch 33: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[033|00/05] ---> 0.3596382439136505


Epoch 34: 100%|██████████| 276/276 [00:35<00:00,  7.85it/s]


[034|00/05] ---> 0.3622651696205139


Epoch 35: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[035|01/05] -/-> 0.36096712946891785


Epoch 36: 100%|██████████| 276/276 [00:35<00:00,  7.85it/s]


[036|02/05] -/-> 0.36044755578041077


Epoch 37: 100%|██████████| 276/276 [00:35<00:00,  7.85it/s]


[037|02/05] ---> 0.3634190857410431


Epoch 38: 100%|██████████| 276/276 [00:35<00:00,  7.85it/s]


[038|00/05] ---> 0.36404094099998474


Epoch 39: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[039|01/05] -/-> 0.36339324712753296


Epoch 40: 100%|██████████| 276/276 [00:35<00:00,  7.86it/s]


[040|01/05] ---> 0.3649275600910187


Epoch 41: 100%|██████████| 276/276 [00:35<00:00,  7.86it/s]


[041|01/05] -/-> 0.36426717042922974


Epoch 42: 100%|██████████| 276/276 [00:35<00:00,  7.86it/s]


[042|01/05] ---> 0.36650732159614563


Epoch 43: 100%|██████████| 276/276 [00:35<00:00,  7.85it/s]


[043|01/05] -/-> 0.3647529184818268


Epoch 44: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[044|01/05] ---> 0.3676719665527344


Epoch 45: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[045|01/05] -/-> 0.3668738603591919


Epoch 46: 100%|██████████| 276/276 [00:34<00:00,  7.89it/s]


[046|02/05] -/-> 0.36648064851760864


Epoch 47: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[047|02/05] ---> 0.36806029081344604


Epoch 48: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[048|01/05] -/-> 0.3661724328994751


Epoch 49: 100%|██████████| 276/276 [00:35<00:00,  7.86it/s]


[049|01/05] ---> 0.3693622946739197


Epoch 50: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[050|01/05] -/-> 0.3681705594062805


Epoch 51: 100%|██████████| 276/276 [00:35<00:00,  7.86it/s]


[051|02/05] -/-> 0.3682176172733307


Epoch 52: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[052|03/05] -/-> 0.3686762750148773


Epoch 53: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[053|03/05] ---> 0.3704022467136383


Epoch 54: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[054|01/05] -/-> 0.3698284327983856


Epoch 55: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[055|02/05] -/-> 0.36886951327323914


Epoch 56: 100%|██████████| 276/276 [00:35<00:00,  7.85it/s]


[056|02/05] ---> 0.3705386519432068


Epoch 57: 100%|██████████| 276/276 [00:34<00:00,  7.89it/s]


[057|00/05] ---> 0.37063488364219666


Epoch 58: 100%|██████████| 276/276 [00:35<00:00,  7.86it/s]


[058|01/05] -/-> 0.36978578567504883


Epoch 59: 100%|██████████| 276/276 [00:35<00:00,  7.86it/s]


[059|01/05] ---> 0.37068450450897217


Epoch 60: 100%|██████████| 276/276 [00:35<00:00,  7.85it/s]


[060|01/05] -/-> 0.3682824969291687


Epoch 61: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[061|02/05] -/-> 0.36889246106147766


Epoch 62: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[062|03/05] -/-> 0.3699387013912201


Epoch 63: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[063|04/05] -/-> 0.36893272399902344


Epoch 64: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[064|05/05] -/-> 0.37035292387008667


Epoch 65:   0%|          | 0/276 [00:00<?, ?it/s]

Restoring best model after lr decay! 0.370353 ---> 0.370685


Epoch 65: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[065|00/05] ---> 0.37152981758117676


Epoch 66: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[066|01/05] -/-> 0.37110868096351624


Epoch 67: 100%|██████████| 276/276 [00:34<00:00,  7.90it/s]


[067|02/05] -/-> 0.36912456154823303


Epoch 68: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[068|02/05] ---> 0.3716086447238922


Epoch 69: 100%|██████████| 276/276 [00:34<00:00,  7.91it/s]


[069|01/05] -/-> 0.3704628646373749


Epoch 70: 100%|██████████| 276/276 [00:35<00:00,  7.87it/s]


[070|01/05] ---> 0.3722069263458252


Epoch 71: 100%|██████████| 276/276 [00:34<00:00,  7.94it/s]


[071|01/05] -/-> 0.37006059288978577


Epoch 72: 100%|██████████| 276/276 [00:35<00:00,  7.86it/s]


[072|01/05] ---> 0.37241315841674805


Epoch 73: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[073|01/05] -/-> 0.3714292347431183


Epoch 74: 100%|██████████| 276/276 [00:34<00:00,  7.89it/s]


[074|02/05] -/-> 0.3705821931362152


Epoch 75: 100%|██████████| 276/276 [00:34<00:00,  7.93it/s]


[075|03/05] -/-> 0.37209048867225647


Epoch 76: 100%|██████████| 276/276 [00:34<00:00,  7.91it/s]


[076|04/05] -/-> 0.3694833517074585


Epoch 77: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[077|05/05] -/-> 0.3691495954990387


Epoch 78:   0%|          | 0/276 [00:00<?, ?it/s]

Restoring best model after lr decay! 0.369150 ---> 0.372413


Epoch 78: 100%|██████████| 276/276 [00:34<00:00,  7.90it/s]


Epoch    78: reducing learning rate of group 0 to 1.5000e-03.
[078|01/05] -/-> 0.3719525635242462


Epoch 79: 100%|██████████| 276/276 [00:34<00:00,  7.93it/s]


[079|01/05] ---> 0.37389376759529114


Epoch 80: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[080|00/05] ---> 0.3748821020126343


Epoch 81: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[081|01/05] -/-> 0.374287873506546


Epoch 82: 100%|██████████| 276/276 [00:34<00:00,  7.92it/s]


[082|02/05] -/-> 0.37449920177459717


Epoch 83: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[083|02/05] ---> 0.3749876916408539


Epoch 84: 100%|██████████| 276/276 [00:34<00:00,  7.96it/s]


[084|01/05] -/-> 0.3716907799243927


Epoch 85: 100%|██████████| 276/276 [00:34<00:00,  7.89it/s]


[085|02/05] -/-> 0.3733169436454773


Epoch 86: 100%|██████████| 276/276 [00:34<00:00,  7.91it/s]


[086|03/05] -/-> 0.3728507161140442


Epoch 87: 100%|██████████| 276/276 [00:34<00:00,  7.91it/s]


[087|04/05] -/-> 0.3697288930416107


Epoch 88: 100%|██████████| 276/276 [00:35<00:00,  7.88it/s]


[088|05/05] -/-> 0.3739948272705078
Restoring best model after lr decay! 0.373995 ---> 0.374988
Restoring best model! 0.374988 ---> 0.374988
