<table class="ee-notebook-buttons" align="left"><td>
<a target="_blank"  href="https://colab.research.google.com/github/walkerlab/FENS-2022/blob/main/notebooks/UseYourOwnTraining.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" /> Run in Google Colab</a>
</td><td>
<a target="_blank"  href="https://github.com/walkerlab/FENS-2022/blob/main/notebooks/UseYourOwnTraining.ipynb"><img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" /> View source on GitHub</a></td></table>

# Digging deeper into the training

In the [main notebook](./Deep-Learning-in-Neuroscience.ipynb), we made exclusive use of `train_model` function to take care of a lot of details of the training. In this notebook, we are going to develop our own, albeit simpler training routine to get a better appreciation of what goes into training neural networks.

## Part 0: Preparing the environment

Again, we are going to prepare the environment by downloading the necessary library (e.g. `FENS-2022`) and the dataset. This is necessary as each colab notebook ends up offering distinct environment by default.

In [None]:
# Clone and install the FENS package
!git clone https://github.com/walkerlab/FENS-2022.git
!pip3 install FENS-2022

In [None]:
# download the dataset
!wget -nc "https://onedrive.live.com/download?cid=06D44059794C5B46&resid=6D44059794C5B46%21121992&authkey=AHJVfxtvAASasjQ" -O dataset.zip

# Unzip
!unzip -nq 'dataset.zip'

# get trained network weights
!git clone https://gin.g-node.org/walkerlab/fens-2022.git /content/trained_nets

Finally we go ahead and implment a bunch of standard libraries.

In [None]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
import seaborn as sns

### Prepare the dataloaders

As before, we prepare PyTorch [dataloaders](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) to conveniently load all images and responses in **batches** using `load_dataset` function.

In [None]:
from fens.dataset import load_dataset

In [None]:
dataloaders = load_dataset(path = './Lurz2020/static20457-5-9-preproc0', batch_size=60)

In [None]:
dataloaders

In [None]:
train_loader = dataloaders['train']
valid_loader = dataloaders['validation']
test_loader = dataloaders['test']

Let's also extract an underlying dataset object so that we can gain access to additional *meta* information. Keep in mind that additional data attributes avaialble on the dataset is unique to the way we designed the dataset in `fens` library!

In [None]:
# Access to the dataset object that underlies all dataloaders
dataset = dataloaders['test'].dataset

You can see that indeed a single neuron's responses vary widely even to repeated presentations of an identical stimulus!

This so-called **noiseness** of neural responses make predicting the respones of the neurons to images fundamentally challenging and in fact makes it  essentially impossible to yield a perfect fit!

Instead, we would often try to fit the **distribution of responses** the best we can, and we will briefly visit this point later.

# Setting up the LN Model

Let us again define a Linear-Nonlinear (LN) model that we can use to test our training routine.

In [None]:
class Linear(nn.Module):
    def __init__(
        self,
        input_height,
        input_width,
        n_neurons,
        momentum=0.1,
        init_std=1e-3,
        gamma=0.0,
    ):
        super().__init__()
        self.bn = nn.BatchNorm2d(1, momentum=momentum, affine=False)
        self.linear = nn.Linear(input_height * input_width, n_neurons)
        self.gamma = gamma
        self.init_std = init_std
        self.initialize()
        
    
    def forward(self, x):
        x = self.bn(x)
        x = self.linear(x.flatten(1))
        return nn.functional.elu(x) + 1
        

    def initialize(self, std=None):
        if std is None:
            std = self.init_std
        nn.init.normal_(self.linear.weight.data, std=std)


    def regularizer(self):
        return self.gamma * self.linear.weight.abs().sum()


### Performing the training - simple routine

It's time to train the network, and this time we will do so wihout relying on the `fens`' `train_model` convenience function. 

We are going to want to periodically check the performance on the validation set during the training to monitor the progress. It's fairly common to use a metric here that is different from the training objective and more intuitive. 

Here, we are going to compute `correlation` between the real and the predicted responses across images, averaged over the neurons. Let's go ahead and define a function to compute the correlation.


In [None]:
def corr(y1, y2, axis=-1, eps=1e-8, **kwargs):
    """
    Compute the correlation between two matrices along certain dimensions.

    Args:
        y1:      first numpy array
        y2:      second numpy array
        axis:    dimension along which the correlation is computed.
        eps:     offset to the standard deviation to make sure the correlation is well defined (default 1e-8)
        **kwargs passed to final mean of standardized y1 * y2

    Returns: correlation vector

    """
    y1 = (y1 - y1.mean(axis=axis, keepdims=True)) / (
        y1.std(axis=axis, keepdims=True, ddof=0) + eps
    )
    y2 = (y2 - y2.mean(axis=axis, keepdims=True)) / (
        y2.std(axis=axis, keepdims=True, ddof=0) + eps
    )
    return (y1 * y2).mean(axis=axis, **kwargs)

Let us set various part go ahead and train the network. Let's go ahead and build our network and define our training routine, using a simple `SGD` (stochastic gradient descient optimizer).

We first instantiate the model to train.

In [None]:
ln_model = Linear(input_height=64, input_width=36, n_neurons=5335, gamma=0.1)

# let us refer to our target model as `model`
model = ln_model

Now that we have instantiated a network to train, let's define a routine to train the network. We will train on our training dataset, and periodically report the performance on the validation dataset. As discussed before, we will make use of Poisson Loss, which is simply the negative log likelihood of Poisson probability. This is conveniently given as `PoissonNLLLoss` in PyTorch.

In [None]:
from torch.optim import SGD, Adam
from torch.nn import PoissonNLLLoss
from scipy import percentile
from tqdm import tqdm # to show progress bar

In [None]:
N_EPOCHS = 100
REPORT_FREQ = 2 # how often to report the performance on the validation set
lr = 0.01
device = 'cuda'

# log_input=False says that we intend to pass in response value directly, rather than log of responses
criterion = PoissonNLLLoss(log_input=False, reduction='sum')
#criterion = PoissonLoss(avg=False)

# move model and criterion into the target device
model.to(device)
criterion.to(device)

# define and setup the optimizer
optimizer = Adam(model.parameters(), lr=lr)

train_loader = dataloaders['train']
valid_loader = dataloaders['validation']
test_loader = dataloaders['test']

optimizer.zero_grad()
for epoch in range(N_EPOCHS):
    for batch_no, (images, targets) in tqdm(enumerate(train_loader), 
                                            desc="Epoch {}".format(epoch),
                                            total=len(train_loader)):
        # put model into training mode
        model.train()

        # zero out the gradient
        optimizer.zero_grad()

        # move data into the target device
        images, targets = images.to(device), targets.to(device)

        # get predicted responses
        responses = model(images)

        # compute the loss, with regularizers
        loss = criterion(responses, targets) + model.regularizer()
        
        # compute the gradient
        loss.backward()

        # apply the learning step
        optimizer.step()
        
    if epoch % REPORT_FREQ == 0:
        total_responses = []
        total_targets = []
        for images, targets in tqdm(valid_loader,
                                    desc="Validation",
                                    total=len(valid_loader)):
            pass
            with torch.no_grad():
                images = images.to(device)
                model.eval()
                total_responses.append(model(images).detach().cpu())
            total_targets.append(targets.detach().cpu())

        # concatenate batches into one big numpy array
        total_responses = torch.concat(total_responses).numpy()
        total_targets = torch.concat(total_targets).numpy()
        
        # compute the correlation
        print('Correlation: {:.3f}'.format(corr(total_responses, total_targets, axis=0).mean()))
    

Now try and see how well the above routine works when you use a CNN model!

In [None]:
from collections import OrderedDict
class CNN(nn.Module):
    def __init__(
        self,
        input_height,
        input_width,
        n_neurons,
        momentum=0.1,
        init_std=1e-3,
        gamma=0.1,
        hidden_channels=8,
    ):
        super().__init__()
        self.init_std = init_std
        self.gamma = gamma

        # CNN core
        self.cnn_core = nn.Sequential(
            OrderedDict(
                [
                    ("conv1", nn.Conv2d(1, hidden_channels, 15, padding=15 // 2, bias=False)),
                    ("bn1", nn.BatchNorm2d(hidden_channels, momentum=momentum)),
                    ("elu1", nn.ELU()),
                    ("conv2", nn.Conv2d(hidden_channels, hidden_channels, 13, padding=13 // 2, bias=False)),
                    ("bn2", nn.BatchNorm2d(hidden_channels, momentum=momentum)),
                    ("elu2", nn.ELU()),
                    ("conv3", nn.Conv2d(hidden_channels, hidden_channels, 13, padding=13 // 2, bias=False)),
                    ("bn3", nn.BatchNorm2d(hidden_channels, momentum=momentum)),
                    ("elu3", nn.ELU()),
                    ("conv4", nn.Conv2d(hidden_channels, hidden_channels, 13, padding=13 // 2, bias=False)),
                    ("bn4", nn.BatchNorm2d(hidden_channels, momentum=momentum)),
                    ("elu4", nn.ELU()),
                ]
            )
        )

        # Fully connected readout
        self.readout = nn.Sequential(
            OrderedDict(
                [
                    ('fc_ro', nn.Linear(input_height * input_width * hidden_channels, n_neurons)),
                    ('bn_ro', nn.BatchNorm1d(n_neurons, momentum=momentum)),
                ]
            )
        )


    def initialize(self, std=None):
        if std is None:
            std = self.init_std
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.normal_(m.weight.data, std=std)

    def forward(self, x):
        x = self.cnn_core(x)
        x = x.view(x.size(0), -1)
        x = self.readout(x)
        return nn.functional.elu(x) + 1
    
    def regularizer(self):
        return self.readout[0].weight.abs().sum() * self.gamma


Let us now instantiate the model and train it!

In [None]:
cnn_model = CNN(input_height=64, input_width=36, n_neurons=5335)