In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from poisson_models import ExodusNet, SlayerNet, smooth

In [None]:
encoding_dim = 250
hidden_dim = 25
tau_mem = 20.
spike_threshold = 1
learning_rate = 1e-3
n_time_steps = 200
epochs = 3000
width_grad = 1
scale_grad = 1

In [None]:
torch.manual_seed(8935)

slayernet = SlayerNet(
    encoding_dim=encoding_dim,
    hidden_dim=hidden_dim,
    tau_mem=tau_mem,
    spike_threshold=spike_threshold,
    n_time_steps=n_time_steps,
    width_grad=width_grad,
    scale_grad=scale_grad,
).cuda()

exodusnet = ExodusNet(
    encoding_dim=encoding_dim,
    hidden_dim=hidden_dim,
    tau_mem=tau_mem,
    spike_threshold=spike_threshold,
    n_time_steps=n_time_steps,
    width_grad=width_grad,
    scale_grad=scale_grad,
).cuda()

exodusnet.lin1.weight.data = slayernet.lin1.weight.data.clone()
exodusnet.lin2.weight.data = slayernet.lin2.weight.data.clone()

criterion = nn.MSELoss()

input_spikes = (torch.rand(1, n_time_steps, encoding_dim, 1, 1) > 0.95).float().cuda()
target = torch.zeros((1, n_time_steps, 1, 1, 1)).float().cuda()
target[0, torch.randint(n_time_steps//5, n_time_steps, (4, )), 0] = 1

out1 = exodusnet(input_spikes)
out2 = slayernet(input_spikes)

assert out1.sum() == out2.sum() # mostly 0
assert torch.allclose(exodusnet.lif1.v_mem_recorded, slayernet.psp_post1, atol=1e-3)

In [None]:
def train(model, input_spikes, target):
    optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)

    out_spikes = []
    losses = []
    for epoch in tqdm(range(epochs)):
        optimiser.zero_grad()
        if model.spiking_layers:
            for layer in model.spiking_layers:
                layer.zero_grad()
        out = model(input_spikes)
        loss = criterion(out, target)
        loss.backward()
        optimiser.step()

        out_spikes.append(out.flatten())
        losses.append(loss.item())

    print(out.sum())
    output_spikes = torch.stack(out_spikes).detach().cpu().numpy().T
    return output_spikes, np.array(losses)


In [None]:
print(exodusnet.lin1.weight.mean())
print(slayernet.lin1.weight.mean())
slayer_output, slayer_losses = train(slayernet, input_spikes, target)
print(exodusnet.lin1.weight.mean())
print(slayernet.lin1.weight.mean())
exodus_output, exodus_losses = train(exodusnet, input_spikes, target)
print(exodusnet.lin1.weight.mean())
print(slayernet.lin1.weight.mean())


In [None]:
fig = plt.figure(figsize=(8, 8))
ax1 = fig.add_subplot(211)
xpad = epochs/50
xmin = -xpad
xmax = epochs+xpad
xmin_fraction = 1-(epochs/(epochs+xpad))
xmax_fraction = 1-xmin_fraction

from matplotlib.lines import Line2D
legend1 = [
    Line2D([0], [0], color='C0', lw=4, label='EXODUS spikes'),
    Line2D([0], [0], color='C1', lw=4, label='SLAYER spikes'),
    Line2D([0], [0], color='black', lw=4, label='Target'),
]

for spike in np.where(target.to("cpu"))[1]:
    ax1.axhspan(spike, spike, xmin=xmin_fraction, xmax=xmax_fraction, alpha=1, linewidth=3, color='black')
ax1.scatter(np.where(exodus_output)[1], np.where(exodus_output)[0], s=0.3, alpha=0.5)
ax1.scatter(np.where(slayer_output)[1], np.where(slayer_output)[0], s=0.3, alpha=0.5)
ax1.set_ylabel("Spike output [t]")
ax1.set_xlim(xmin, xmax)
ax1.legend(handles=legend1, loc='best')

legend2 = [
    Line2D([0], [0], color='C0', lw=4, label='EXODUS smoothed loss'),
    Line2D([0], [0], color='C1', lw=4, label='SLAYER smoothed loss'),
]
ax2 = fig.add_subplot(212)
ax2.plot(smooth(exodus_losses, window_len=30), label='EXODUS smoothed loss')
ax2.plot(smooth(slayer_losses, window_len=30), label='SLAYER smoothed loss')
ax2.legend()
ax2.set_xlim(xmin, xmax)
ax2.set_ylabel("Loss")
ax2.set_xlabel("Epochs")

plt.savefig("poisson_result.png")

In [None]:
handles, labels = ax1.get_legend_handles_labels()

In [None]:
ok = handles[0]

In [None]:
ok.