# Mouse System Identification Models

In [1]:
import torch
import numpy as np

from nnsysident.training.trainers import standard_trainer
from nnsysident.models.models import stacked2d_gamma, stacked2d_zig, stacked2d_poisson
from nnsysident.datasets.mouse_loaders import static_loaders
from nnsysident.utility.data_helpers import extract_data_key
from nnsysident.utility.measures import get_model_performance

random_seed = 27121992
device = 'cuda'

  warn(f"Failed to load image Python extension: {e}")


Connecting konstantin@134.76.19.44:3306


___

## Data

In [2]:
paths = ['/project/notebooks/data/static20457-5-9-preproc0']

data_key = extract_data_key(paths[0])

dataset_config = {'paths': paths,
                  'batch_size': 64,
                  'seed': random_seed,
                  'loader_outputs': ["images", "responses"],
                  'normalize': True,
                  'exclude': ["images"],
                  }

dataloaders = static_loaders(**dataset_config)

## Model

In [3]:
model_config_base = {"hidden_kern": 13,
                     "input_kern": 15,
                     "init_sigma": 0.4,
                     'init_mu_range': 0.55,
                     'gamma_input': 1.0,
                     'grid_mean_predictor': {'type': 'cortex',
                                              'input_dimensions': 2,
                                              'hidden_layers': 0,
                                              'hidden_features': 0,
                                              'final_tanh': False},
                     "feature_reg_weight": 0.78,
                     "readout_type": "MultipleGeneralizedFullGaussian2d",
                    }

#### ZIG model

In [4]:
loc = np.exp(-10)

zig_model_config = model_config_base.copy()
zig_model_config['zero_thresholds'] = loc

zig_model = stacked2d_zig(dataloaders, random_seed, **zig_model_config)



If you want to load a state_dict, run this cell. If the state_dict is transferred from a model which was trained on another dataset, set `strict=False`.
In this case, you will need to fine tune the readout of the model using the new dataset. For this, set `detach_core=True` in the trainer_config

In [6]:
# zig_model.load_state_dict(torch.load("ZIG_statedict" + data_key + ".pt"), strict=True)

#### Poisson model

In [7]:
poisson_model_config = model_config_base.copy()
poisson_model = stacked2d_poisson(dataloaders, random_seed, **poisson_model_config)

If you want to load a state_dict, run this cell. If the state_dict is transferred from a model which was trained on another dataset, set `strict=False`.
In this case, you will need to fine tune the readout of the model using the new dataset. For this, set `detach_core=True` in the trainer_config

In [8]:
# poisson_model.load_state_dict(torch.load("Poisson_statedict" + data_key + ".pt"), strict=True)

#### Lurz model

In [5]:
model_config_lurz_model = model_config_base.copy()
model_config_lurz_model['batch_norm_scale'] = False
model_config_lurz_model['feature_reg_weight'] = 2.439
model_config_lurz_model['independent_bn_bias'] = True
model_config_lurz_model['init_with_lurz_core'] = True

lurz_poisson_model = stacked2d_poisson(dataloaders, random_seed, **model_config_lurz_model)



## Training

Set `track_training=True` if you want to see more details how the training is going. If you only want to fine-tune the readout because the core was loaded, set `detach_core=True`.

In [11]:
trainer_config_base = {"track_training": False,
                       "device": device,
                       "detach_core": False}

#### ZIG model

In [None]:
zig_score, zig_output, zig_state_dict = standard_trainer(zig_model,
                                                         dataloaders,
                                                         random_seed,
                                                         loss_function="ZIGLoss",
                                                         stop_function="get_loss",
                                                         maximize=False,
                                                         **trainer_config_base)
# torch.save(zig_state_dict, "mouseV1_gaussian_ZIG_statedict" + data_key + ".tar")

In [11]:
zig_performance = get_model_performance(zig_model, dataloaders, "ZIGLoss", device=device)

c̲o̲r̲r̲e̲l̲a̲t̲i̲o̲n̲ 

train:        0.305 
validation:   0.253 
test:         0.253 

l̲o̲s̲s̲ 

train:        -8.130 
validation:   -8.068 
test:         -8.081 



#### Poisson model

In [None]:
poisson_score, poisson_output, poisson_state_dict = standard_trainer(poisson_model,
                                                                     dataloaders,
                                                                     random_seed,
                                                                     loss_function="PoissonLoss",
                                                                     **trainer_config_base)
# torch.save(poisson_state_dict, "mouseV1_gaussian_Poisson_statedict" + data_key + ".tar")

In [13]:
poisson_performance = get_model_performance(poisson_model, dataloaders, "PoissonLoss", device=device)



c̲o̲r̲r̲e̲l̲a̲t̲i̲o̲n̲ 

train:        0.377 
validation:   0.278 
test:         0.275 

l̲o̲s̲s̲ 

train:        0.565 
validation:   0.611 
test:         0.617 



#### Lurz model

In [None]:
# Set 'detach_core' to True in order to only train the readout:
lurz_model_trainer_config = trainer_config_base.copy()
lurz_model_trainer_config["detach_core"] = True

lurz_poisson_score, lurz_poisson_output, lurz_poisson_state_dict = standard_trainer(lurz_poisson_model,
                                                                     dataloaders,
                                                                     random_seed,
                                                                     loss_function="PoissonLoss",
                                                                     **lurz_model_trainer_config)
# torch.save(lurz_poisson_state_dict, "mouseV1_gaussian_lurz_Poisson_statedict" + data_key + ".tar")

In [15]:
lurz_performance = get_model_performance(lurz_poisson_model, dataloaders, "PoissonLoss", device=device)

c̲o̲r̲r̲e̲l̲a̲t̲i̲o̲n̲ 

train:        0.384 
validation:   0.325 
test:         0.324 

l̲o̲s̲s̲ 

train:        0.562 
validation:   0.587 
test:         0.591 



___