# Demo Notebook how to run models on static mouse datasets

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2 

In [None]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USER']
dj.config['database.password'] = os.environ['DJ_PASS']
dj.config['enable_python_native_blobs'] = True
dj.config['display.limit'] = 200

name = 'iclr'
dj.config['schema_name'] = f"konstantin_nnsysident_{name}"

In [None]:
import re
import torch
import numpy as np
import pickle 
import pandas as pd
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_rows', 10)
from collections import OrderedDict, Iterable
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns

import nnfabrik
from nnfabrik.main import *
from nnfabrik import builder
from nnfabrik.utility.hypersearch import Bayesian

from nnsysident.tables.experiments import *
from nnsysident.tables.bayesian import *
from nnsysident.datasets.mouse_loaders import static_shared_loaders, static_loaders, static_loader
from nnsysident.tables.scoring import OracleScore, OracleScoreTransfer

---

# Dataset

In [None]:
# change path here
paths = ['data/static20457-5-9-preproc0.zip']

dataset_fn = 'nnsysident.datasets.mouse_loaders.static_loaders'
dataset_config = dict(
    paths=paths,
    batch_size=10000,
    seed=1,
    file_tree=True,
)

dataloaders = builder.get_data(dataset_fn, dataset_config)

# Model

In [None]:
model_fn = 'nnsysident.models.models.se2d_fullgaussian2d'

model_config = {'init_mu_range': 0.55,
                 'init_sigma': 0.4,
                 'input_kern': 15,
                 'hidden_kern': 13,
                 'gamma_input': 1.0,
                 'grid_mean_predictor': None,
                 'gamma_readout': 2.439}


model = builder.get_model(model_fn=model_fn, model_config=model_config, dataloaders=dataloaders, seed=1)

# Trainer

In [None]:
detach_core=False  # this should be True for transfer learning

trainer_fn = 'nnsysident.training.trainers.standard_trainer'
trainer_config = dict(track_training=True, detach_core=detach_core)  
trainer = builder.get_trainer(trainer_fn, trainer_config)

# Run Training

In [None]:
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=1)