# Demo Notebook how to run models on static mouse datasets

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2 

In [None]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USER']
dj.config['database.password'] = os.environ['DJ_PASS']
dj.config['enable_python_native_blobs'] = True

name = "test"
dj.config['schema_name'] = f"konstantin_nnsysident_{name}"

In [None]:
import torch
import numpy as np
import pickle 
import pandas as pd
from collections import OrderedDict, Iterable

import nnfabrik
from nnfabrik.main import *
from nnfabrik import builder

from nnsysident.tables.experiments import *
from nnsysident.datasets.mouse_loaders import static_shared_loaders
from nnsysident.datasets.mouse_loaders import static_loaders
from nnsysident.datasets.mouse_loaders import static_loader

# Get Dataloader

In [None]:
# change path here
paths = ['data/static0-0-2-preproc0.zip']

dataset_fn = 'nnsysident.datasets.mouse_loaders.static_loaders'
dataset_config = dict(
    paths=paths,
    batch_size=64,
    seed=1
)
dataloaders = builder.get_data(dataset_fn, dataset_config)

# Get Model

### The old gaussian readout

In [None]:
model_fn = 'nnvision.models.se_core_gauss_readout'
model_config = {
   'pad_input': False,
   'stack': -1,
   'layers':4,
   'input_kern': 9,
   'gamma_input': 20,
   'gamma_readout': 0.012,
   'hidden_dilation': 1,
   'hidden_kern': 7,
   'hidden_channels': 64,
    'depth_separable': True,
    
}
model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1000)

### Spatial Transformer

In [None]:
model_fn = 'nnvision.models.se_core_spatialXfeature_readout'
model_config = {
   'pad_input': False,
   'stack': -1,
   'layers':4,
   'input_kern': 9,
   'gamma_input': 20,
   'gamma_readout': 0.005,
   'hidden_dilation': 1,
   'hidden_kern': 7,
   'hidden_channels': 64,
    'init_noise': 1e-3,
    'depth_separable': True,
    
}
model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1000)

### The New gaussian readout: change gauss_type for the different modes

In [None]:
model_fn = 'nnsysident.models.models.se2d_fullgaussian2d'
model_config = {"input_kern": 15, "hidden_kern": 11,
}
model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1000)

In [None]:
model_fn = 'nnsysident.models.models.se2d_spatialxfeaturelinear'
model_config = {
}
model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1000)

In [None]:
trained_model_table = TrainedModel
t_model_hash = "1ad851e940460d0d962b2d323f7d3b6c"
t_dataset_hash = "53c03d1dcf82d468513dfd5f2e20e85c"
t_trainer_hash = "d41d8cd98f00b204e9800998ecf8427e"

In [None]:
restricted_trained_model_table = (trained_model_table &
                                          "model_hash = '{}'".format(t_model_hash) &
                                          "dataset_hash = '{}'".format(t_dataset_hash) &
                                          "trainer_hash = '{}'".format(t_trainer_hash))
trained_model_entries = pd.DataFrame(restricted_trained_model_table.fetch())
trained_model_entry = trained_model_entries.loc[trained_model_entries['score'] == trained_model_entries['score'].max()]
state_dict = (restricted_trained_model_table * restricted_trained_model_table.ModelStorage & "seed = {}".format(int(trained_model_entry['seed']))).fetch1('model_state', download_path='models/')
core_dict = OrderedDict([(k, v) for k, v in torch.load(state_dict).items() if k[0:5] == 'core.'])
model.load_state_dict(core_dict, strict=False)

# Get Trainer

In [None]:
trainer_fn = 'nnsysident.training.trainers.standard_trainer'
trainer_config = dict()
trainer = builder.get_trainer(trainer_fn, trainer_config)

# Run Training

In [None]:
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=123)