# How to train the Baseline Models

### This notebook will show how to
- instantiate dataloader for the demo data
- instantiate pytorch model
- instantiate a trainer function
- train two baselines on the demo data
- save the model weights (the model weights can already be found in '/notebooks/precomputed_checkpoints/')

### Imports

In [2]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from nnfabrik.builder import get_data, get_model, get_trainer

### Instantiate DataLoader

In [3]:
filenames = ['./data/lurz2020/static20457-5-9-preproc0', ]

dataset_fn = 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                 'normalize': True,
                 'include_behavior': False,
                 'include_eye_position': True,
                 'batch_size': 128,
                 'exclude': None,
                 'file_tree': True,
                 'scale': 1,
                 }

dataloaders = get_data(dataset_fn, dataset_config)

dataport not available, will only be able to load data locally


# Instantiate State of the Art Model (SOTA)

In [4]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
              'stack': -1,
              'layers': 4,
              'input_kern': 9,
              'gamma_input': 6.3831,
              'gamma_readout': 0.0076,
              'hidden_dilation': 1,
              'hidden_kern': 7,
              'hidden_channels': 64,
              'depth_separable': True,
              'init_sigma': 0.1,
              'init_mu_range': 0.3,
              'gauss_type': 'full',
               }

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)



## Configure Trainer

In [5]:
trainer_fn = "sensorium.training.standard_trainer"

trainer_config = {'max_iter': 100,
                 'verbose': False,
                 'lr_decay_steps': 4,
                 'avg_loss': False,
                 'lr_init': 0.009,
                 }

trainer = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config)

# Run model training

In [6]:
validation_score, trainer_output, state_dict = trainer(model, dataloaders, seed=42)

Epoch 1: 100%|██████████| 35/35 [00:16<00:00,  2.16it/s]
Epoch 2: 100%|██████████| 35/35 [00:08<00:00,  4.02it/s]
Epoch 3: 100%|██████████| 35/35 [00:08<00:00,  4.08it/s]
Epoch 4: 100%|██████████| 35/35 [00:08<00:00,  4.07it/s]
Epoch 5: 100%|██████████| 35/35 [00:08<00:00,  4.13it/s]
Epoch 6: 100%|██████████| 35/35 [00:08<00:00,  4.07it/s]
Epoch 7: 100%|██████████| 35/35 [00:08<00:00,  4.13it/s]
Epoch 8: 100%|██████████| 35/35 [00:08<00:00,  4.11it/s]
Epoch 9: 100%|██████████| 35/35 [00:08<00:00,  4.12it/s]
Epoch 10: 100%|██████████| 35/35 [00:08<00:00,  4.04it/s]
Epoch 11: 100%|██████████| 35/35 [00:08<00:00,  4.04it/s]
Epoch 12: 100%|██████████| 35/35 [00:08<00:00,  4.11it/s]
Epoch 13: 100%|██████████| 35/35 [00:08<00:00,  4.10it/s]
Epoch 14: 100%|██████████| 35/35 [00:08<00:00,  4.16it/s]
Epoch 15: 100%|██████████| 35/35 [00:08<00:00,  4.12it/s]
Epoch 16: 100%|██████████| 35/35 [00:08<00:00,  4.13it/s]
Epoch 17: 100%|██████████| 35/35 [00:08<00:00,  4.12it/s]
Epoch 18: 100%|████████

## Save model checkpoints

In [7]:
torch.save(model.state_dict(), './checkpoints/sota_model.pth')

---

# Train Simple LN model

In [8]:
# this will remove all nonlinearities from the CNN, and computes a 3 layer LN-model

model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
              'stack': -1,
              'layers': 3,
              'input_kern': 9,
              'gamma_input': 6.3831,
              'gamma_readout': 0.0076,
              'hidden_dilation': 1,
              'hidden_kern': 7,
              'hidden_channels': 64,
              'depth_separable': True,
              'init_sigma': 0.1,
              'init_mu_range': 0.3,
              'gauss_type': 'full',
              'linear': True
               }
ln_model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

In [9]:
validation_score, trainer_output, state_dict = trainer(ln_model, dataloaders, seed=42)

Epoch 1: 100%|██████████| 35/35 [00:07<00:00,  4.39it/s]
Epoch 2: 100%|██████████| 35/35 [00:08<00:00,  4.33it/s]
Epoch 3: 100%|██████████| 35/35 [00:08<00:00,  4.37it/s]
Epoch 4: 100%|██████████| 35/35 [00:08<00:00,  4.33it/s]
Epoch 5: 100%|██████████| 35/35 [00:08<00:00,  4.31it/s]
Epoch 6: 100%|██████████| 35/35 [00:08<00:00,  4.34it/s]
Epoch 7: 100%|██████████| 35/35 [00:08<00:00,  4.36it/s]
Epoch 8: 100%|██████████| 35/35 [00:08<00:00,  4.34it/s]
Epoch 9: 100%|██████████| 35/35 [00:08<00:00,  4.33it/s]
Epoch 10: 100%|██████████| 35/35 [00:08<00:00,  4.37it/s]
Epoch 11: 100%|██████████| 35/35 [00:08<00:00,  4.37it/s]
Epoch 12: 100%|██████████| 35/35 [00:08<00:00,  4.36it/s]
Epoch 13: 100%|██████████| 35/35 [00:07<00:00,  4.43it/s]
Epoch 14: 100%|██████████| 35/35 [00:08<00:00,  4.31it/s]
Epoch 15: 100%|██████████| 35/35 [00:07<00:00,  4.38it/s]
Epoch 16: 100%|██████████| 35/35 [00:08<00:00,  4.37it/s]
Epoch 17: 100%|██████████| 35/35 [00:08<00:00,  4.29it/s]
Epoch 18: 100%|████████

In [10]:
torch.save(ln_model.state_dict(), './checkpoints/ln_model.pth')

---