# Demo Notebook on how to load the transfer core and train a model

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2 

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
import neuralpredictors as neur

# Build the dataloaders

The dataloaders object is a dictionary of 3 dictionaries: train, validation and test. Each of them contains the respective data from all datasets combined that were specified in paths. Here we only provide one dataset. While the responses are normalized, we exclude the input images from normalization. The following config was used in the paper (all arguments not in the config have the default value of the function). 

In [None]:
from lurz2020.datasets.mouse_loaders import static_loaders

paths = ['data/Lurz2020/static20457-5-9-preproc0']

dataset_config = {'paths': paths, 
                  'batch_size': 64, 
                  'seed': 1, 
                  'cuda': True,
                  'normalize': True, 
                  'exclude': "images"}

dataloaders = static_loaders(**dataset_config)

### Look at the data

In [None]:
tier = 'train'
dataset_name = '20457-5-9-0'

images, responses = [], []
for x, y in dataloaders[tier][dataset_name]:
    images.append(x.squeeze().cpu().data.numpy())
    responses.append(y.squeeze().cpu().data.numpy())
    
images = np.vstack(images)
responses = np.vstack(responses)

print('The \"{}\" set of dataset \"{}\" contains the responses of {} neurons to {} images'.format(tier, dataset_name, responses.shape[1], responses.shape[0]))

In [None]:
# show some example images and the neural responses
n_images = 5
max_response = responses[:n_images].max()

for i in range(n_images):
    fig, axs = plt.subplots(1, 2, figsize=(15,4))
    axs[0].imshow(images[i])
    axs[1].plot(responses[i])
    axs[1].set_xlabel('neurons')
    axs[1].set_ylabel('responses')
    axs[1].set_ylim([0, max_response])
    plt.show()

# Build the model

If you want to load the transfer core later on, the arguments in the model config that concern the architecture of the model can not be changed. The following config was used in the paper (all arguments not in the config have the default value of the function).

In [None]:
from lurz2020.models.models import se2d_fullgaussian2d

model_config = {'init_mu_range': 0.55,
                'init_sigma': 0.4,
                'input_kern': 15,
                'hidden_kern': 13,
                'gamma_input': 1.0,
                'grid_mean_predictor': {'type': 'cortex',
                                        'input_dimensions': 2,
                                        'hidden_layers': 0,
                                        'hidden_features': 0,
                                        'final_tanh': False},
                'gamma_readout': 2.439}

model = se2d_fullgaussian2d(**model_config, dataloaders=dataloaders, seed=1)

## Load the weights of the transfer core

This will load the weights of the transfer core onto the model that you built above. The argument `strict=False` ensures that only matching keys are loaded. The readout keys are thus discarded.

In [None]:
transfer_model = torch.load('models/transfer_model.pth.tar') 
model.load_state_dict(transfer_model, strict=False)

# Build the trainer

In [None]:
from lurz2020.training.trainers import standard_trainer as trainer

# If you want to allow fine tuning of the core, set detach_core to False
detach_core=True
if detach_core:
    print('Core is fixed and will not be fine-tuned')
else:
    print('Core will be fine-tuned')

trainer_config = {'track_training': True,
                  'detach_core': detach_core}

# Run training

In [None]:
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=1, **trainer_config)

# Analyze the trained model

### Predict neural responses to an image (here from the train set)

In [None]:
# show some example images and the neural responses
n_images = 5
max_response = responses[:n_images].max()

for i in range(n_images):
    input_image = images[i]
    predicted_response = model(torch.from_numpy(input_image).view(1,1,36,64).cuda())
    predicted_response = predicted_response.squeeze().cpu().data.numpy()
    
    fig, axs = plt.subplots(1, 3, figsize=(20,4))
    axs[0].imshow(input_image)
    axs[1].plot(responses[i])
    axs[2].plot(predicted_response)
    axs[1].set_xlabel('neurons')
    axs[2].set_xlabel('neurons')
    axs[1].set_ylabel('responses')
    axs[2].set_ylabel('predicted responses')
    axs[1].set_ylim([0, max_response])
    plt.show()

### Get the performance of your model

In [None]:
from lurz2020.utility.measures import get_correlations, get_fraction_oracles

train_correlation = get_correlations(model, dataloaders["train"], device='cuda', as_dict=False, per_neuron=False)
validation_correlation = get_correlations(model, dataloaders["validation"], device='cuda', as_dict=False, per_neuron=False)
test_correlation = get_correlations(model, dataloaders["test"], device='cuda', as_dict=False, per_neuron=False)

# Fraction Oracle can only be computed on the test set. It requires the dataloader to give out batches of repeats of images. 
# This is achieved by building a dataloader with the argument "return_test_sampler=True"
oracle_dataloader = static_loaders(**dataset_config, return_test_sampler=True, tier='test')
fraction_oracle = get_fraction_oracles(model=model, dataloaders=oracle_dataloader, device='cuda')[0]

print('-----------------------------------------')
print('Correlation (train set):      {0:.3f}'.format(train_correlation))
print('Correlation (validation set): {0:.3f}'.format(validation_correlation))
print('Correlation (test set):       {0:.3f}'.format(test_correlation))
print('-----------------------------------------')
print('Fraction oracle (test set):   {0:.3f}'.format(fraction_oracle))