# How to train the Baseline Models

### This notebook will show how to
- instantiate dataloader for the demo data
- instantiate pytorch model
- instantiate a trainer function
- train two baselines on the demo data
- save the model weights (the model weights can already be found in '/notebooks/precomputed_checkpoints/')

### Imports

In [3]:
ls

1_inspect_data.ipynb               [0m[01;34mdata[0m/
2_model_training.ipynb             [01;34mground_truth_files[0m/
3_submission_and_evaluation.ipynb  __init__.py
4_cloud_based_data_demo.ipynb      [01;34msubmission_files[0m/
[01;34mcheckpoints[0m/


In [4]:
cd ..

/home/maria/github/sensorium


In [5]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

from nnfabrik.builder import get_data, get_model, get_trainer

### Instantiate DataLoader

In [6]:
#filenames = ['./notebooks/data/static20457-5-9-preproc0', ]
#filenames = ['./notebooks/data/ImageNet_25172_28_1', ]
#filenames = ['/mnt/lab/data/Neuropixels/ImageNet_25172_28_1', ]
#filenames = ['/home/maria/imagenet_data/128_22_1_1']
#filenames = ['/home/maria/imagenet_data/102_110_1_1']
#filenames = ['/home/maria/imagenet_data/102_117_1_1']
filenames = ['/home/maria/imagenet_data/107_124_1_1', '/home/maria/imagenet_data/128_22_1_1', '/home/maria/imagenet_data/102_110_1_1', '/home/maria/imagenet_data/102_117_1_1']

dataset_fn = 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                 'normalize': True,
                 'include_behavior': False,
                 'include_eye_position': False,
                 'batch_size': 64,
                 'exclude': None,
                 'file_tree': True,
                 'scale': 0.25,
                 }

dataloaders = get_data(dataset_fn, dataset_config)

/home/maria/imagenet_data/107_124_1_1 ['images', 'responses']
FileTreeDataset /home/maria/imagenet_data/107_124_1_1 (n=6000 items)
	images, responses
/home/maria/imagenet_data/128_22_1_1 ['images', 'responses']
FileTreeDataset /home/maria/imagenet_data/128_22_1_1 (n=6000 items)
	images, responses
/home/maria/imagenet_data/102_110_1_1 ['images', 'responses']
FileTreeDataset /home/maria/imagenet_data/102_110_1_1 (n=6000 items)
	images, responses
/home/maria/imagenet_data/102_117_1_1 ['images', 'responses']
FileTreeDataset /home/maria/imagenet_data/102_117_1_1 (n=6000 items)
	images, responses


# Instantiate State of the Art Model (SOTA)

In [23]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
              'stack': -1,
              'layers': 4,
              'input_kern': 9,
              'gamma_input': 6.3831,
              'gamma_readout': 0.0076,
              'hidden_dilation': 1,
              'hidden_kern': 7,
              'hidden_channels': 64,
              'depth_separable': True,
              'init_sigma': 0.1,
              'init_mu_range': 0.3,
              'gauss_type': 'full',
               }

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

## Configure Trainer

In [24]:
trainer_fn = "sensorium.training.standard_trainer"

trainer_config = {'max_iter': 100,
                 'verbose': False,
                 'lr_decay_steps': 4,
                 'avg_loss': False,
                 'lr_init': 0.009,
                 }

trainer = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config)

# Run model training

In [25]:
torch.cuda.empty_cache()

In [26]:
%%time
validation_score, trainer_output, state_dict = trainer(model, dataloaders, seed=42)

Epoch 1: 100%|██████████| 284/284 [00:40<00:00,  6.93it/s]
Epoch 2: 100%|██████████| 284/284 [00:24<00:00, 11.45it/s]
Epoch 3: 100%|██████████| 284/284 [00:24<00:00, 11.59it/s]
Epoch 4: 100%|██████████| 284/284 [00:24<00:00, 11.51it/s]
Epoch 5: 100%|██████████| 284/284 [00:24<00:00, 11.37it/s]
Epoch 6: 100%|██████████| 284/284 [00:24<00:00, 11.49it/s]
Epoch 7: 100%|██████████| 284/284 [00:24<00:00, 11.51it/s]
Epoch 8: 100%|██████████| 284/284 [00:24<00:00, 11.36it/s]
Epoch 9: 100%|██████████| 284/284 [00:24<00:00, 11.48it/s]
Epoch 10: 100%|██████████| 284/284 [00:25<00:00, 11.35it/s]
Epoch 11: 100%|██████████| 284/284 [00:24<00:00, 11.43it/s]
Epoch 12: 100%|██████████| 284/284 [00:25<00:00, 11.36it/s]
Epoch 13: 100%|██████████| 284/284 [00:24<00:00, 11.41it/s]
Epoch 14: 100%|██████████| 284/284 [00:24<00:00, 11.40it/s]
Epoch 15: 100%|██████████| 284/284 [00:24<00:00, 11.36it/s]
Epoch 16: 100%|██████████| 284/284 [00:24<00:00, 11.37it/s]
Epoch 17: 100%|██████████| 284/284 [00:24<00:00, 

CPU times: user 24min 50s, sys: 6.47 s, total: 24min 57s
Wall time: 24min 10s


## Save model checkpoints

In [27]:
model.state_dict()

OrderedDict([('core._input_weights_regularizer.laplace.filter',
              tensor([[[[ 0., -1.,  0.],
                        [-1.,  4., -1.],
                        [ 0., -1.,  0.]]]], device='cuda:0')),
             ('core.features.layer0.conv.weight',
              tensor([[[[ 9.9055e-03,  3.3260e-02,  4.8413e-02,  ...,  4.3834e-02,
                          4.2487e-02,  2.8706e-02],
                        [ 2.4554e-02,  5.8688e-02,  7.9935e-02,  ...,  7.4701e-02,
                          7.4394e-02,  5.2265e-02],
                        [ 2.4166e-02,  5.7144e-02,  7.9627e-02,  ...,  8.1135e-02,
                          8.5009e-02,  6.1861e-02],
                        ...,
                        [ 4.8600e-02,  7.8239e-02,  8.6211e-02,  ...,  9.3687e-02,
                          1.0887e-01,  8.1600e-02],
                        [ 5.0626e-02,  8.1533e-02,  8.8941e-02,  ...,  9.5372e-02,
                          1.0517e-01,  7.6275e-02],
                        [ 3.3079e-02,

In [28]:
cd ~

/home/maria


In [29]:
file_path = '/home/maria/imagenet_data/'

In [30]:
mkdir $file_path'/checkpoints'

In [31]:
cd $file_path'/checkpoints'

/home/maria/imagenet_data/checkpoints


In [32]:
ls

In [33]:
#version of models:
#1: all units (good+mua) for the range of keep_channels
torch.save(model.state_dict(), './sota_model_1.pth')

---

# Train Simple LN model

In [34]:
# this will remove all nonlinearities from the CNN, and computes a 3 layer LN-model

model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
              'stack': -1,
              'layers': 3,
              'input_kern': 9,
              'gamma_input': 6.3831,
              'gamma_readout': 0.0076,
              'hidden_dilation': 1,
              'hidden_kern': 7,
              'hidden_channels': 64,
              'depth_separable': True,
              'init_sigma': 0.1,
              'init_mu_range': 0.3,
              'gauss_type': 'full',
              'linear': True
               }
ln_model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

In [35]:
validation_score, trainer_output, state_dict = trainer(ln_model, dataloaders, seed=42)

Epoch 1: 100%|██████████| 284/284 [00:19<00:00, 14.25it/s]
Epoch 2: 100%|██████████| 284/284 [00:19<00:00, 14.36it/s]
Epoch 3: 100%|██████████| 284/284 [00:19<00:00, 14.32it/s]
Epoch 4: 100%|██████████| 284/284 [00:19<00:00, 14.22it/s]
Epoch 5: 100%|██████████| 284/284 [00:19<00:00, 14.23it/s]
Epoch 6: 100%|██████████| 284/284 [00:19<00:00, 14.26it/s]
Epoch 7: 100%|██████████| 284/284 [00:19<00:00, 14.27it/s]
Epoch 8: 100%|██████████| 284/284 [00:20<00:00, 14.02it/s]
Epoch 9: 100%|██████████| 284/284 [00:19<00:00, 14.21it/s]
Epoch 10: 100%|██████████| 284/284 [00:20<00:00, 14.16it/s]
Epoch 11: 100%|██████████| 284/284 [00:19<00:00, 14.22it/s]
Epoch 12: 100%|██████████| 284/284 [00:20<00:00, 14.14it/s]
Epoch 13: 100%|██████████| 284/284 [00:20<00:00, 14.16it/s]
Epoch 14: 100%|██████████| 284/284 [00:19<00:00, 14.20it/s]
Epoch 15: 100%|██████████| 284/284 [00:20<00:00, 14.04it/s]
Epoch 16: 100%|██████████| 284/284 [00:20<00:00, 14.03it/s]
Epoch 17: 100%|██████████| 284/284 [00:20<00:00, 

In [17]:
ls

sota_model.pth


In [18]:
torch.save(ln_model.state_dict(), './ln_model_1.pth')

---