# Mouse System Identification Models

In [None]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USERNAME']
dj.config['database.password'] = os.environ['DJ_PASSWORD']
dj.config['enable_python_native_blobs'] = True
dj.config['display.limit'] = 200
        
name = 'nsc'
os.environ["DJ_SCHEMA_NAME"] = f"metrics_{name}"
dj.config["nnfabrik.schema_name"] = os.environ["DJ_SCHEMA_NAME"]

In [None]:
import torch
import numpy as np

from nnsysident.training.trainers import standard_trainer
from nnsysident.models.models import Stacked2dFullGaussian2d_Poisson, Stacked2dFullGaussian2d_ZIG, Stacked2dPointPooled_Poisson, Stacked2dPointPooled_Gamma
from nnsysident.datasets.mouse_loaders import static_loaders
from nnsysident.utility.data_helpers import extract_data_key
from nnsysident.utility.measures import get_model_performance

random_seed = 27121992
device = 'cuda'

___

## Data

In [None]:
paths = ['./data/static20457-5-9-preproc0']

data_key = extract_data_key(paths[0])

dataset_config = {'paths': paths,
                  'batch_size': 64,
                  'seed': random_seed,
                  'loader_outputs': ["images", "responses"],
                  'normalize': True,
                  'exclude': ["images"],
                  }

dataloaders = static_loaders(**dataset_config)

## Model

In [None]:
model_config_base = {"hidden_kern": 13,
                     "input_kern": 15,
                     "init_sigma": 0.4,
                     'init_mu_range': 0.55,
                     'gamma_input': 1.0,
                     'grid_mean_predictor': {'type': 'cortex',
                                              'input_dimensions': 2,
                                              'hidden_layers': 0,
                                              'hidden_features': 0,
                                              'final_tanh': False},
                     "feature_reg_weight": 0.78,
                    }

# This is for the Lurz model (Poisson)
model_config_lurz_model = model_config_base.copy()
model_config_lurz_model['batch_norm_scale'] = False
model_config_lurz_model['feature_reg_weight'] = 2.439
model_config_lurz_model['independent_bn_bias'] = True

#### ZIG model

In [None]:
loc = np.exp(-10)

zig_model_config = model_config_base.copy()
zig_model_config['zero_thresholds'] = {data_key: loc}

zig_model = Stacked2dFullGaussian2d_ZIG().build_model(dataloaders, random_seed, **zig_model_config)

If you want to load a state_dict, run this cell. If the state_dict is transferred from a model which was trained on another dataset, set `strict=False`.
In this case, you will need to fine tune the readout of the model using the new dataset. For this, set `detach_core=True` in the trainer_config

In [None]:
# zig_model.load_state_dict(torch.load("ZIG_statedict" + data_key + ".pt"), strict=True)

#### Poisson model

In [None]:
poisson_model_config = model_config_base.copy()

poisson_model = Stacked2dFullGaussian2d_Poisson().build_model(dataloaders, random_seed, **poisson_model_config)
lurz_poisson_model = Stacked2dFullGaussian2d_Poisson().build_model(dataloaders, random_seed, **model_config_lurz_model)

In [None]:
# poisson_model.load_state_dict(torch.load("Poisson_statedict" + data_key + ".pt"), strict=True)

#### Gamma model

In [None]:
gamma_model_config = model_config_base.copy()

gamma_model = Stacked2dPointPooled_Gamma().build_model(dataloaders, random_seed, **gamma_model_config)

If you want to load a state_dict, run this cell. If the state_dict is transferred from a model which was trained on another dataset, set `strict=False`.
In this case, you will need to fine tune the readout of the model using the new dataset. For this, set `detach_core=True` in the trainer_config

#### Lurz model

In [None]:
not_matching_keys = lurz_poisson_model.load_state_dict(torch.load("lurz_core_poisson.tar"), strict=False)

not_matching_keys = list(key for key in not_matching_keys.missing_keys if key[:7] != "readout") + list(key for key in not_matching_keys.unexpected_keys if key[:7] != "readout")
print("{} not matching keys".format(len(not_matching_keys)))

## Training

Set `track_training=True` if you want to see more details how the training is going. If you only want to fine-tune the readout because the core was loaded, set `detach_core=True`.

In [None]:
trainer_config_base = {"track_training": False,
                       "device": device,
                       "detach_core": False}
lurz_model_trainer_config = trainer_config_base.copy()
lurz_model_trainer_config["detach_core"] = True

#### ZIG model

In [None]:
zig_score, zig_output, zig_state_dict = standard_trainer(zig_model,
                                                         dataloaders,
                                                         random_seed,
                                                         loss_function="ZIGLoss",
                                                         stop_function="get_loss",
                                                         maximize=False,
                                                         **trainer_config_base)
# torch.save(zig_state_dict, "mouseV1_gaussian_ZIG_statedict" + data_key + ".tar")

In [None]:
zig_performance = get_model_performance(zig_model, dataloaders, "ZIGLoss", device=device)

#### Poisson model

In [None]:
poisson_score, poisson_output, poisson_state_dict = standard_trainer(poisson_model,
                                                                     dataloaders,
                                                                     random_seed,
                                                                     loss_function="PoissonLoss",
                                                                     **trainer_config_base)
# torch.save(poisson_state_dict, "mouseV1_gaussian_Poisson_statedict" + data_key + ".tar")

In [None]:
poisson_performance = get_model_performance(poisson_model, dataloaders, "PoissonLoss", device=device)

#### Gamma model

In [None]:
gamma_score, gamma_output, gamma_state_dict = standard_trainer(gamma_model,
                                                                     dataloaders,
                                                                     random_seed,
                                                                     loss_function="GammaLoss",
                                                                     **trainer_config_base)
# torch.save(gamma_state_dict, "mouseV1_pointpooled_Gamma_statedict" + data_key + ".tar")

In [None]:
gamma_performance = get_model_performance(gamma_model, dataloaders, "GammaLoss", device=device)

#### Lurz model

In [None]:
lurz_poisson_score, lurz_poisson_output, lurz_poisson_state_dict = standard_trainer(lurz_poisson_model,
                                                                     dataloaders,
                                                                     random_seed,
                                                                     loss_function="PoissonLoss",
                                                                     **lurz_model_trainer_config)
# torch.save(lurz_poisson_state_dict, "mouseV1_gaussian_lurz_Poisson_statedict" + data_key + ".tar")

In [None]:
lurz_performance = get_model_performance(lurz_poisson_model, dataloaders, "PoissonLoss", device=device)

___