In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from single_neuron_models import ExodusNeuron, SlayerNeuron

In [None]:
tau_mem = torch.tensor(30.)
tau_syn = 30.
spike_threshold = 1
n_time_steps = 100
width_grad = 1
scale_grad = 1

In [None]:
params = {
    'weight': 0.01,
    'tau_mem': tau_mem,
    'spike_threshold': spike_threshold,
    'width_grad': width_grad,
    'scale_grad': scale_grad,
}
exodus_neuron = ExodusNeuron(**params).cuda()
slayer_neuron = SlayerNeuron(n_time_steps=n_time_steps, **params).cuda()

input_spike_time = 10
target_spike_time = 30
input = torch.zeros(1, 100, 1, 1, 1)
target = torch.zeros_like(input)
input[:, input_spike_time] = 1
target[:, target_spike_time] = 1

In [None]:
out1 = exodus_neuron(input.cuda())
v_mem1 = exodus_neuron.lif.v_mem_recorded

out2 = slayer_neuron(input.cuda())
psp_pre = slayer_neuron.psp_pre
psp_post = slayer_neuron.psp_post

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(8, 6))

ax1.plot(out1.squeeze().cpu().detach().numpy(), label='output')
ax1.plot(v_mem1.squeeze().cpu().detach().numpy(), label='LIF v_mem')
ax1.legend()
ax1.set_title('Exodus neuron')

ax2.plot(out2.squeeze().cpu().detach().numpy(), label='output')
# ax2.plot(psp_pre.squeeze().cpu().detach().numpy(), label='psp pre')
ax2.plot(psp_post.squeeze().cpu().detach().numpy(), label='psp post')
ax2.legend()
ax2.set_title('SLAYER CUBALIF neuron')
ax2.set_xlabel('time steps')

In [None]:
def train(model, input, target, lr=1e-2):
    criterion = nn.MSELoss()
    optimiser = torch.optim.Adam(model.parameters(), lr=lr)

    param_trace = []
    loss_trace = []
    grad_trace = []

    input, target = input.cuda(), target.cuda()

    for i in tqdm(range(2000)):
        if model.spiking_layers:
            model.reset_states()
        model.zero_grad()
        optimiser.zero_grad()

        output = model(input)

        regulariser = 0.01 * (output.sum() - 1) ** 2
        loss = criterion(output, target) * n_time_steps + regulariser

        if loss == 0:
            break

        loss_trace.append(loss.item())
        loss.backward()
        optimiser.step()

        param_trace.append(model.lin.weight.item())
        grad_trace.append(model.lin.weight.grad.item())

    return output, param_trace, loss_trace, grad_trace

In [None]:
lr=1e-3
output, param_trace, loss_trace, grad_trace = train(exodus_neuron, input, target, lr=lr)
output2, param_trace2, loss_trace2, grad_trace2 = train(slayer_neuron, input, target, lr=lr)

In [None]:
fig, (ax0, ax1, ax2) = plt.subplots(3, 1, sharex=True, figsize=(16, 8))

ax0.plot(loss_trace, label='EXODUS')
ax0.plot(loss_trace2, label='SLAYER')
ax0.set_ylabel('loss')
ax0.legend()
ax1.plot(param_trace, label='EXODUS')
ax1.plot(param_trace2, label='SLAYER')
ax1.set_ylabel('weight')
ax1.legend()
ax2.plot(grad_trace, label='EXODUS')
ax2.plot(grad_trace2, label='SLAYER')
ax2.set_ylabel('weight grad')
ax2.set_xlabel('epochs')
ax2.legend()

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(16, 8))

ax1.plot(param_trace, loss_trace)
ax2.plot(param_trace, grad_trace)
