# Monkey System Identification Models

In [1]:
import torch
import numpy as np

from nnvision.datasets.monkey_loaders import monkey_static_loader
from nnsysident.models.models import stacked2d_gamma, stacked2d_poisson
from nnsysident.training.trainers import standard_trainer
from nnsysident.utility.measures import get_model_performance

import matplotlib.pyplot as plt

random_seed = 27121992
device = 'cuda'

  warn(f"Failed to load image Python extension: {e}")


## Data

In [25]:
dataset_config = {'dataset': 'CSRF19_V1',
                 'neuronal_data_files': [
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3631896544452.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3632669014376.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3632932714885.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3633364677437.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3634055946316.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3634142311627.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3634658447291.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3634744023164.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3635178040531.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3635949043110.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3636034866307.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3636552742293.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3637161140869.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3637248451650.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3637333931598.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3637760318484.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3637851724731.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3638367026975.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3638456653849.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3638885582960.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3638373332053.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3638541006102.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3638802601378.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3638973674012.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3639060843972.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3639406161189.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3640011636703.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3639664527524.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3639492658943.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3639749909659.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3640095265572.pickle',
                  '/project/notebooks/data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3631807112901.pickle'],
                 'image_cache_path': '/project/notebooks/data/monkey/CSRF19_V1/images/individual',
                 'crop': 70,
                 'subsample': 1,
                 'seed': 1000,
                 'time_bins_sum': 12,
                 'batch_size': 128}

dataloaders = monkey_static_loader(**dataset_config)

In [26]:
for x, y in dataloaders["train"]['3631896544452']:
    break

In [27]:
x.shape

torch.Size([128, 1, 93, 93])

## Model

In [11]:
poisson_model_config =  {'layers': 3,
                         'input_kern': 24,
                         'gamma_input': 10,
                         'gamma_readout': 0.5,
                         'hidden_dilation': 2,
                         'hidden_kern': 9,
                         'hidden_channels': 32,
                         "readout_type": "MultipleGeneralizedFullGaussian2d",
                         'grid_mean_predictor': None}

In [12]:
poisson_model = stacked2d_poisson(dataloaders, random_seed, **poisson_model_config)

## Training

In [None]:
trainer_config = {'max_iter': 100,
                  'verbose': False,
                  'lr_decay_steps': 3,
                  'avg_loss': False,
                  'patience': 3,
                  'lr_init': 0.0042,
                  'device': device}

In [None]:
poisson_score, poisson_output, poisson_state_dict = standard_trainer(poisson_model,
                                                                     dataloaders,
                                                                     random_seed,
                                                                     loss_function="PoissonLoss",
                                                                     **trainer_config)

In [None]:
# torch.save(poisson_state_dict, "monkeyV1_pointpooled_poisson_statedict" + ".tar")

In [None]:
poisson_performance = get_model_performance(poisson_model, dataloaders, "PoissonLoss", device=device)