# Monkey System Identification Models

In [1]:
import torch
import numpy as np

from nnvision.datasets.monkey_loaders import monkey_static_loader
from nnsysident.models.models import stacked2d_gamma, stacked2d_poisson
from nnsysident.training.trainers import standard_trainer
from nnsysident.utility.measures import get_model_performance

import matplotlib.pyplot as plt

random_seed = 27121992
device = 'cuda'

  warn(f"Failed to load image Python extension: {e}")


## Data

In [2]:
dataset_config = {'dataset': 'CSRF19_V1',
                 'neuronal_data_files': [
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3631896544452.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3632669014376.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3632932714885.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3633364677437.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3634055946316.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3634142311627.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3634658447291.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3634744023164.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3635178040531.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3635949043110.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3636034866307.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3636552742293.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3637161140869.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3637248451650.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3637333931598.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3637760318484.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3637851724731.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3638367026975.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3638456653849.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3638885582960.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3638373332053.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3638541006102.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3638802601378.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3638973674012.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3639060843972.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3639406161189.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3640011636703.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3639664527524.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3639492658943.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3639749909659.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3640095265572.pickle',
                  './data/monkey/CSRF19_V1/neuronal_data/CSRF19_V1_3631807112901.pickle'],
                 'image_cache_path': './data/monkey/CSRF19_V1/images/individual',
                 'crop': 70,
                 'subsample': 1,
                 'seed': 1000,
                 'time_bins_sum': 12,
                 'batch_size': 128}

dataloaders = monkey_static_loader(**dataset_config)

In [6]:
for x, y in dataloaders["train"]['3631896544452']:
    break

In [8]:
x.shape

torch.Size([128, 1, 93, 93])

## Model

In [5]:
poisson_model_config =  {'layers': 3,
                         'input_kern': 24,
                         'gamma_input': 10,
                         'gamma_readout': 0.5,
                         'hidden_dilation': 2,
                         'hidden_kern': 9,
                         'hidden_channels': 32}

In [6]:
poisson_model = Stacked2dPointPooled_Poisson().build_model(dataloaders, random_seed, **poisson_model_config)



## Training

In [5]:
trainer_config = {'max_iter': 100,
                  'verbose': False,
                  'lr_decay_steps': 3,
                  'avg_loss': False,
                  'patience': 3,
                  'lr_init': 0.0042,
                  'device': device}

In [6]:
poisson_score, poisson_output, poisson_state_dict = standard_trainer(poisson_model,
                                                                     dataloaders,
                                                                     random_seed,
                                                                     loss_function="PoissonLoss",
                                                                     **trainer_config)

Epoch 1: 100% 3616/3616 [02:52<00:00, 20.96it/s]
Epoch 2: 100% 3616/3616 [02:52<00:00, 21.00it/s]
Epoch 3: 100% 3616/3616 [02:52<00:00, 20.97it/s]
Epoch 4: 100% 3616/3616 [02:52<00:00, 20.98it/s]
Epoch 5: 100% 3616/3616 [02:52<00:00, 20.97it/s]
Epoch 6: 100% 3616/3616 [02:52<00:00, 20.94it/s]
Epoch 7: 100% 3616/3616 [02:52<00:00, 20.97it/s]
Epoch 8: 100% 3616/3616 [02:52<00:00, 20.96it/s]
Epoch 9: 100% 3616/3616 [02:52<00:00, 20.96it/s]
Epoch 10: 100% 3616/3616 [02:52<00:00, 20.97it/s]
Epoch 11: 100% 3616/3616 [02:52<00:00, 20.97it/s]
Epoch 12: 100% 3616/3616 [02:52<00:00, 20.96it/s]
Epoch 13: 100% 3616/3616 [02:52<00:00, 20.92it/s]
Epoch 14: 100% 3616/3616 [02:52<00:00, 20.94it/s]
Epoch 15: 100% 3616/3616 [02:52<00:00, 20.97it/s]
Epoch 16: 100% 3616/3616 [02:52<00:00, 20.92it/s]
Epoch 17: 100% 3616/3616 [02:52<00:00, 20.94it/s]
Epoch 18: 100% 3616/3616 [02:52<00:00, 20.96it/s]
Epoch 19: 100% 3616/3616 [02:52<00:00, 20.94it/s]
Epoch 20: 100% 3616/3616 [02:52<00:00, 20.98it/s]
Epoch 21:

In [None]:
# torch.save(poisson_state_dict, "monkeyV1_pointpooled_poisson_statedict" + ".tar")

In [8]:
poisson_performance = get_model_performance(poisson_model, dataloaders, "PoissonLoss", device=device)



c̲o̲r̲r̲e̲l̲a̲t̲i̲o̲n̲ 

train:        0.432 
validation:   0.409 
test:         0.397 

l̲o̲s̲s̲ 

train:        -0.413 
validation:   -0.414 
test:         -0.365 

