In [None]:
import torch
import torch.nn as nn
import tonic
from tonic import datasets, transforms
import os
import numpy as np
from tqdm.auto import tqdm
import sinabs.exodus.layers as sel
import sinabs
import sinabs.layers as sl
import sinabs.activation as sa

In [None]:
dt = 4000
encoding_dim = 100

class ToRaster():
    def __init__(self, encoding_dim):
        self.encoding_dim = encoding_dim

    def __call__(self, events):
        # tensor has dimensions (time_steps, encoding_dim)
        tensor = np.zeros((events["t"].max()+1, self.encoding_dim), dtype=int)
        np.add.at(tensor, (events["t"], events["x"]), 1)
        return tensor[:250,:]

transform = transforms.Compose([
    transforms.Downsample(time_factor=1/dt, spatial_factor=encoding_dim/700),
    ToRaster(encoding_dim),
])

In [None]:
dataset = tonic.datasets.SSC('./data', transform=transform)
print(f"This dataset has {len(dataset)} samples.")

In [None]:
batch_size = 128

cached_dataset = tonic.DiskCachedDataset(
            dataset=dataset,
            cache_path=os.path.join(f"cache/ssc/train/{encoding_dim}/{dt}"),
        )
cached_dataloader = torch.utils.data.DataLoader(
    cached_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=4, 
    collate_fn=tonic.collation.PadTensors(batch_first=True), 
    drop_last=True
)

In [None]:
data_list = []
target_list = []
for data, targets in tqdm(cached_dataloader):
    data_list.append(data.to_sparse().coalesce().cuda())
    target_list.append(targets.byte().cuda())

sparse_tensor_dataset = list(zip(data_list, target_list))

In [None]:
for data, target in tqdm(sparse_tensor_dataset):
    data.to_dense()


In [None]:
kw_args = dict(
            spike_fn=sa.SingleSpike,
            reset_fn=sa.MembraneSubtract(),
            surrogate_grad_fn=sa.SingleExponential(),
        )

hidden_dim = 128

model = nn.Sequential(
    nn.Linear(100, hidden_dim),
    sel.IAF(**kw_args),
    nn.Linear(hidden_dim, hidden_dim),
    sel.IAF(**kw_args),
    nn.Linear(hidden_dim, hidden_dim),
    sel.IAF(**kw_args),
    nn.Linear(hidden_dim, 35),
).cuda()

In [None]:
dummy_model = nn.Sequential(
    torch.nn.Linear(100, hidden_dim),
    torch.nn.Linear(hidden_dim, hidden_dim),
    torch.nn.Linear(hidden_dim, hidden_dim),
    torch.nn.Linear(hidden_dim, 100),
).cuda()

In [None]:
def reset_states(model):
    for layer in model.children():
        if isinstance(layer, sl.StatefulLayer):
            layer.reset_states()

In [None]:
data.shape

In [None]:
for data, target in tqdm(sparse_tensor_dataset):
    dense = data.to_dense()
    dummy_model.zero_grad()
    output = dense[:, 0]
    for step in range(dense.shape[1]):
        output = dummy_model(dense[:, step] + output)
    output.sum().backward()

In [None]:
for data, target in tqdm(sparse_tensor_dataset):
    reset_states(model)
    dense = data.to_dense()
    model.zero_grad()
    output = model(dense)
    output.sum().backward()