In [1]:
import torch
import torch.nn as nn
import tonic
from tonic import datasets, transforms
import os
import numpy as np
from tqdm.auto import tqdm
import sinabs.exodus.layers as sel
import sinabs
import sinabs.layers as sl
import sinabs.activation as sa

In [2]:
dt = 4000
encoding_dim = 100

transform = transforms.Compose([
    transforms.Downsample(time_factor=1, spatial_factor=encoding_dim/700),
    transforms.CropTime(max=1e6),
    transforms.ToSparseTensor(sensor_size=(encoding_dim, 1, 1), time_window=dt, include_incomplete=True),
])

In [3]:
dataset = tonic.datasets.SHD('./data', transform=transform, target_transform=torch.tensor)
print(f"This dataset has {len(dataset)} samples.")

This dataset has 8156 samples.


In [4]:
batch_size = 64

disk_cached_dataset = tonic.DiskCachedDataset(
            dataset=dataset,
            cache_path=os.path.join(f"cache/hsd/train/{encoding_dim}/{dt}"),
        )
disk_cached_dataloader = torch.utils.data.DataLoader(
    disk_cached_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=4, 
    collate_fn=tonic.collation.PadTensors(batch_first=True), 
    drop_last=True
)

gpu_cached_dataset = tonic.MemoryCachedDataset(
    dataset=dataset,
    device="cuda",
    transform=lambda x: x.to_dense(),
)
gpu_cached_dataloader = torch.utils.data.DataLoader(
    gpu_cached_dataset,
    batch_size=batch_size,
    collate_fn=tonic.collation.PadTensors(batch_first=True),
    drop_last=True,
    shuffle=True,
)

In [5]:
len(gpu_cached_dataloader)

127

In [6]:
# gpu_cached_dataset[0]

In [7]:
# next(iter(gpu_cached_dataloader))

In [8]:
class ExodusModel(nn.Sequential):
    def __init__(self):
        hidden_dim = 128

        super().__init__(
            nn.Linear(100, hidden_dim),
            sel.IAF(),
            nn.Linear(hidden_dim, hidden_dim),
            sel.IAF(),
            nn.Linear(hidden_dim, hidden_dim),
            sel.IAF(),
            nn.Linear(hidden_dim, 35),
        )

    def reset_states(self):
        for layer in self.children():
            if isinstance(layer, sl.StatefulLayer):
                layer.reset_states()

exodus_model = ExodusModel().cuda()

In [10]:
for data, target in tqdm(gpu_cached_dataloader):
    exodus_model.reset_states()
    exodus_model.zero_grad()
    output = exodus_model(data)
    output.sum().backward()

  0%|          | 0/127 [00:00<?, ?it/s]