# Demo Notebook how to run models on static mouse datasets

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2 

In [2]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USER']
dj.config['database.password'] = os.environ['DJ_PASS']
dj.config['enable_python_native_blobs'] = True
dj.config['display.limit'] = 200

name = 'iclr'
dj.config['schema_name'] = f"konstantin_nnsysident_{name}"

In [3]:
import re
import torch
import numpy as np
import pickle 
import pandas as pd
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_rows', 10)
from collections import OrderedDict, Iterable
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns

import nnfabrik
from nnfabrik.main import *
from nnfabrik import builder
from nnfabrik.utility.hypersearch import Bayesian

from nnsysident.tables.experiments import *
from nnsysident.tables.bayesian import *
from nnsysident.datasets.mouse_loaders import static_shared_loaders, static_loaders, static_loader
from nnsysident.tables.scoring import OracleScore, OracleScoreTransfer

  from collections import OrderedDict, Iterable


Connecting konstantin@134.2.168.16:3306
Schema name: konstantin_nnsysident_iclr


---

# Dataset

In [9]:
# change path here
paths = ['data/static20457-5-9-preproc0.zip']

dataset_fn = 'nnsysident.datasets.mouse_loaders.static_loaders'
dataset_config = dict(
    paths=paths,
    batch_size=64,
    seed=1,
    file_tree=True,
)

dataloaders = builder.get_data(dataset_fn, dataset_config)

data/static20457-5-9-preproc0 exists already. Not unpacking data/static20457-5-9-preproc0.zip


In [15]:
dataloaders['train']['20457-5-9-0'].dataset.neurons.cell_motor_coordinates

array([[-301., -305.,  210.],
       [-304., -285.,  210.],
       [-283., -306.,  210.],
       ...,
       [  86., -305.,  245.],
       [ 270., -197.,  245.],
       [-222.,  -40.,  245.]])

# Model

In [5]:
model_fn = 'nnsysident.models.models.se2d_fullgaussian2d'

model_config = {'init_mu_range': 0.55,
                 'init_sigma': 0.4,
                 'input_kern': 15,
                 'hidden_kern': 13,
                 'gamma_input': 1.0,
                 'grid_mean_predictor': None,
                 'gamma_readout': 2.439}


model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1)

# Trainer

In [6]:
detach_core=False  # this should be True for transfer learning

trainer_fn = 'nnsysident.training.trainers.standard_trainer'
trainer_config = dict(track_training=True, detach_core=detach_core)  
trainer = builder.get_trainer(trainer_fn, trainer_config)

# Run Training

In [7]:
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=1)

correlation 0.0010534218
poisson_loss 4029516.2


Epoch 1: 100%|██████████| 69/69 [00:12<00:00,  5.48it/s]


[001|00/05] ---> 0.07745499163866043
correlation 0.07745499
poisson_loss 2424402.5


Epoch 2: 100%|██████████| 69/69 [00:06<00:00, 10.64it/s]


[002|00/05] ---> 0.09418176859617233
correlation 0.09418177
poisson_loss 2394455.8


Epoch 3: 100%|██████████| 69/69 [00:06<00:00, 10.58it/s]


[003|00/05] ---> 0.11241321265697479
correlation 0.11241321
poisson_loss 2365204.0


Epoch 4: 100%|██████████| 69/69 [00:06<00:00, 10.76it/s]


[004|00/05] ---> 0.12286487221717834
correlation 0.12286487
poisson_loss 2363228.2


Epoch 5: 100%|██████████| 69/69 [00:06<00:00, 10.72it/s]


[005|00/05] ---> 0.13239504396915436
correlation 0.13239504
poisson_loss 2348261.2


Epoch 6: 100%|██████████| 69/69 [00:06<00:00, 10.62it/s]


[006|00/05] ---> 0.14011405408382416
correlation 0.14011405
poisson_loss 2337927.5


Epoch 7: 100%|██████████| 69/69 [00:06<00:00, 10.76it/s]


[007|00/05] ---> 0.1469586044549942
correlation 0.1469586
poisson_loss 2353265.0


Epoch 8: 100%|██████████| 69/69 [00:06<00:00, 10.78it/s]


[008|00/05] ---> 0.15796317160129547
correlation 0.15796317
poisson_loss 2319957.0


Epoch 9: 100%|██████████| 69/69 [00:06<00:00, 10.65it/s]


[009|01/05] -/-> 0.14025233685970306
correlation 0.14025234
poisson_loss 2475463.8


Epoch 10: 100%|██████████| 69/69 [00:06<00:00, 10.49it/s]


[010|01/05] ---> 0.1693505197763443
correlation 0.16935052
poisson_loss 2309706.5


Epoch 11: 100%|██████████| 69/69 [00:06<00:00, 10.66it/s]


[011|00/05] ---> 0.17268837988376617
correlation 0.17268838
poisson_loss 2311567.0


Epoch 12: 100%|██████████| 69/69 [00:06<00:00, 10.74it/s]


[012|00/05] ---> 0.18217767775058746
correlation 0.18217768
poisson_loss 2306825.2


Epoch 13: 100%|██████████| 69/69 [00:06<00:00, 10.70it/s]


[013|00/05] ---> 0.18871542811393738
correlation 0.18871543
poisson_loss 2281200.5


Epoch 14: 100%|██████████| 69/69 [00:06<00:00, 10.72it/s]


[014|01/05] -/-> 0.18461741507053375
correlation 0.18461742
poisson_loss 2291466.2


Epoch 15: 100%|██████████| 69/69 [00:06<00:00, 10.57it/s]


[015|02/05] -/-> 0.18754203617572784
correlation 0.18754204
poisson_loss 2301473.5


Epoch 16: 100%|██████████| 69/69 [00:06<00:00, 10.75it/s]


[016|02/05] ---> 0.19247876107692719
correlation 0.19247876
poisson_loss 2290604.0


Epoch 17: 100%|██████████| 69/69 [00:06<00:00, 10.68it/s]


[017|01/05] -/-> 0.18939627707004547
correlation 0.18939628
poisson_loss 2316726.5


Epoch 18: 100%|██████████| 69/69 [00:06<00:00, 10.78it/s]


[018|01/05] ---> 0.19611643254756927
correlation 0.19611643
poisson_loss 2283056.5


Epoch 19: 100%|██████████| 69/69 [00:06<00:00, 10.55it/s]


[019|00/05] ---> 0.19722841680049896
correlation 0.19722842
poisson_loss 2280220.8


Epoch 20: 100%|██████████| 69/69 [00:06<00:00, 10.79it/s]


[020|00/05] ---> 0.19849933683872223
correlation 0.19849934
poisson_loss 2292558.0


Epoch 21: 100%|██████████| 69/69 [00:06<00:00, 10.77it/s]


[021|00/05] ---> 0.20389460027217865
correlation 0.2038946
poisson_loss 2263162.5


Epoch 22: 100%|██████████| 69/69 [00:06<00:00, 10.74it/s]


[022|01/05] -/-> 0.2020898163318634
correlation 0.20208982
poisson_loss 2273013.5


Epoch 23: 100%|██████████| 69/69 [00:06<00:00, 10.67it/s]


[023|02/05] -/-> 0.20214639604091644
correlation 0.2021464
poisson_loss 2272157.0


Epoch 24: 100%|██████████| 69/69 [00:06<00:00, 10.74it/s]


[024|03/05] -/-> 0.19909223914146423
correlation 0.19909224
poisson_loss 2274666.0


Epoch 25: 100%|██████████| 69/69 [00:06<00:00, 10.66it/s]


[025|04/05] -/-> 0.20384334027767181
correlation 0.20384334
poisson_loss 2275629.5


Epoch 26: 100%|██████████| 69/69 [00:06<00:00, 10.64it/s]


[026|05/05] -/-> 0.19713424146175385
Restoring best model after lr decay! 0.197134 ---> 0.203895
correlation 0.2038946
poisson_loss 2263162.5


Epoch 27: 100%|██████████| 69/69 [00:06<00:00, 10.66it/s]


Epoch    27: reducing learning rate of group 0 to 1.5000e-03.
[027|01/05] -/-> 0.20017088949680328
correlation 0.20017089
poisson_loss 2272522.5


Epoch 28: 100%|██████████| 69/69 [00:06<00:00, 10.50it/s]


[028|02/05] -/-> 0.19888649880886078
correlation 0.1988865
poisson_loss 2278176.8


Epoch 29: 100%|██████████| 69/69 [00:06<00:00, 10.56it/s]


[029|02/05] ---> 0.20408618450164795
correlation 0.20408618
poisson_loss 2268965.5


Epoch 30: 100%|██████████| 69/69 [00:06<00:00, 10.60it/s]


[030|00/05] ---> 0.2066251039505005
correlation 0.2066251
poisson_loss 2268577.0


Epoch 31: 100%|██████████| 69/69 [00:06<00:00, 10.67it/s]


[031|01/05] -/-> 0.20179226994514465
correlation 0.20179227
poisson_loss 2286463.5


Epoch 32: 100%|██████████| 69/69 [00:06<00:00, 10.62it/s]


[032|01/05] ---> 0.20890842378139496
correlation 0.20890842
poisson_loss 2257435.5


Epoch 33: 100%|██████████| 69/69 [00:06<00:00, 10.68it/s]


[033|01/05] -/-> 0.20799380540847778
correlation 0.2079938
poisson_loss 2262540.5


Epoch 34: 100%|██████████| 69/69 [00:06<00:00, 10.80it/s]


[034|02/05] -/-> 0.20465344190597534
correlation 0.20465344
poisson_loss 2271765.0


Epoch 35: 100%|██████████| 69/69 [00:06<00:00, 10.58it/s]


[035|02/05] ---> 0.20976875722408295
correlation 0.20976876
poisson_loss 2259881.5


Epoch 36: 100%|██████████| 69/69 [00:06<00:00, 10.72it/s]


[036|00/05] ---> 0.2103910744190216
correlation 0.21039107
poisson_loss 2258942.0


Epoch 37: 100%|██████████| 69/69 [00:06<00:00, 10.64it/s]


[037|01/05] -/-> 0.2070290446281433
correlation 0.20702904
poisson_loss 2271747.0


Epoch 38: 100%|██████████| 69/69 [00:06<00:00, 10.70it/s]


[038|02/05] -/-> 0.1949913054704666
correlation 0.1949913
poisson_loss 2303824.2


Epoch 39: 100%|██████████| 69/69 [00:06<00:00, 10.59it/s]


[039|03/05] -/-> 0.20876780152320862
correlation 0.2087678
poisson_loss 2258374.5


Epoch 40: 100%|██████████| 69/69 [00:06<00:00, 10.73it/s]


[040|04/05] -/-> 0.2060539871454239
correlation 0.20605399
poisson_loss 2267716.2


Epoch 41: 100%|██████████| 69/69 [00:06<00:00, 10.50it/s]


[041|05/05] -/-> 0.2082003653049469
Restoring best model after lr decay! 0.208200 ---> 0.210391
correlation 0.21039107
poisson_loss 2258942.0


Epoch 42: 100%|██████████| 69/69 [00:06<00:00, 10.73it/s]


Epoch    42: reducing learning rate of group 0 to 4.5000e-04.
[042|01/05] -/-> 0.20726580917835236
correlation 0.20726581
poisson_loss 2263447.5


Epoch 43: 100%|██████████| 69/69 [00:06<00:00, 10.70it/s]


[043|02/05] -/-> 0.20154309272766113
correlation 0.2015431
poisson_loss 2285970.5


Epoch 44: 100%|██████████| 69/69 [00:06<00:00, 10.63it/s]


[044|03/05] -/-> 0.200156107544899
correlation 0.20015611
poisson_loss 2295849.0


Epoch 45: 100%|██████████| 69/69 [00:06<00:00, 10.76it/s]


[045|04/05] -/-> 0.20769421756267548
correlation 0.20769422
poisson_loss 2269215.0


Epoch 46: 100%|██████████| 69/69 [00:06<00:00, 10.61it/s]


[046|04/05] ---> 0.2107335478067398
correlation 0.21073355
poisson_loss 2265564.0


Epoch 47: 100%|██████████| 69/69 [00:06<00:00, 10.74it/s]


[047|00/05] ---> 0.21132005751132965
correlation 0.21132006
poisson_loss 2252679.0


Epoch 48: 100%|██████████| 69/69 [00:06<00:00, 10.77it/s]


[048|01/05] -/-> 0.2068132758140564
correlation 0.20681328
poisson_loss 2279474.0


Epoch 49: 100%|██████████| 69/69 [00:06<00:00, 10.77it/s]


[049|02/05] -/-> 0.20844539999961853
correlation 0.2084454
poisson_loss 2265216.2


Epoch 50: 100%|██████████| 69/69 [00:06<00:00, 10.73it/s]


[050|03/05] -/-> 0.20527330040931702
correlation 0.2052733
poisson_loss 2269974.5


Epoch 51: 100%|██████████| 69/69 [00:06<00:00, 10.63it/s]


[051|04/05] -/-> 0.207217738032341
correlation 0.20721774
poisson_loss 2263188.0


Epoch 52: 100%|██████████| 69/69 [00:06<00:00, 10.73it/s]


[052|05/05] -/-> 0.2079562097787857
Restoring best model after lr decay! 0.207956 ---> 0.211320
Restoring best model! 0.211320 ---> 0.211320
