# Demo Notebook how to load the transfer core and train a model

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2 

In [None]:
import torch
from collections import OrderedDict
import neuralpredictors as neur

## Build the dataloaders

In [None]:
from lurz2020.datasets.mouse_loaders import static_loaders

paths = ['data/static20457-5-9-preproc0']

dataset_config = dict(
    paths=paths,
    batch_size=64,
    seed=1,
)

dataloaders = static_loaders(**dataset_config)

## Build the model

In [None]:
from lurz2020.models.models import se2d_fullgaussian2d

model_config = {'init_mu_range': 0.55,
                 'init_sigma': 0.4,
                 'input_kern': 15,
                 'hidden_kern': 13,
                 'gamma_input': 1.0,
                 'grid_mean_predictor': {'type': 'cortex',
                                          'input_dimensions': 2,
                                          'hidden_layers': 0,
                                          'hidden_features': 0,
                                          'final_tanh': False},
                 'gamma_readout': 2.439}

model = se2d_fullgaussian2d(**model_config, dataloaders=dataloaders, seed=1)

## Load the weights of the transfer core

In [None]:
transfer_model = torch.load('models/transfer_model.pth.tar') 
transfer_core = OrderedDict([(k, v) for k, v in transfer_model.items() if k[0:5] == "core."])
model.load_state_dict(transfer_core, strict=False)

## Build the trainer

In [None]:
from lurz2020.training.trainers import standard_trainer

# If you want to allow fine tuning of the core, set detach_core to False
detach_core=True
if detach_core:
    print('Core is fixed and will not be fine-tuned')
else:
    print('Core will be fine-tuned')

trainer_config = dict(track_training=True, detach_core=detach_core)
trainer = standard_trainer

## Run training

In [None]:
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=1)