# SNN Base Models Tonic + snnTorch
This notebook demonstrates:
- **NMNIST** with an MLP-style SNN
- **DVSGesture** with a 3-layer **Convolutional SNN**

> Make sure you have `torch`, `snntorch`, `tonic`, and `numpy` installed.


In [4]:
!pip install snntorch tonic --q



In [5]:

import sys, os, torch, numpy as np
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


Torch: 2.7.1+cu128
CUDA available: True


device(type='cuda')

In [6]:

# Install paths so we can import the helper scripts created for CLI
import sys
sys.path.append('/mnt/data/scripts')
from nmnist import MLPNet, load_nmnist, collate_fn as collate_nmnist
from dvs_conv3 import ConvSNN, load_dvs, collate_fn as collate_dvs
import torch
from torch.utils.data import DataLoader
import tonic
from tonic.transforms import ToFrame
import snntorch as snn


## NMNIST — SNN

In [None]:

# Params
save_to = "./data"
num_steps = 25
batch_size = 64

# Data
testset = load_nmnist(save_to, num_steps, split="test")
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, collate_fn=collate_nmnist)

# Model
C, H, W = 2, 34, 34
in_dim = C*H*W
net = MLPNet(in_dim=in_dim, hidden=300, out_dim=10).to(device)
net.eval()

# Eval
def eval_model(model, loader, device, num_steps):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            spk, _ = model(x, num_steps)
            votes = spk.sum(dim=0)
            pred = votes.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.numel()
    return correct / max(1, total)

acc = eval_model(net, testloader, device, num_steps)
acc


## DVSGesture — 3-layer Conv SNN

In [None]:

# Params
save_to = "./data"
num_steps = 25
batch_size = 4
resize = 64

# Data
testset = load_dvs(save_to, num_steps, split="test", resize=resize)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, collate_fn=collate_dvs)

# Model
net = ConvSNN(in_channels=2, img_size=resize).to(device)
net.eval()

# Eval
acc = eval_model(net, testloader, device, num_steps)
acc
