*Note: This is an adaptation of the notebook found in optical_flow repo, commit `b654828`, which in its original form works with sinabs from commit `3fafee4` and sinabs-slayer, commit `7114a11`*

This adaptation should work with sinabs and exodus in their most recent versions (25.02.2022) and is independent of the optical_flow project. 

# Why does the slayer model not work?

This notebook will show that for the wheel-motion classification toy task, the same model architecture can be trained with exodus but training on slayer fails. We will try to explore why this is the case.

The notebook's code is based on the script `binary_task.py` from the optical_flow repo.

SPOILER:
The problem was exploding gradients in the slayer model. Enabling an option to scale down the surrogate gradients resolved it.
The problem can be reproduced by setting `kwargs_spiking["scale_grads"]` to 1 (second code cell), which corresponds to the original value.

In [None]:
### --- Run this cell only to generate the data!
%run data_generation.py --size=256 --num_segments=4 --num_timesteps=300 --save_path=rotating_wedge_events.npy

In [None]:
### --- Imports

# Set this to inline or notebook if widget is not supported. Animation might not work then.
%matplotlib widget

import torch
from torch import nn
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation

from slayerSNN import loss as SpikeLoss

from sinabs.from_torch import from_model
from sinabs.utils import get_activations
import sinabs.layers as sl

from data import InvertDirDataset
from binary_models import SlayerModel, ExodusModel

In [None]:
### --- Data loading and inspection
raster = np.load("rotating_wedge_events.npy").astype(np.float32)
raster = raster.transpose(1,2,3,0)
print(raster.shape)

raster_merged = (raster[1] - raster[0]).transpose(-1,0,1)

fig, ax = plt.subplots()
screen = ax.imshow(raster_merged[0], vmin=np.min(raster_merged), vmax=np.max(raster_merged))

def update_plot(frm):
    return screen.set_data(frm)

# anim = FuncAnimation(fig, update_plot, frames=raster_merged)

In [None]:
### --- Settings and hyperparameters
lr = 1e-3
num_epochs = 40
optimizer_class = torch.optim.SGD  # torch.optim.Adam
downsample = 1
num_ts = raster.shape[-1] // downsample

# - Model parameters
kwargs_model = {
    "grad_width": 0.5,  # 0.5
    "grad_scale": 1,  # 0.02
    "thr": 1,
    "num_ts": num_ts,
}

# Use LIF model
# kwargs_model["neuron_type"] = "LIF"
# kwargs_model["tau_leak"] = 20

In [None]:
# - Datset and loader
# Set sample_size and step_size such that each sample corresponds to all frames of one class
ds = InvertDirDataset(
    raster, sample_size=num_ts, step_size=num_ts, downsample=downsample
)
dl = torch.utils.data.DataLoader(ds, batch_size=1, shuffle=True)

# - Loss function
loss_func = torch.nn.CrossEntropyLoss()

torch.manual_seed(1234)

# - Sinabs model
model_exodus = ExodusModel(**kwargs_model).cuda()
model_exodus.reset()
print("Sinabs model:")
print(model_exodus)

model_slayer = SlayerModel(**kwargs_model).cuda()
print("Slayer model:")
print(model_slayer)

In [None]:
# - Transfer weights from exodus to slayer model to ensure same initial conditions
model_slayer.conv0.weight.data = model_exodus.conv0.weight.data.unsqueeze(-1).clone()
model_slayer.conv1.weight.data = model_exodus.conv1.weight.data.unsqueeze(-1).clone()
model_slayer.linear.weight.data = model_exodus.linear.weight.data.clone().reshape(2, 8, 4, 4, 1)

Now the two models should produce the same output for a given input. Let's make sure this is the case

In [None]:
# - Compare outputs before training
outputs_slayer = []
outputs_exodus = []

with torch.no_grad():
    # Load input from dataset to avoid shuffling.
    for inp, *__ in ds:
        # Add batch dimension
        inp = inp.unsqueeze(0).cuda()
        outputs_slayer.append(model_slayer(inp))
        outputs_exodus.append(model_exodus(inp))
        model_exodus.reset()

In [None]:
len(ds)

In [None]:
os = outputs_slayer[1].squeeze(0).cpu().detach()
oe = outputs_exodus[1].squeeze(0).cpu().detach()
fig, (ax1, ax2) = plt.subplots(2, 1)
ax1.plot(oe[0].t(), lw=2, label="exodus")
ax1.plot(os[0].t(), ls="--", label="slayer")
ax2.plot(oe[1].t(), lw=2, label="exodus")
ax2.plot(os[1].t(), ls="--", label="slayer")
ax1.legend()
ax2.legend()

Outputs seem very similar.

In [None]:
def correlation(a, b, eps=1e-16):
    a = a.flatten()
    b= b.flatten()
    return torch.sum(a*b) / (torch.sqrt(torch.sum(a**2) * torch.sum(b**2)) + eps)

In [None]:
# - Set up training
loss_func = torch.nn.CrossEntropyLoss()
optim_exo = optimizer_class(model_exodus.parameters(), lr=lr)
optim_slr = optimizer_class(model_slayer.parameters(), lr=lr)

# - Do a single batch
dl_iter = iter(dl)
inp, tgt, onehot = next(dl_iter)
print("Target:", tgt)
inp = inp.cuda()

# Average over different initialisations
all_max_grads_exo = []
all_max_grads_slr = []
all_std_grads_exo = []
all_std_grads_slr = []
all_corrs = []
for seed in [11, 22, 33, 44, 55, 66, 77, 88, 99]:
    torch.manual_seed(seed)

    model_exodus = ExodusModel(**kwargs_model).cuda()
    model_exodus.reset()

    model_slayer = SlayerModel(**kwargs_model).cuda()

    # - Transfer weights from exodus to slayer model to ensure same initial conditions
    model_slayer.conv0.weight.data = model_exodus.conv0.weight.data.unsqueeze(-1).clone()
    model_slayer.conv1.weight.data = model_exodus.conv1.weight.data.unsqueeze(-1).clone()
    model_slayer.linear.weight.data = model_exodus.linear.weight.data.clone().reshape(2, 8, 4, 4, 1)

    # Sinabs
    out_exo = model_exodus(inp).sum(-1).cpu()
    optim_exo.zero_grad()
    loss_exo = loss_func(out_exo, tgt)
    loss_exo.backward()
    grads_exo = [p.grad for p in model_exodus.parameters()]
    model_exodus.reset()

    # Slayer
    out_slr = model_slayer(inp).sum(-1).cpu()
    optim_slr.zero_grad()
    loss_slr = loss_func(out_slr, tgt)
    loss_slr.backward()
    grads_slr = [p.grad.squeeze(-1) for p in model_slayer.parameters() if p.grad is not None]

    print(f"Losses - Sinabs: {loss_exo.item()}, Slayer: {loss_slr.item()}")

    max_grads_exo = [torch.max(g).item() for g in grads_exo]
    max_grads_slr = [torch.max(g).item() for g in grads_slr]
    std_grads_exo = [torch.std(g).item() for g in grads_exo]
    std_grads_slr = [torch.std(g).item() for g in grads_slr]
    correlations = [correlation(gs, ge).item() for gs, ge in zip(grads_slr, grads_exo)]

    all_max_grads_exo.append(max_grads_exo)
    all_max_grads_slr.append(max_grads_slr)
    all_std_grads_exo.append(std_grads_exo)
    all_std_grads_slr.append(std_grads_slr)
    all_corrs.append(correlations)

In [None]:
avg_max_grads_exo = torch.tensor(all_max_grads_exo).mean(dim=0)
avg_max_grads_slr = torch.tensor(all_max_grads_slr).mean(dim=0)
avg_std_grads_exo = torch.tensor(all_std_grads_exo).mean(dim=0)
avg_std_grads_slr = torch.tensor(all_std_grads_slr).mean(dim=0)
avg_corrs = torch.tensor(all_corrs).mean(dim=0)

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 4))

labels = ['layer1', 'layer2', 'layer3']
x = np.arange(len(labels))  # the label locations
width = 0.35  # the width of the bars

ax1.bar(x - width/2, avg_max_grads_slr, width, label='SLAYER')
ax1.bar(x + width/2, avg_max_grads_exo, width, label='EXODUS')
ax1.set_xticks(x)
ax1.set_xticklabels(labels)
ax1.set_yscale('log')
ax1.set_title('max gradient')
ax1.legend()

ax2.bar(x - width/2, avg_std_grads_slr, width, label='SLAYER')
ax2.bar(x + width/2, avg_std_grads_exo, width, label='EXODUS')
ax2.set_xticks(x)
ax2.set_xticklabels(labels)
ax2.set_yscale('log')
ax2.set_title('gradient std')
ax2.legend()

ax3.bar(range(len(correlations)), avg_corrs, color='C3')
ax3.set_xticks(x)
ax3.set_xticklabels(labels)
ax3.set_title('gradient correlation\nbetween slayer and exodus')
ax3.set_ylim(bottom=0.5)

fig.suptitle('NEURON MODEL IAF')
plt.tight_layout()
plt.savefig('gradient_comparison_optical_flow.png')

It turns out that with the default scaling (`kwargs_spiking["scale_grads"] = 1.0`), the gradients in the slayer model explode. After trying a few values, setting the scale to 0.02 seems to give reasonable gradients.
Let's train both models to see, whether everything works now.

In [None]:
# - Training loop

mistakes_exo = []
mistakes_slr = []

for ep in range(num_epochs):
    print(f"Epoch {ep} ------------------------------------------------------")
    for inp, tgt, __ in dl:
        inp = inp.cuda()       
        
        # Sinabs
        out_exo = model_exodus(inp).sum(-1).cpu()
        __, predict_exo = torch.max(out_exo, 1)
        optim_exo.zero_grad()
        loss_exo = loss_func(out_exo, tgt)
        loss_exo.backward()
        exo_right = tgt.item() == predict_exo.item()
        mistakes_exo.append(int(not exo_right))

        # Slayer
        out_slr = model_slayer(inp).sum(-1).cpu()
        __, predict_slr = torch.max(out_slr, 1)
        optim_slr.zero_grad()
        loss_slr = loss_func(out_slr, tgt)
        loss_slr.backward()        
        slr_right = tgt.item() == predict_slr.item()
        mistakes_slr.append(int(not slr_right))
        
        # Get correlation
        grads_exo = [p.grad.squeeze(-1) for p in model_exodus.parameters() if p.grad is not None]
        grads_slr = [p.grad.squeeze(-1) for p in model_slayer.parameters() if p.grad is not None]
        correls = [correlation(gs, ge).item() for gs, ge in zip(grads_slr, grads_exo)]
        
        # Optimizer step
        optim_exo.step()
        model_exodus.reset()
        optim_slr.step()
        
        # Print statement
        print(f"Target: {tgt.item()}")
        print(f"Prediction exodus: {predict_exo.item()} ({'correct' if exo_right else 'wrong'})")
        print(f"Prediction slayer: {predict_slr.item()} ({'correct' if slr_right else 'wrong'})")
        print(f"Gradient correlation: {correls}")
                
    # print(ep, sum(tl[-5:]), end="\r")
    print(f"Total mistakes exodus: {sum(mistakes_exo)}, slayer: {sum(mistakes_slr)} -----------------------")

In [None]:
plt.figure()
plt.plot(mistakes_exo)
plt.plot(mistakes_slr)