In [1]:
import tonic

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import tonic.transforms as transforms

In [3]:
sensor_size = tonic.datasets.NMNIST.sensor_size

In [18]:
frame_transform = transforms.Compose([
    transforms.Denoise(filter_time=10000),
    transforms.ToFrame(sensor_size=sensor_size,
                       time_window=100000)
])

In [19]:
train_dataset = tonic.datasets.NMNIST(save_to='/DATA/hwkang/', transform=frame_transform, train=True)
test_dataset = tonic.datasets.NMNIST(save_to='/DATA/hwkang/', transform=frame_transform, train=False)

In [20]:
from torch.utils.data import DataLoader
from tonic import DiskCachedDataset

In [21]:
cached_train_dataset = DiskCachedDataset(train_dataset, cache_path='./cache/nmnist/train')
cached_train_loader = DataLoader(cached_train_dataset, batch_size=8, collate_fn=tonic.collation.PadTensors(batch_first=False))

In [22]:
x, y = next(iter(cached_train_loader))
print(x.shape)

torch.Size([303, 8, 2, 34, 34])


In [None]:
batch_size = 128
train_loader = DataLoader(cached_train_dataset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors())

In [None]:
import torch
import torchvision

In [None]:
transform = tonic.transforms.Compose([torch.from_numpy,
                                      torchvision.transforms.RandomRotation([-10,10])])

cached_train_dataset = DiskCachedDataset(train_dataset, transform=transform, cache_path='./cache/nmnist/train')

# no augmentations for the testset
cached_test_dataset = DiskCachedDataset(test_dataset, cache_path='./cache/nmnist/test')

In [None]:
batch_size = 128
train_loader = DataLoader(cached_train_dataset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)
test_loader = DataLoader(cached_test_dataset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False))

In [None]:
event_tensor, labels = next(iter(train_loader))

In [None]:
print(event_tensor.shape)

In [None]:
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import spikeplot as splt
from snntorch import utils
import torch.nn as nn

In [None]:
#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
device = torch.device("cuda")

# neuron and simulation parameters
spike_grad = surrogate.atan()
beta = 0.5

#  Initialize Network
net = nn.Sequential(nn.Conv2d(2, 12, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(12, 32, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(32*5*5, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)

In [None]:
def forward_pass(net, data):
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(data.size(0)):  # data.size(0) = number of time steps
      spk_out, mem_out = net(data[step])
      spk_rec.append(spk_out)

  return torch.stack(spk_rec)

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=2e-2, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

In [None]:
num_epochs = 1
num_iters = 50

loss_hist = []
acc_hist = []

In [None]:
# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec = forward_pass(net, data)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

        acc = SF.accuracy_rate(spk_rec, targets)
        acc_hist.append(acc)
        print(f"Accuracy: {acc * 100:.2f}%\n")

        # training loop breaks after 50 iterations
        if i == num_iters:
          break