In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from poisson_models import ExodusNet, SlayerNet

In [None]:
encoding_dim = 250
hidden_dim = 25
tau_mem = 20.
spike_threshold = 1
learning_rate = 1e-4
n_time_steps = 200
epochs = 18000
width_grad = 1
scale_grad = 1

In [None]:
torch.manual_seed(8935)

slayernet = SlayerNet(
    encoding_dim=encoding_dim,
    hidden_dim=hidden_dim,
    tau_mem=tau_mem,
    spike_threshold=spike_threshold,
    n_time_steps=n_time_steps,
    width_grad=width_grad,
    scale_grad=scale_grad,
).cuda()

exodusnet = ExodusNet(
    encoding_dim=encoding_dim,
    hidden_dim=hidden_dim,
    tau_mem=tau_mem,
    spike_threshold=spike_threshold,
    n_time_steps=n_time_steps,
    width_grad=width_grad,
    scale_grad=scale_grad,
).cuda()

exodusnet.lin1.weight.data = slayernet.lin1.weight.data
exodusnet.lin2.weight.data = slayernet.lin2.weight.data

criterion = nn.MSELoss()

input_spikes = (torch.rand(1, n_time_steps, encoding_dim, 1, 1) > 0.95).float().cuda()
target = torch.zeros((1, n_time_steps, 1, 1, 1)).float().cuda()
target[0, torch.randint(n_time_steps//5, n_time_steps, (4, )), 0] = 1

out1 = exodusnet(input_spikes)
out2 = slayernet(input_spikes)

assert out1.sum() == out2.sum() # mostly 0
assert torch.allclose(exodusnet.lif1.v_mem_recorded, slayernet.psp_post1, atol=1e-3)

In [None]:
def train(model, input_spikes, target):
    optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)

    out_spikes = []
    losses = []
    for epoch in tqdm(range(epochs)):
        optimiser.zero_grad()
        if model.spiking_layers:
            for layer in model.spiking_layers:
                layer.zero_grad()
        out = model(input_spikes)
        loss = criterion(out, target)
        loss.backward()
        optimiser.step()

        out_spikes.append(out.flatten())
        losses.append(loss.item())

    print(out.sum())
    output_spikes = torch.stack(out_spikes).detach().cpu().numpy().T
    return output_spikes, losses


In [None]:
exodus_output, exodus_losses = train(exodusnet, input_spikes, target)
slayer_output, slayer_losses = train(slayernet, input_spikes, target)

In [None]:
fig = plt.figure(figsize=(8, 8))
ax1 = fig.add_subplot(211)
for spike in np.where(target.to("cpu"))[1]:
    ax1.axhspan(spike, spike, xmin=(1-(epochs/(epochs+1000)))/2, xmax=1-((1-(epochs/(epochs+1000)))/2), alpha=0.8, linewidth=2, color='red')
ax1.scatter(np.where(exodus_output)[1], np.where(exodus_output)[0], s=1., label='EXODUS spikes')
ax1.scatter(np.where(slayer_output)[1], np.where(slayer_output)[0], s=1., label='SLAYER spikes')
ax1.set_ylabel("Spike output [t]")
ax1.set_xlim(-500, epochs+500)
ax1.legend()

ax2 = fig.add_subplot(212)
ax2.plot(exodus_losses, label='EXODUS loss')
ax2.plot(slayer_losses, label='SLAYER loss')
ax2.legend()
ax2.set_xlim(-500, epochs+500)
ax2.set_ylabel("Loss")
ax2.set_xlabel("Epochs")