# Demo Notebook how to run models on static mouse datasets

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2 

In [2]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USER']
dj.config['database.password'] = os.environ['DJ_PASS']
dj.config['enable_python_native_blobs'] = True

name = "test"
dj.config['schema_name'] = f"konstantin_nnsysident_{name}"

In [3]:
import torch
import numpy as np
import pickle 

import nnfabrik
from nnfabrik.main import *
from nnfabrik import builder

from nnsysident.datasets.mouse_loaders import static_shared_loaders
from nnsysident.datasets.mouse_loaders import static_loaders
from nnsysident.datasets.mouse_loaders import static_loader

Connecting konstantin@sinzlab.chlkmukhxp6i.eu-central-1.rds.amazonaws.com:3306


# Get Dataloader

In [4]:
# change path here
paths = ['data/static22564-2-12-preproc0.zip']

dataset_fn = 'nnsysident.datasets.mouse_loaders.static_loaders'
dataset_config = dict(
    paths=paths,
    batch_size=64,
    seed=1
)
dataloaders = builder.get_data(dataset_fn, dataset_config)

data/static22564-2-12-preproc0 exists already. Not unpacking data/static22564-2-12-preproc0.zip


In [5]:
dataloaders

OrderedDict([('train',
              OrderedDict([('22564-2-12-0',
                            <torch.utils.data.dataloader.DataLoader at 0x7f486f8d5520>)])),
             ('validation',
              OrderedDict([('22564-2-12-0',
                            <torch.utils.data.dataloader.DataLoader at 0x7f486f8d57c0>)])),
             ('test',
              OrderedDict([('22564-2-12-0',
                            <torch.utils.data.dataloader.DataLoader at 0x7f486f8d5730>)]))])

# Get Model

### The old gaussian readout

In [None]:
model_fn = 'nnvision.models.se_core_gauss_readout'
model_config = {
   'pad_input': False,
   'stack': -1,
   'layers':4,
   'input_kern': 9,
   'gamma_input': 20,
   'gamma_readout': 0.012,
   'hidden_dilation': 1,
   'hidden_kern': 7,
   'hidden_channels': 64,
    'depth_separable': True,
    
}
model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1000)

### Spatial Transformer

In [None]:
model_fn = 'nnvision.models.se_core_spatialXfeature_readout'
model_config = {
   'pad_input': False,
   'stack': -1,
   'layers':4,
   'input_kern': 9,
   'gamma_input': 20,
   'gamma_readout': 0.005,
   'hidden_dilation': 1,
   'hidden_kern': 7,
   'hidden_channels': 64,
    'init_noise': 1e-3,
    'depth_separable': True,
    
}
model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1000)

### The New gaussian readout: change gauss_type for the different modes

In [None]:
model_fn = 'nnsysident.models.models.se2d_fullgaussian2d'
model_config = {
}
model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1000)

In [6]:
model_fn = 'nnsysident.models.models.se2d_spatialxfeaturelinear'
model_config = {
}
model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1000)

In [7]:
model

Encoder(
  (core): SE2dCore(
    (_input_weights_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv): Conv2d(1, 64, kernel_size=(9, 9), stride=(1, 1), bias=False)
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0, inplace=True)
      )
      (layer1): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (spatial_conv): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64, bias=False)
          (out_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0, inplace=True)
      )
      (layer2): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (i

# Get Trainer

In [8]:
trainer_fn = 'nnsysident.training.trainers.standard_trainer'
trainer_config = dict()
trainer = builder.get_trainer(trainer_fn, trainer_config)

# Run Training

In [None]:
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=123)

Epoch 1: 100%|██████████| 70/70 [00:15<00:00,  4.47it/s]


[001|00/05] ---> 0.11776267737150192


Epoch 2: 100%|██████████| 70/70 [00:09<00:00,  7.44it/s]


[002|00/05] ---> 0.16628608107566833


Epoch 3: 100%|██████████| 70/70 [00:09<00:00,  7.53it/s]


[003|00/05] ---> 0.19108696281909943


Epoch 4: 100%|██████████| 70/70 [00:09<00:00,  7.37it/s]


[004|00/05] ---> 0.2120126485824585


Epoch 5: 100%|██████████| 70/70 [00:09<00:00,  7.27it/s]


[005|00/05] ---> 0.21928304433822632


Epoch 6: 100%|██████████| 70/70 [00:09<00:00,  7.35it/s]


[006|00/05] ---> 0.22545452415943146


Epoch 7: 100%|██████████| 70/70 [00:09<00:00,  7.27it/s]


[007|00/05] ---> 0.23142990469932556


Epoch 8: 100%|██████████| 70/70 [00:09<00:00,  7.34it/s]


[008|00/05] ---> 0.23637323081493378


Epoch 9: 100%|██████████| 70/70 [00:09<00:00,  7.30it/s]


[009|00/05] ---> 0.23967482149600983


Epoch 10: 100%|██████████| 70/70 [00:09<00:00,  7.32it/s]


[010|00/05] ---> 0.24208907783031464


Epoch 11: 100%|██████████| 70/70 [00:09<00:00,  7.24it/s]


[011|00/05] ---> 0.24248376488685608


Epoch 12: 100%|██████████| 70/70 [00:10<00:00,  6.71it/s]


[012|00/05] ---> 0.2486010044813156


Epoch 13: 100%|██████████| 70/70 [00:10<00:00,  6.87it/s]


[013|00/05] ---> 0.250047504901886


Epoch 14: 100%|██████████| 70/70 [00:10<00:00,  6.56it/s]


[014|00/05] ---> 0.2508845925331116


Epoch 15: 100%|██████████| 70/70 [00:11<00:00,  6.23it/s]


[015|01/05] -/-> 0.24836283922195435


Epoch 16: 100%|██████████| 70/70 [00:10<00:00,  6.63it/s]


[016|01/05] ---> 0.25437527894973755


Epoch 17: 100%|██████████| 70/70 [00:11<00:00,  6.33it/s]


[017|00/05] ---> 0.25563788414001465


Epoch 18: 100%|██████████| 70/70 [00:11<00:00,  6.16it/s]


[018|01/05] -/-> 0.25360172986984253


Epoch 19: 100%|██████████| 70/70 [00:11<00:00,  5.85it/s]


[019|01/05] ---> 0.2567400336265564


Epoch 20: 100%|██████████| 70/70 [00:11<00:00,  6.09it/s]


[020|01/05] -/-> 0.2566647529602051


Epoch 21: 100%|██████████| 70/70 [00:11<00:00,  6.08it/s]


[021|01/05] ---> 0.26207494735717773


Epoch 22: 100%|██████████| 70/70 [00:11<00:00,  6.31it/s]


[022|00/05] ---> 0.26504236459732056


Epoch 23: 100%|██████████| 70/70 [00:12<00:00,  5.43it/s]


[023|01/05] -/-> 0.2628045678138733


Epoch 24: 100%|██████████| 70/70 [00:13<00:00,  5.34it/s]
