# Baseline CNN model

In this notebook we provide an explanation of the baseline CNN model, including: 
- a written description of the model (with references for further information)
- a code demo exploring different parts of the model

---

## Baseline CNN model description

### The model has two main parts

The baseline CNN model (which is mostly based on [this work](https://openreview.net/forum?id=Tp7kI90Htd)) is constructed from two main parts:
- **core**: the core aims to (nonlinearly) extract features that are common between neurons. That is, we assume there exist a set of features that all neurons use but combine them in their own unique way.
- **readout**: once the core extracts the feautures, then a neuron reads out from those features by simply linearly combining those features into a single value. Finally, by passing this single value through a final nonlinarity (in this case `ELU() + 1`) we make sure that the model output is positive and we get the inferred firing rate of the neuron.

### Learning where neurons "look" 👀

From experimental evidence, we know that neurons are sensitive to a limited area in the visual field - this is referred to as neuron's Receptive Field (RF). Knowing this, the readout is equipped with a mechanism that allows the model to learn where the neuron is "looking" in the visual field. In other words, the model learns the RF position of the neuron, allowing it to pick a specific spatial position from the core's output and then linearly combine the features along the channel dimension. This significantly reduces the number of parameters in the readout.

### RF as a function of the cortex positions

While the neuron's receptive field can be defined as a model parameter and be learned during training, we can be further inspired by the experimental evidence: neuron's that are located close to each other on the cortex, their RFs are also close in the visual field. To this end, we equip the readout with an additional module called **readout position network** which learns to map cortical positions to RF positions.

### Accounting for RF shifts due to eye movement

Finally, we account for shifts in the RF positions due to eye movement with yet another module called **shifter**. The shifter module takes pupil position (2d vector) as input and outputs a 2d vector which is used to globally shift all neurons' RF position.

### References
- [**Paper**: Lurz, K. K., Bashiri, M., Willeke, K., Jagadish, A. K., Wang, E., Walker, E. Y., ... & Sinz, F. H. (2021). Generalization in data-driven models of primary visual cortex. BioRxiv, 2020-10.](https://www.biorxiv.org/content/10.1101/2020.10.05.326256v2)
- [**This Video**](https://youtu.be/xwLMO8nVvxs?t=220) (which is a talk explaining the above paper) also explains the readout.

---

## Baseline CNN model exploration

### Imports

In [1]:
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from nnfabrik.builder import get_data

device = "cuda"
random_seed = 42

### Instantiate DataLoader

To initialize the model we use a *model function* which requires the dataloader as an input argument.

In [2]:
# loading the SENSORIUM+ dataset
filenames = ['../data/static27204-5-13-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', ]

dataset_fn = 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                 'normalize': True,
                 'include_behavior': False,
                 'include_eye_position': True,
                 'batch_size': 32,
                 'scale':1,
                 }

dataloaders = get_data(dataset_fn, dataset_config)

---

### Import the model function

In [3]:
from sensorium.models import stacked_core_full_gauss_readout

Let's have a quick look at the inputs arguments of the model function:

In [4]:
stacked_core_full_gauss_readout?

[0;31mSignature:[0m
[0mstacked_core_full_gauss_readout[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mdataloaders[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mseed[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mhidden_channels[0m[0;34m=[0m[0;36m32[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0minput_kern[0m[0;34m=[0m[0;36m13[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mhidden_kern[0m[0;34m=[0m[0;36m3[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlayers[0m[0;34m=[0m[0;36m3[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mgamma_input[0m[0;34m=[0m[0;36m15.5[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mskip[0m[0;34m=[0m[0;36m0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfinal_nonlinearity[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmomentum[0m[0;34m=[0m[0;36m0.9[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpad_input[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbatch_norm[0m[0;34m=[0m[0;32mTrue[0m

### Specify (some of) the input arguments to initialize the model

In [5]:
grid_mean_predictor = {
    'type': 'cortex',
    'input_dimensions': 2,
    'hidden_layers': 1,
    'hidden_features': 30,
    'final_tanh': True
}

model_config = {
    # core args
    'input_kern': 9,
    'hidden_kern': 7,
    'hidden_channels': 64,
    'layers': 4,
    'depth_separable': True,
    'stack': -1,
    'gamma_input': 6.3831,
    # readout args
    'gamma_readout': 0.0076,
    'grid_mean_predictor': grid_mean_predictor,
    'gauss_type': 'full',
    'shifter': True,
}

With the above model config we are defining a model which:
- has four convolution laters where each layer
    - is a depth-seperable convolution (`depth_separable=True`)
    - has a kernel size of 7 (`hidden_kern=7`), with the exception of the first layer which has a kernel size of 9 (`input_kern=9`)
    - outputs an activation tensor with 64 channels (`hidden_channels=64`)
- only uses the last layer (`stack=-1`) as the final output of the core (other options results in stacking the outputs of multiple layers)
- uses the cortex positions (x and y) of the neurons to infer their receptive field position (specified with the `grid_mean_predictor`)
- uses the pupil center to shift the neurons' receptive positions globally (they will all be shifted the same way) depending on the pupil position of the subject (`shifter=True`)

### Instantiate the model

In [6]:
model = stacked_core_full_gauss_readout(dataloaders, random_seed, **model_config).to(device)
model.eval();

In [7]:
model

FiringRateEncoder(
  (core): Stacked2dCore(
    (_input_weights_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv): Conv2d(1, 64, kernel_size=(9, 9), stride=(1, 1), bias=False)
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): AdaptiveELU()
      )
      (layer1): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (spatial_conv): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64, bias=False)
          (out_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): AdaptiveELU()
      )
      (layer2): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_depth_conv): 

Looking at the model we can see three top modules:
- core: `model.core`
- readout: `model.readout`
- shifter: `model.shifter`

**Note** that when you display the readout or the shifter modules you get a `ModuleDict` as opposed to the core which is just a `Module`. The model is designed this way to allow the user to share the core between multiple datasets. While the features (i.e. output of the core) can be shared between different neurons and different subjects/sessions, other information are most likely unique to each dataset. Therefore, we keep the readout and shifter dataset specific. Each key (aka `data_key`) in the readout (or shifter) corresponds to a specific dataset. 

---

Now we are going to go through each model and explore them individually. But before doing that let's get a batch of data:

### Get a single batch from the dataloader

In [8]:
data_key = '27204-5-13'
batch = next(iter(dataloaders["train"][data_key]))

In [9]:
len(batch)

3

In [10]:
batch._asdict().keys()

dict_keys(['images', 'responses', 'pupil_center'])

In [11]:
batch.images.shape

torch.Size([32, 1, 144, 256])

In [12]:
batch.responses.shape

torch.Size([32, 7538])

In [13]:
batch.pupil_center.shape

torch.Size([32, 2])

### Core

In [14]:
core_output = model.core(batch.images)

In [15]:
core_output.shape

torch.Size([32, 64, 136, 248])

We see that the output of the core has:
- **64** channels as specified by the `hidden_channels` argument
- a height of **136** which is `input_h - input_kern + 1` <-> **144 - 9 + 1 = 136**. Note that this also implies that the rest of the convolutional layers preserve the dimensions (i.e. `padding=same`)
- a width of **248** which is `input_w - input_kern + 1` <-> **256 - 9 + 1 = 248**.

### Readout

In [16]:
readout_output = model.readout[data_key](core_output)

In [17]:
readout_output.shape

torch.Size([32, 7538])

And here we we have the predicted firing rate of all the neurons (n=7538) for all the images (n=32) in this batch.

### Where is the shifter then?


The shifter takes the eye position (can be assessed in a variable called `pupil_center`) and outputs a global shift which is then used as an input to the readout.

In [18]:
shifter_output = model.shifter[data_key](batch.pupil_center)

In [19]:
shifter_output.shape

torch.Size([32, 2])

In [20]:
readout_output_shifted = model.readout[data_key](core_output, shift=shifter_output)

In [21]:
readout_output_shifted.shape

torch.Size([32, 7538])

We can check whether this results in a different output from the readout:

In [22]:
torch.equal(readout_output, readout_output_shifted)

False

### References
- Code for the [model function](https://github.com/sinzlab/sensorium/blob/8660c0c925b3944e723637db4725083f84ee28c3/sensorium/models/models.py#L17)
- Code for the [core Module](https://github.com/sinzlab/neuralpredictors/blob/0d3d793cc0e1f55ec61c5f9f7a98318b5241a2e9/neuralpredictors/layers/cores/conv2d.py#L27)
- Code for the [readout Module](https://github.com/sinzlab/neuralpredictors/blob/0d3d793cc0e1f55ec61c5f9f7a98318b5241a2e9/neuralpredictors/layers/readouts/gaussian.py#L210)
- Code for the [shifter module](https://github.com/sinzlab/neuralpredictors/blob/0d3d793cc0e1f55ec61c5f9f7a98318b5241a2e9/neuralpredictors/layers/shifters/mlp.py#L13)