In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from poisson_models import ExodusNet, SlayerNet, smooth
from matplotlib.lines import Line2D

In [None]:
encoding_dim = 250
hidden_dim = 25
spike_threshold = 1
n_time_steps = 200
epochs = 3000
width_grad = 1

In [None]:
def train(model, input_spikes, target, lr):
    criterion = nn.MSELoss()
    optimiser = torch.optim.Adam(model.parameters(), lr)
    out_spikes = []
    losses = []
    for epoch in range(epochs):
        optimiser.zero_grad()
        if model.spiking_layers:
            for layer in model.spiking_layers:
                layer.zero_grad()
        out = model(input_spikes)
        loss = criterion(out, target)
        loss.backward()
        optimiser.step()

        out_spikes.append(out.flatten())
        losses.append(loss.item())

    output_spikes = torch.stack(out_spikes).detach().cpu().numpy().T
    return output_spikes, np.array(losses)

In [None]:
torch.manual_seed(8935)

n_experiments = 5
taus = np.logspace(2, 7, num=3, endpoint=True, base=2)
scales = [1] # np.logspace(-1, 1, num=3, endpoint=True, base=10)
learning_rates = np.logspace(-4, -2, num=3, endpoint=True, base=10)

print(taus)
print(scales)
print(learning_rates)

In [None]:
exodus_param_losses = []
slayer_param_losses = []
for tau_mem in tqdm(taus):
    for scale_grad in scales:
        for lr in learning_rates:
            exodus_all_losses = np.zeros((epochs))
            slayer_all_losses = np.zeros((epochs))
            for i in range(n_experiments):
                slayernet = SlayerNet(
                    encoding_dim=encoding_dim,
                    hidden_dim=hidden_dim,
                    tau_mem=tau_mem,
                    spike_threshold=spike_threshold,
                    n_time_steps=n_time_steps,
                    width_grad=width_grad,
                    scale_grad=scale_grad,
                ).cuda()

                exodusnet = ExodusNet(
                    encoding_dim=encoding_dim,
                    hidden_dim=hidden_dim,
                    tau_mem=tau_mem,
                    spike_threshold=spike_threshold,
                    n_time_steps=n_time_steps,
                    width_grad=width_grad,
                    scale_grad=scale_grad,
                ).cuda()
                
                exodusnet.lin1.weight.data = slayernet.lin1.weight.data.clone()
                exodusnet.lin2.weight.data = slayernet.lin2.weight.data.clone()

                input_spikes = (torch.rand(1, n_time_steps, encoding_dim, 1, 1) > 0.9).float().cuda()
                target = torch.zeros((1, n_time_steps, 1, 1, 1)).float().cuda()
                target[0, torch.randint(n_time_steps//5, n_time_steps, (4, )), 0] = 1
                
                out1 = exodusnet(input_spikes)
                out2 = slayernet(input_spikes)

                # assert torch.allclose(exodusnet.lif1.v_mem_recorded, slayernet.psp_post1, atol=1e-3)

                slayer_output, slayer_losses = train(slayernet, input_spikes, target, lr)
                exodus_output, exodus_losses = train(exodusnet, input_spikes, target, lr)

                exodus_all_losses += exodus_losses
                slayer_all_losses += slayer_losses
            exodus_all_losses /= n_experiments
            slayer_all_losses /= n_experiments
            exodus_param_losses.append([exodus_all_losses, {'tau_mem': tau_mem, 'scale_grad': scale_grad, 'lr': lr}])
            slayer_param_losses.append([slayer_all_losses, {'tau_mem': tau_mem, 'scale_grad': scale_grad, 'lr': lr}])

In [None]:
fig = plt.figure(figsize=(8, 8))
ax1 = fig.add_subplot(211)
xpad = epochs/50
xmin = -xpad
xmax = epochs+xpad
xmin_fraction = 1-(epochs/(epochs+xpad))
xmax_fraction = 1-xmin_fraction

legend1 = [
    Line2D([0], [0], color='C0', lw=4, label='EXODUS spikes'),
    Line2D([0], [0], color='C1', lw=4, label='SLAYER spikes'),
    Line2D([0], [0], color='black', lw=4, label='Target'),
]

for spike in np.where(target.to("cpu"))[1]:
    ax1.axhspan(spike, spike, xmin=xmin_fraction, xmax=xmax_fraction, alpha=1, linewidth=3, color='black')
ax1.scatter(np.where(exodus_output)[1], np.where(exodus_output)[0], s=0.3, alpha=0.5)
ax1.scatter(np.where(slayer_output)[1], np.where(slayer_output)[0], s=0.3, alpha=0.5)
ax1.set_ylabel("Spike output [t]")
ax1.set_xlim(xmin, xmax)
ax1.legend(handles=legend1, loc='best')

legend2 = [
    Line2D([0], [0], color='C0', lw=4, label='EXODUS smoothed loss'),
    Line2D([0], [0], color='C1', lw=4, label='SLAYER smoothed loss'),
]
ax2 = fig.add_subplot(212)
ax2.plot(smooth(exodus_losses, window_len=30))
ax2.plot(smooth(slayer_losses, window_len=30))
# ax2.plot(smooth(exodus_param_losses[-1][0], window_len=30))
# ax2.plot(smooth(slayer_param_losses[-1][0], window_len=30))
ax2.legend(handles=legend2)
ax2.set_xlim(xmin, xmax)
ax2.set_ylabel("Loss")
ax2.set_xlabel("Epochs")

plt.savefig("poisson_result.png")

In [None]:
fig = plt.figure(figsize=(8, 8))
ax1 = fig.add_subplot(211)

for loss, params in exodus_param_losses:
    ax1.plot(smooth(loss, window_len=30), label=f"tau_mem: {round(params['tau_mem'], 1)}, scale_grad: {params['scale_grad']}, lr: {params['lr']}")

ax1.legend()
# ax1.set_ylim(top=0.5)
ax1.set_ylabel('loss')
ax1.set_title("EXODUS")

ax2 = fig.add_subplot(212)
for loss, params in slayer_param_losses:
    ax2.plot(smooth(loss, window_len=30), label=f"tau_mem: {round(params['tau_mem'], 1)}, scale_grad: {params['scale_grad']}, lr: {params['lr']}")

ax2.legend()
# ax2.set_ylim(top=0.5)
ax2.set_ylabel('loss')
ax2.set_xlabel('epochs')
ax2.set_title("SLAYER")

plt.savefig("poisson_param_sweep.png")

In [None]:
fig = plt.figure(figsize=(6, 6))
ax1 = fig.add_subplot(211)
ax1.set_xscale('log')
ax1.set_yscale('log')
# ax1.set_ylabel('scale_grad')
ax1.set_ylabel('learning rate')
ax1.set_title("EXODUS")

exodus_loss_integrals = np.array([loss.sum() for (loss, param) in exodus_param_losses])
slayer_loss_integrals = np.array([loss.sum() for (loss, param) in slayer_param_losses])

vmax = np.maximum(exodus_loss_integrals.max(), slayer_loss_integrals.max())

for (loss1, params), (loss2, params) in zip(exodus_param_losses, slayer_param_losses):
    # ax1.scatter(params['tau_mem'], params['scale_grad'], c=loss1.sum(), s=100., vmin=0, vmax=vmax)
    ax1.scatter(params['tau_mem'], params['lr'], c=loss1.sum(), s=150., vmin=0, vmax=vmax)

ax2 = fig.add_subplot(212)
ax2.set_xscale('log')
ax2.set_yscale('log')
ax2.set_xlabel('tau membrane')
# ax2.set_ylabel('scale_grad')
ax2.set_ylabel('learning rate')
ax2.set_title("SLAYER")

for (loss1, params), (loss2, params) in zip(exodus_param_losses, slayer_param_losses):
    # ax = ax2.scatter(params['tau_mem'], params['scale_grad'], c=loss2.sum(), s=100., vmin=0, vmax=vmax)
    ax = ax2.scatter(params['tau_mem'], params['lr'], c=loss2.sum(), s=150., vmin=0, vmax=vmax)

print(loss2)
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(ax, cax=cbar_ax)

plt.savefig("poisson_param_sweep_loss_integral.png")

In [None]:
labels = ['tau=4', 'tau=22', 'tau=128']

x = np.arange(len(labels))  # the label locations
width = 0.35  # the width of the bars

fig, (ax1, ax2) = plt.subplots(2, 1)

rects1 = ax1.bar(x - width/2, exodus_loss_integrals[2::3], width, label='EXODUS')
rects2 = ax1.bar(x + width/2, slayer_loss_integrals[2::3], width, label='SLAYER')

rects1 = ax2.bar(x - width/2, exodus_loss_integrals[2::3], width, label='EXODUS')
rects2 = ax2.bar(x + width/2, slayer_loss_integrals[2::3], width, label='SLAYER')

# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel('Sum of loss')
ax.set_title('Poisson task with learning rate = 1e-2')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()

fig.tight_layout()

plt.show()

In [None]:
exodus_loss_integrals