# Mouse System Identification Models

In [1]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USERNAME']
dj.config['database.password'] = os.environ['DJ_PASSWORD']
dj.config['enable_python_native_blobs'] = True
dj.config['display.limit'] = 200
        
name = 'vei'
os.environ["DJ_SCHEMA_NAME"] = f"metrics_{name}"
dj.config["nnfabrik.schema_name"] = os.environ["DJ_SCHEMA_NAME"]

In [2]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import hiplot as hip

from nnsysident.training.trainers import standard_trainer
from nnsysident.models.models import stacked2d_gamma, stacked2d_poisson
from nnsysident.datasets.mouse_loaders import static_loaders
from nnsysident.utility.data_helpers import extract_data_key
from nnsysident.utility.measures import get_model_performance

random_seed = 27121992
device = 'cuda'

Connecting konstantin@134.76.19.44:3306


In [3]:
from nnsysident.tables.bayesian import TrainedModelBayesian, ModelBayesian, DatasetBayesian
from nnfabrik.main import Model, Trainer, Dataset, Seed
from nnsysident.tables.experiments import TrainedModel, schema

___

## Data

In [4]:
paths = ['./data/static20457-5-9-preproc0']

data_key = extract_data_key(paths[0])

dataset_config = {'paths': paths,
                  'batch_size': 64,
                  'seed': random_seed,
                  'loader_outputs': ["images", "responses"],
                  'normalize': True,
                  'exclude': ["images"],
                  "cuda": True if device=="cuda" else False
                  }

dataloaders = static_loaders(**dataset_config)

## Model

In [20]:
model_config_base = {"hidden_kern": 13,
                     "input_kern": 15,
                     "init_sigma": 0.4,
                     'init_mu_range': 0.55,
                     'gamma_input': 1.0,
                     'grid_mean_predictor': {'type': 'cortex',
                                              'input_dimensions': 2,
                                              'hidden_layers': 0,
                                              'hidden_features': 0,
                                              'final_tanh': False},
                     "feature_reg_weight": 0.78,
                     # "readout_type": "MultipleGeneralizedFullGaussian2d",
                     "readout_type": "MultipleGeneralizedPointPooled2d",
                    }

# model_config_base['batch_norm_scale'] = False
# model_config_base['feature_reg_weight'] = 2.439
# model_config_base['independent_bn_bias'] = True

gamma_model_config = model_config_base.copy()
gamma_model = stacked2d_gamma(dataloaders, 
                            random_seed, 
                            **gamma_model_config)

poisson_model = stacked2d_poisson(dataloaders, random_seed, **model_config_base)

gamma_model.to(device);
poisson_model.to(device);

In [None]:
not_matching_keys = gamma_model.load_state_dict(torch.load("lurz_core_poisson.tar"), strict=False)

not_matching_keys = list(key for key in not_matching_keys.missing_keys if key[:7] != "readout") + list(key for key in not_matching_keys.unexpected_keys if key[:7] != "readout")
print("{} not matching keys".format(len(not_matching_keys)))

In [None]:
not_matching_keys = poisson_model.load_state_dict(torch.load("lurz_core_poisson.tar"), strict=False)

not_matching_keys = list(key for key in not_matching_keys.missing_keys if key[:7] != "readout") + list(key for key in not_matching_keys.unexpected_keys if key[:7] != "readout")
print("{} not matching keys".format(len(not_matching_keys)))

## Training

In [None]:
trainer_config_base = {"track_training": False,
                       "device": device,
                       "detach_core": False,
                       "stop_function": "get_loss",
                       "maximize": False}

In [None]:
gamma_score, gamma_output, gamma_state_dict = standard_trainer(gamma_model,
                                                                     dataloaders,
                                                                     random_seed,
                                                                     loss_function=gamma_model.loss_fn,
                                                                     **trainer_config)
# torch.save(gamma_state_dict, "mouseV1_gaussian_Gamma_statedict" + data_key + ".pt")

In [None]:
gamma_performance = get_model_performance(gamma_model, dataloaders, gamma_model.loss_fn, device=device)

In [None]:
gamma_score, gamma_output, gamma_state_dict = standard_trainer(gamma_model,
                                                                     dataloaders,
                                                                     random_seed,
                                                                     loss_function=gamma_model.loss_fn,
                                                                     **trainer_config_base)
# torch.save(gamma_state_dict, "mouseV1_gaussian_Gamma_statedict" + data_key + ".pt")

In [None]:
poisson_score, poisson_output, poisson_state_dict = standard_trainer(poisson_model,
                                                                     dataloaders,
                                                                     random_seed,
                                                                     loss_function=poisson_model.loss_fn,
                                                                     **trainer_config_base)
# torch.save(poisson_state_dict, "mouseV1_gaussian_Poisson_statedict" + data_key + ".pt")

In [None]:
poisson_performance = get_model_performance(poisson_model, dataloaders, poisson_model.loss_fn, device=device)

___

In [8]:
for x, y in dataloaders["train"][data_key]:
    break

In [10]:
x.shape

torch.Size([64, 1, 36, 64])

In [21]:
out = poisson_model(x[0][None, ...])
out.shape

torch.Size([5335])

In [None]:
poisson_model(x[0][None, ...])

In [26]:
poisson_model.core(x[0][None, ...]).shape

torch.Size([1, 64, 22, 50])

In [27]:
poisson_model.core(x).shape

torch.Size([64, 64, 22, 50])