# How to train the Baseline Models for the SENSORIUM track

### This notebook will show how to
- instantiate dataloader for the Sensorium track
- instantiate pytorch model
- instantiate a trainer function
- train two baselines for this competition track
- save the model weights (the model weights can already be found in './model_checkpoints/pretrained/')

### Imports

In [2]:
import collections.abc
#hyper needs the four following aliases to be done manually.
collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping

In [3]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

from nnfabrik.builder import get_data, get_model, get_trainer

In [4]:
%pwd

'c:\\Users\\hp\\sensorium\\notebooks\\model_tutorial'

In [5]:
%cd ../../

c:\Users\hp\sensorium


### Instantiate DataLoader

In [7]:
# loading the SENSORIUM dataset
filenames = ['notebooks/data/static26872-17-20-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', ]

dataset_fn = 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                 'normalize': True,
                 'include_behavior': False,
                 'include_eye_position': False,
                 'batch_size': 16,
                 'scale':0.25,
                 }

dataloaders = get_data(dataset_fn, dataset_config)

# Instantiate State of the Art Model (SOTA)

In [8]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
  'stack': -1,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': False,
}

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

## Configure Trainer

In [9]:
trainer_fn = "sensorium.training.standard_trainer"

trainer_config = {'max_iter': 200,
                 'verbose': False,
                 'lr_decay_steps': 4,
                 'avg_loss': False,
                 'lr_init': 0.009,
                 }

trainer = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config)

# Run model training

In [10]:
validation_score, trainer_output, state_dict = trainer(model, dataloaders, seed=42)

Epoch 1: 100%|██████████| 280/280 [02:25<00:00,  1.92it/s]
Epoch 2: 100%|██████████| 280/280 [00:15<00:00, 17.95it/s]
Epoch 3: 100%|██████████| 280/280 [00:15<00:00, 18.15it/s]
Epoch 4: 100%|██████████| 280/280 [00:15<00:00, 18.19it/s]
Epoch 5: 100%|██████████| 280/280 [00:15<00:00, 17.75it/s]
Epoch 6: 100%|██████████| 280/280 [00:15<00:00, 17.62it/s]
Epoch 7: 100%|██████████| 280/280 [00:15<00:00, 17.58it/s]
Epoch 8: 100%|██████████| 280/280 [00:18<00:00, 15.00it/s]
Epoch 9: 100%|██████████| 280/280 [00:15<00:00, 18.13it/s]
Epoch 10: 100%|██████████| 280/280 [00:15<00:00, 17.99it/s]
Epoch 11: 100%|██████████| 280/280 [00:15<00:00, 17.94it/s]
Epoch 12: 100%|██████████| 280/280 [00:15<00:00, 17.94it/s]
Epoch 13: 100%|██████████| 280/280 [00:15<00:00, 18.04it/s]
Epoch 14: 100%|██████████| 280/280 [00:15<00:00, 17.96it/s]
Epoch 15: 100%|██████████| 280/280 [00:15<00:00, 18.07it/s]
Epoch 16: 100%|██████████| 280/280 [00:15<00:00, 18.20it/s]
Epoch 17: 100%|██████████| 280/280 [00:15<00:00, 

KeyboardInterrupt: 

### Save model checkpoints after training is complete

In [None]:
torch.save(model.state_dict(), './model_checkpoints/sensorium_sota_model.pth')

## Load Model Checkpoints

In [5]:
model.load_state_dict(torch.load("./model_checkpoints/pretrained/sensorium_sota_model.pth"));

---

# Train a simple LN model

Our LN model has the same architecture as our CNN model (a convolutional core followed by a gaussian readout)
but with all non-linearities removed except the final ELU+1 nonlinearity.
Thus turning the CNN model effectively into a fully linear model followed by a single output non-linearity.


In [6]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
              'stack': -1,
              'layers': 3,
              'input_kern': 9,
              'gamma_input': 6.3831,
              'gamma_readout': 0.0076,
              'hidden_kern': 7,
              'hidden_channels': 64,
              'grid_mean_predictor': {'type': 'cortex',
              'input_dimensions': 2,
              'hidden_layers': 1,
              'hidden_features': 30,
              'final_tanh': True},
              'depth_separable': True,
              'init_sigma': 0.1,
              'init_mu_range': 0.3,
              'gauss_type': 'full',
              'linear': True
               }
model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

In [None]:
validation_score, trainer_output, state_dict = trainer(model, dataloaders, seed=42)

In [None]:
torch.save(model.state_dict(), './model_checkpoints/sensorium_ln_model.pth')

In [7]:
model.load_state_dict(torch.load("./model_checkpoints/pretrained/sensorium_ln_model.pth"));

---