In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import sinabs.activation as sina
import sinabs.slayer.layers as ssl

In [None]:
tau_mem = torch.tensor(30.)
alpha_mem = torch.exp(-1/tau_mem)
tau_syn = 30.
spike_threshold = 1
n_time_steps = 100
epochs = 3000
width_grad = 1
scale_grad = 1

In [None]:
act_fn = sina.ActivationFunction(
    spike_threshold=spike_threshold,
    spike_fn=sina.SingleSpike,
    reset_fn=sina.MembraneSubtract(),
    surrogate_grad_fn=sina.SingleExponential(
        grad_width=width_grad, 
        grad_scale=scale_grad
    ),
)

model = nn.Sequential(
    nn.Linear(1, 1, bias=False),
    ssl.ExpLeak(tau_leak=tau_syn),
    ssl.LIF(tau_mem=tau_mem, activation_fn=act_fn, norm_input=False),
).cuda()

In [None]:
torch.manual_seed(123)

input_spike_time = 10
target_spike_time = 30
input = torch.zeros(1, 100, 1)
target = torch.zeros_like(input)
input[:, input_spike_time] = 1 / alpha_mem
target[:, target_spike_time] = 1

v_mems = []
i_syns = []
input = input.cuda()
for step in range(n_time_steps):
    output = model(input[:, step:step+1])
    v_mems.append(model[2].v_mem)
    i_syns.append(model[1].v_mem)
v_mems = torch.stack(v_mems).squeeze().cpu().detach().numpy()
i_syns = torch.stack(i_syns).squeeze().cpu().detach().numpy()

In [None]:
input = input.cpu()

fig = plt.figure(figsize=(16, 5))
ax = fig.add_subplot(111)

ax.plot(input[0], label='input')
ax.plot(target[0], label='target')
ax.plot(v_mems, label='membrane potential')
ax.plot(i_syns, label='synaptic current')
# ax.plot(output.cpu().detach().numpy()[0], label='output')
ax.legend()
ax.set_xlabel('Time steps')
ax.set_ylabel('Activation')

In [None]:
criterion = nn.MSELoss()
optimiser = torch.optim.Adam(model.parameters(), 1e-2)

param_trace = []
loss_trace = []
grad_trace = []

input, target = input.cuda(), target.cuda()

for i in tqdm(range(epochs)):
    model[1].reset_states()
    model[2].reset_states()
    model[1].zero_grad()
    model[2].zero_grad()
    optimiser.zero_grad()

    output = model(input)

    regulariser = 0.01 * (output.sum() - 1) ** 2
    loss = criterion(output, target) * n_time_steps + regulariser
    loss_trace.append(loss)

    if loss == 0:
        break

    loss.backward()
    optimiser.step()

In [None]:
input = input

fig = plt.figure(figsize=(16, 5))
ax = fig.add_subplot(111)

ax.plot(input.cpu()[0], label='input')
ax.plot(target.cpu()[0], label='target')
ax.plot(output.squeeze().cpu().detach().numpy(), label='output')
ax.legend()
ax.set_xlabel('Time steps')
ax.set_ylabel('Activation')