In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import sinabs.activation as sa
from sinabs.slayer.layers import LIF
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from slayer_layer import SlayerLayer

In [None]:
encoding_dim = 250
hidden_dim = 25
tau_mem = 20.
spike_threshold = 0.01
learning_rate = 1e-5
n_time_steps = 200
epochs = 12000

In [None]:
class SlayerNet(torch.nn.Module):
    def __init__(self, n_time_steps, encoding_dim, hidden_dim, tau_mem, spike_threshold, width_grad, scale_grad):
        super().__init__()

        neuron_params = {
            "type": "CUBALIF",
            "theta": spike_threshold,
            "tauSr": tau_mem,
            "tauRef": tau_mem,
            "scaleRef": 1,
            "tauRho": width_grad,
            "scaleRho": scale_grad,
        }
        sim_params = {"Ts": 1.0, "tSample": n_time_steps}
        self.slayer = SlayerLayer(neuron_params, sim_params)
        self.lin1 = self.slayer.dense(encoding_dim, hidden_dim)
        self.lin2 = self.slayer.dense(hidden_dim, 1)

    def forward(self, x):
        self.weighted = self.lin1(x)
        psp = self.slayer.psp(self.weighted)
        self.psp_pre = psp.clone()
        out = self.slayer.spike(psp)
        self.psp_post = psp.clone()

        return out

In [None]:


act_fn = sa.ActivationFunction(
            spike_threshold=spike_threshold,
            spike_fn=sa.SingleSpike,
            reset_fn=sa.MembraneSubtract(),
            surrogate_grad_fn=sa.SingleExponential(),
        )

exodus_model = nn.Sequential(
                nn.Linear(encoding_dim, hidden_dim, bias=False),
                LIF(tau_mem=tau_mem, activation_fn=act_fn),
                nn.Linear(hidden_dim, 1, bias=False),
                LIF(tau_mem=tau_mem, activation_fn=act_fn),
            ).cuda()



torch.manual_seed(12345)
criterion = nn.MSELoss()
optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)

input_spikes = (torch.rand(1, n_time_steps, encoding_dim) > 0.95).float().cuda()
target = torch.zeros((1, n_time_steps, 1)).float().cuda()
target[0, torch.randint(n_time_steps//5, n_time_steps, (4, )), 0] = 1

out_spikes = []
losses = []
for epoch in tqdm(range(epochs)):
    optimiser.zero_grad()
    model[1].zero_grad()
    model[3].zero_grad()
    out = model(input_spikes)
    loss = criterion(out, target)
    loss.backward()
    optimiser.step()

    out_spikes.append(out.flatten())
    losses.append(loss.item())


print(out.sum())
input_spikes = input_spikes.detach().cpu().int().squeeze(0).numpy()
output_spikes = torch.stack(out_spikes).detach().cpu().numpy().T
target_spikes = target.detach().cpu().int().squeeze(0).numpy()


In [None]:
fig = plt.figure(figsize=(12, 6))
ax1 = fig.add_subplot(121)
for spike in np.where(target_spikes)[0]:
    ax1.axhspan(spike, spike, alpha=0.8, linewidth=2, color='red')
ax1.scatter(np.where(output_spikes)[1], np.where(output_spikes)[0], s=1.)

ax2 = fig.add_subplot(122)
ax2.plot(losses)