# How to train the Baseline Models

### This notebook will show how to
- instantiate dataloader for the demo data
- instantiate pytorch model
- instantiate a trainer function
- train two baselines on the demo data
- save the model weights (the model weights can already be found in '/notebooks/precomputed_checkpoints/')

### Imports

In [1]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

from nnfabrik.builder import get_data, get_model, get_trainer

### Instantiate DataLoader

In [2]:
# loading the SENSORIUM+ dataset
filenames = ['./data/static27204-5-13-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', ]

dataset_fn = 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                 'normalize': True,
                 'include_behavior': True,
                 'include_eye_position': True,
                 'batch_size': 128,
                 'scale':.25,
                 }

dataloaders = get_data(dataset_fn, dataset_config)

In [3]:
dataloaders

OrderedDict([('train',
              OrderedDict([('27204-5-13',
                            <torch.utils.data.dataloader.DataLoader at 0x7fd54294ee80>)])),
             ('validation',
              OrderedDict([('27204-5-13',
                            <torch.utils.data.dataloader.DataLoader at 0x7fd54295c370>)])),
             ('test',
              OrderedDict([('27204-5-13',
                            <torch.utils.data.dataloader.DataLoader at 0x7fd54295c340>)])),
             ('final_test',
              OrderedDict([('27204-5-13',
                            <torch.utils.data.dataloader.DataLoader at 0x7fd54295c310>)]))])

In [None]:
b.images.shape

# Instantiate State of the Art Model (SOTA)

In [7]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_dilation': 1,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': True,
  'stack': -1,
  'use_avg_reg':False,
}

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

# load checkpoints

In [8]:
model.load_state_dict(torch.load("./model_checkpoints/pretrained/sensorium_p_sota_model.pth"));

In [9]:
model.eval();

# Now Evaluate

### get submission file

In [10]:
dataset_name = '27204-5-13'

In [11]:
# import the API from the competition repo

from sensorium import evaluate
from sensorium.utility import submission

In [None]:
# generate the submission file
submission.generate_submission_file(trained_model=model, 
                                    dataloaders=dataloaders,
                                    data_key=dataset_name,
                                    path="./submission_files/",
                                    device="cuda")

In [None]:
# now you can upload the "test"  and "final_test" submission files to the competition homepage.

# submission_file_test: these are the LIVE scores
# submission_file_fina_test: these are the FINAL scores


---