In [None]:
import torch
import torch.nn as nn
import tonic
from tonic import datasets, transforms
import os
import numpy as np
from tqdm.auto import tqdm
import sinabs.exodus.layers as sel
import sinabs
import sinabs.layers as sl
import sinabs.activation as sa

In [None]:
dt = 4000
encoding_dim = 100

dense_transform = transforms.Compose([
    transforms.Downsample(spatial_factor=encoding_dim/700),
    transforms.CropTime(max=1e6),
    transforms.ToFrame(sensor_size=(encoding_dim, 1, 1), time_window=dt, include_incomplete=True),
])

sparse_transform = transforms.Compose([
    transforms.Downsample(spatial_factor=encoding_dim/700),
    transforms.CropTime(max=1e6),
    transforms.ToSparseTensor(sensor_size=(encoding_dim, 1, 1), time_window=dt, include_incomplete=True),
])

In [None]:
dense_dataset = tonic.datasets.SSC('./data', transform=dense_transform)
sparse_dataset = tonic.datasets.SSC('./data', transform=sparse_transform, target_transform=torch.tensor)
print(f"This dataset has {len(sparse_dataset)} samples.")

In [None]:
dataloader = torch.utils.data.DataLoader(
    dense_dataset, 
    batch_size=64, 
    shuffle=True, 
    num_workers=8, 
    collate_fn=tonic.collation.PadTensors(batch_first=True), 
    drop_last=True,
    pin_memory=True,
)

In [None]:
batch_size = 64

disk_cached_dataset = tonic.DiskCachedDataset(
            dataset=dense_dataset,
            cache_path=os.path.join(f"cache/{dense_dataset.__class__.__name__}/train/{encoding_dim}/{dt}"),
        )
disk_cached_dataloader = torch.utils.data.DataLoader(
    disk_cached_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=4, 
    collate_fn=tonic.collation.PadTensors(batch_first=True), 
    drop_last=True
)

# sparse_dataloader = torch.utils.data.DataLoader(
#     sparse_dataset,
#     num_workers=4,
#     batch_size=None,
# )
cpu_cached_dataset = tonic.MemoryCachedDataset(
    dataset=sparse_dataset,
)
cpu_cached_dataloader = torch.utils.data.DataLoader(
    cpu_cached_dataset,
    batch_size=batch_size,
    collate_fn=tonic.collation.PadTensors(batch_first=True),
    # num_workers=4, 
    drop_last=True,
    shuffle=True,
)

gpu_cached_dataset = tonic.MemoryCachedDataset(
    dataset=sparse_dataset,
    device="cuda",
    transform=lambda x: x.to_dense(),
)
gpu_cached_dataloader = torch.utils.data.DataLoader(
    gpu_cached_dataset,
    batch_size=batch_size,
    collate_fn=tonic.collation.PadTensors(batch_first=True),
    drop_last=True,
    shuffle=True,
)

In [None]:
class SNN(nn.Sequential):
    def __init__(self, backend='exodus'):
        hidden_dim = 128

        super().__init__(
            nn.Linear(100, hidden_dim),
            sel.IAF() if backend=='exodus' else sl.IAF(),
            nn.Linear(hidden_dim, hidden_dim),
            sel.IAF() if backend=='exodus' else sl.IAF(),
            nn.Linear(hidden_dim, hidden_dim),
            sel.IAF() if backend=='exodus' else sl.IAF(),
            nn.Linear(hidden_dim, 35),
        )

    def reset_states(self):
        for layer in self.children():
            if isinstance(layer, sl.StatefulLayer):
                layer.reset_states()

exodus_model = SNN().cuda()
# sinabs_model = SNN(backend='sinabs').cuda()
model = exodus_model


In [None]:
for data, target in tqdm(disk_cached_dataloader):
    data = data.cuda()
    target = target.cuda()
    model.reset_states()
    model.zero_grad()
    output = model(data)
    output.sum().backward()

In [None]:
for data, target in tqdm(dataloader):
    data = data.cuda()
    target = target.cuda()
    model.reset_states()
    model.zero_grad()
    output = model(data)
    output.sum().backward()

In [None]:
for data, target in tqdm(cpu_cached_dataloader):
    data = data.cuda()
    target = target.cuda()
    data = data.to_dense().float()
    model.reset_states()
    model.zero_grad()
    output = model(data)
    output.sum().backward()

In [None]:
for data, target in tqdm(gpu_cached_dataloader):
    model.reset_states()
    model.zero_grad()
    output = model(data)
    output.sum().backward()