# Demo Notebook how to run models on static mouse datasets

In [11]:
import datajoint as dj

import os
import torch
import numpy as np
import pickle 

import nnfabrik
from nnfabrik import main, builder

# Get Dataloader

In [3]:
# change path here
paths = ['/data/mouse/toliaslab/static/static22564-3-12-preproc0.h5']

dataset_fn = 'nnvision.datasets.mouse_static_loaders'
dataset_config = dict(
    paths=paths,
    batch_size=64,
    normalize=True,
    seed=1000,
)
dataloaders = builder.get_data(dataset_fn, dataset_config)

# Get Model

### The old gaussian readout

In [5]:
model_fn = 'nnvision.models.se_core_gauss_readout'
model_config = {
   'pad_input': False,
   'stack': -1,
   'layers':4,
   'input_kern': 9,
   'gamma_input': 20,
   'gamma_readout': 0.012,
   'hidden_dilation': 1,
   'hidden_kern': 7,
   'hidden_channels': 64,
    'depth_separable': True,
    
}
model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1000)



### Spatial Transformer

In [6]:
model_fn = 'nnvision.models.se_core_spatialXfeature_readout'
model_config = {
   'pad_input': False,
   'stack': -1,
   'layers':4,
   'input_kern': 9,
   'gamma_input': 20,
   'gamma_readout': 0.005,
   'hidden_dilation': 1,
   'hidden_kern': 7,
   'hidden_channels': 64,
    'init_noise': 1e-3,
    'depth_separable': True,
    
}
model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1000)

### The New gaussian readout: change gauss_type for the different modes

In [8]:
model_fn = 'nnvision.models.se_core_full_gauss_readout'
model_config = {
   'pad_input': False,
   'stack': -1,
   'layers':4,
   'input_kern': 9,
   'gamma_input': 20,          
   'hidden_dilation': 1,
   'hidden_kern': 7,
   'hidden_channels': 64,
   'n_se_blocks': 0,
   'depth_separable': True,
    # readout parameters
   'init_mu_range': 0.3,
   'grid_mean_predictor': None,
   'share_features': False,
   'share_grid': False,
    'gauss_type': 'full',
}
model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1000)

# Get Trainer

In [9]:
trainer_fn = 'nnvision.training.nnvision_trainer'
trainer_config = dict(max_iter=100,
                      verbose=False, 
                      lr_decay_steps=4,
                      avg_loss=False, 
                      patience=5,
                      lr_init=.0045)
trainer = builder.get_trainer(trainer_fn, trainer_config)

# Run Training

In [None]:
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=1000)