# Mouse System Identification Models

In [1]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USERNAME']
dj.config['database.password'] = os.environ['DJ_PASSWORD']
dj.config['enable_python_native_blobs'] = True
dj.config['display.limit'] = 200
        
name = 'vei'
os.environ["DJ_SCHEMA_NAME"] = f"metrics_{name}"
dj.config["nnfabrik.schema_name"] = os.environ["DJ_SCHEMA_NAME"]

In [2]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import hiplot as hip

from nnsysident.training.trainers import standard_trainer
from nnsysident.models.models import stacked2d_gamma, stacked2d_zig
from nnsysident.datasets.mouse_loaders import static_loaders
from nnsysident.utility.data_helpers import extract_data_key
from nnsysident.utility.measures import get_model_performance

random_seed = 27121992
device = 'cuda'

  warn(f"Failed to load image Python extension: {e}")


Connecting konstantin@134.76.19.44:3306


In [3]:
from nnsysident.tables.bayesian import TrainedModelBayesian, ModelBayesian, DatasetBayesian
from nnfabrik.main import Model, Trainer, Dataset, Seed
from nnsysident.tables.experiments import TrainedModel, schema

___

## Data

In [8]:
paths = ['./data/static20457-5-9-preproc0']

data_key = extract_data_key(paths[0])


dataset_config = {'paths': paths,
                  'batch_size': 64,
                  'seed': random_seed,
                  # 'loader_outputs': ["images", "responses"],
                  'loader_outputs': ["images", "responses", "pupil_center", "behavior"], 
                  'subtract_behavior_mean': True,
                  'normalize': True,
                  'exclude': ["images"],
                  "cuda": True if device=="cuda" else False
                  }

dataloaders = static_loaders(**dataset_config)

## Model

In [9]:
# model_config_base = {"hidden_kern": 13,
#                      "input_kern": 15,
#                      "init_sigma": 0.4,
#                      'init_mu_range': 0.55,
#                      'gamma_input': 1.0,
#                      'grid_mean_predictor': {'type': 'cortex',
#                                               'input_dimensions': 2,
#                                               'hidden_layers': 0,
#                                               'hidden_features': 0,
#                                               'final_tanh': False},
#                      "feature_reg_weight": 0.78,
#                      "readout_type": "MultipleGeneralizedFullGaussian2d",
#                     }

# From Gamma Hypersearch:
model_config_base = {'init_sigma': 0.4,
                     'init_mu_range': 0.55,
                     'gamma_input': 1.0,
                     'grid_mean_predictor': {'type': 'cortex',
                      'input_dimensions': 2,
                      'hidden_layers': 0,
                      'hidden_features': 0,
                      'final_tanh': False},
                     'readout_type': 'MultipleGeneralizedFullGaussian2d',
                     'feature_reg_weight': 0.26702978129164495,
                     'hidden_channels': 128,
                     'layers': 5,
                     'hidden_kern': 11,
                     'input_kern': 15}

model_config_base["modulator_kwargs"] = {'mod_type': 'MLP',
                                         'layers': 2,
                                         'hidden_channels': 10,
                                         'gamma_modulator': 0.0,
                                         'bias': False}

model_config_base["shifter_kwargs"] = {  'shift_type': 'MLP',
                                         'shift_layers': 3,
                                         'hidden_channels_shifter': 5,
                                         'gamma_shifter': 0.0,
                                         'bias': False}

gamma_model_config = model_config_base.copy()
gamma_model = stacked2d_gamma(dataloaders, 
                            random_seed, 
                            **gamma_model_config)



# loc = np.exp(-10)
# zig_model_config = model_config_base.copy()
# zig_model_config['zero_thresholds'] = {data_key: loc}
# zig_model = stacked2d_zig(dataloaders, random_seed, **zig_model_config)

# gamma_model.to(device);
# zig_model.to(device);

In [None]:
for x, y in dataloaders["train"][data_key]:
    break

In [None]:
x.shape

## Training

In [10]:
trainer_config_base = {"track_training": False,
                       "device": device,
                       "detach_core": False,
                       "stop_function": "get_correlations",
                       "maximize": True}

In [11]:
gamma_score, gamma_output, gamma_state_dict = standard_trainer(gamma_model,
                                                                     dataloaders,
                                                                     random_seed,
                                                                     loss_function=gamma_model.loss_fn,
                                                                     **trainer_config_base)
# torch.save(gamma_state_dict, "mouseV1_gaussian_Gamma_statedict" + data_key + ".pt")

Epoch 1: 100% 70/70 [00:11<00:00,  6.12it/s]
Epoch 2: 100% 70/70 [00:29<00:00,  2.38it/s]
Epoch 3: 100% 70/70 [00:32<00:00,  2.13it/s]
Epoch 4: 100% 70/70 [00:32<00:00,  2.14it/s]
Epoch 5: 100% 70/70 [00:32<00:00,  2.15it/s]
Epoch 6: 100% 70/70 [00:32<00:00,  2.13it/s]
Epoch 7: 100% 70/70 [00:32<00:00,  2.14it/s]
Epoch 8: 100% 70/70 [00:32<00:00,  2.17it/s]
Epoch 9: 100% 70/70 [00:31<00:00,  2.20it/s]
Epoch 10: 100% 70/70 [00:32<00:00,  2.17it/s]
Epoch 11:  53% 37/70 [00:17<00:15,  2.17it/s]


KeyboardInterrupt: 

In [None]:
gamma_performance = get_model_performance(gamma_model, dataloaders, gamma_model.loss_fn, device=device)

In [None]:
zig_score, zig_output, zig_state_dict = standard_trainer(zig_model,
                                                                     dataloaders,
                                                                     random_seed,
                                                                     loss_function=zig_model.loss_fn,
                                                                     **trainer_config_base)
# torch.save(poisson_state_dict, "mouseV1_gaussian_Poisson_statedict" + data_key + ".pt")

In [None]:
zig_performance = get_model_performance(zig_model, dataloaders, zig_model.loss_fn, device=device)

___

In [None]:
model_config_base = {'zero_thresholds': {'20457-5-9-0': 4.5399929762484854e-05},
 'init_sigma': 0.4,
 'init_mu_range': 0.55,
 'gamma_input': 1.0,
 'grid_mean_predictor': {'type': 'cortex',
  'input_dimensions': 2,
  'hidden_layers': 0,
  'hidden_features': 0,
  'final_tanh': False},
 'readout_type': 'MultipleGeneralizedFullGaussian2d',
 'feature_reg_weight': 0.01977603972740348,
 'hidden_channels': 256,
 'layers': 5,
 'hidden_kern': 13,
 'input_kern': 15}

loc = np.exp(-10)

zig_model_config = model_config_base.copy()
zig_model_config['zero_thresholds'] = {data_key: loc}
zig_model = stacked2d_zig(dataloaders, random_seed, **zig_model_config)

zig_model.to(device);

In [None]:
zig_score, zig_output, zig_state_dict = standard_trainer(zig_model,
                                                                     dataloaders,
                                                                     random_seed,
                                                                     loss_function=zig_model.loss_fn,
                                                                     **trainer_config_base)
# torch.save(poisson_state_dict, "mouseV1_gaussian_Poisson_statedict" + data_key + ".pt")

In [None]:
zig_performance = get_model_performance(zig_model, dataloaders, zig_model.loss_fn, device=device)