In [1]:
!pip install nirtorch norse snntorch tonic --quiet

In [2]:
import torch
import torch.nn as nn
import snntorch as snn
from torch.utils.data import DataLoader
import tonic.transforms as transforms
import norse
import tonic
from tonic import DiskCachedDataset
from tonic.collation import PadTensors

save_path = "./data/nmnist"

# Hyperparameters
dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 64

# Data Loading
frame_transform = transforms.Compose([
    transforms.ToFrame(
        sensor_size=tonic.datasets.NMNIST.sensor_size,
        time_window=60000
    )
])

trainset = DiskCachedDataset(
    tonic.datasets.NMNIST(train=True, save_to=save_path),
    transform=frame_transform,
    cache_path="./cache/nmnist/train"
)

testset = DiskCachedDataset(
    tonic.datasets.NMNIST(train=False, save_to=save_path),
    transform=frame_transform,
    cache_path="./cache/nmnist/test"
)


train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True,
                          collate_fn=PadTensors(batch_first=False))

test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False,
                         collate_fn=PadTensors(batch_first=False))

Downloading https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/1afc103f-8799-464a-a214-81bb9b1f9337 to ./data/nmnist/NMNIST/train.zip


  0%|          | 0/1011893601 [00:00<?, ?it/s]

Extracting ./data/nmnist/NMNIST/train.zip to ./data/nmnist/NMNIST
Downloading https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/a99d0fee-a95b-4231-ad22-988fdb0a2411 to ./data/nmnist/NMNIST/test.zip


  0%|          | 0/169674850 [00:00<?, ?it/s]

Extracting ./data/nmnist/NMNIST/test.zip to ./data/nmnist/NMNIST


In [3]:
num_inputs = 2 * 34 * 34
num_hidden = 128
num_outputs = 10

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        alpha1 = 0.5
        beta1 = 0.9 # global decay rate for all leaky neurons in layer 1
        beta2 = torch.rand((num_outputs), dtype = torch.float) # independent decay rate for each leaky neuron in layer 2: [0, 1)
        threshold2 = torch.ones_like(beta2) # threshold parameter must have the same shape as beta for NIR
        alpha2 = torch.ones_like(beta2)*0.9

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Synaptic(alpha=alpha1, beta=beta1) # not a learnable decay rate
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Synaptic(alpha=alpha2, beta=beta2, threshold=threshold2, learn_beta=True) # learnable decay rate

    def forward(self, x):
        syn1, mem1 = self.lif1.init_synaptic() # reset/init hidden states at t=0
        syn2, mem2 = self.lif2.init_synaptic() # reset/init hidden states at t=0

        spk2_rec = [] # record output spikes
        mem2_rec = [] # record output hidden states

        for step in range(x.size(0)): # loop over time
            cur1 = self.fc1(x[step].flatten(1))
            spk1, syn1, mem1 = self.lif1(cur1, syn1, mem1)
            cur2 = self.fc2(spk1)
            spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)

            spk2_rec.append(spk2) # record spikes
            mem2_rec.append(mem2) # record membrane

        return torch.stack(spk2_rec), torch.stack(mem2_rec)

model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()


In [None]:
data, targets = next(iter(train_loader))
data = data.to(device)
targets = targets.to(device)

spk_rec, mem_rec = model(data)
print(mem_rec.size())

In [None]:
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        model.eval()
        for data, targets in tqdm(test_loader):
            data, targets = data.to(device), targets.to(device)
            spk_rec, _ = model(data)
            spike_count = spk_rec.sum(0)
            _, max_spike = spike_count.max(1)

            # correct classes for one batch
            total += targets.size(0)
            correct += (max_spike == targets).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

In [None]:
from tqdm import tqdm

In [None]:
num_epochs = 1
def train(epoch):
    model.train()
    for batch_idx, (data, targets) in tqdm(enumerate(train_loader)):
        data, targets = data.to(device), targets.to(device)

        spk_rec, mem_rec = model(data)
        # Sum spikes over time
        loss = torch.zeros((1), dtype=dtype, device=device)
        for step in range(mem_rec.size(0)):
            loss += criterion(mem_rec[step], targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
for epoch in range(num_epochs):
    train(epoch)
    test()

In [None]:
from snntorch.export_nir import export_to_nir
import nir

In [None]:
nir_graph = export_to_nir(model.cpu(), data.cpu())
nir.write("nir_model.nir", nir_graph)

In [None]:
norse_model = norse.torch.from_nir(nir_graph, dt=0.0001)

In [None]:
def apply(data):
    """
    apply an input data batch to the norse model
    """
    state = None
    hid_rec = []
    out = []

    for i, t in enumerate(data):
        z, state = norse_model(t.flatten(1), state)
        out.append(z)
        hid_rec.append(state)
    spk_out = torch.stack(out)
    return spk_out, hid_rec

In [None]:
def measure_accuracy2(model, dataloader):
  with torch.no_grad():
    running_length = 0
    running_accuracy = 0

    for data, targets in iter(dataloader):
      spk_rec, _ = model(data)
      spike_count = spk_rec.sum(0)
      _, max_spike = spike_count.max(1)

      # correct classes for one batch
      num_correct = (max_spike == targets).sum()

      # total accuracy
      running_length += len(targets)
      running_accuracy += num_correct

    accuracy = (running_accuracy / running_length)

    return accuracy.item()

In [None]:
measure_accuracy2(apply, test_loader)