In [1]:
import copy, torch, torch.nn as nn, torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import snntorch as snn
from snntorch import spikegen
from tqdm import tqdm

**device**

In [2]:
device = torch.device("cuda")
torch.manual_seed(0)


<torch._C.Generator at 0x2636cd23c90>

**parameters**

In [3]:
num_in, num_hid, num_out = 784, 256, 10
beta, T, bs, epochs, lr = 0.9, 60, 128, 5, 1e-3

In [4]:
tr = transforms.Compose([transforms.ToTensor()])
train_ds = torchvision.datasets.FashionMNIST("data", True , download=True, transform=tr)
test_ds  = torchvision.datasets.FashionMNIST("data", False, download=True, transform=tr)
train_ld = DataLoader(train_ds, bs, True , num_workers=2, pin_memory=True)
test_ld  = DataLoader(test_ds , bs, False, num_workers=2, pin_memory=True)

**arch network**

In [5]:
class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1  = nn.Linear(num_in , num_hid, bias=False)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2  = nn.Linear(num_hid, num_out, bias=False)
        self.lif2 = snn.Leaky(beta=beta)
    def forward(self, x):
        x = x.to(self.fc1.weight.dtype)
        mem1 = self.lif1.init_leaky().to(dtype=x.dtype, device=x.device)
        mem2 = self.lif2.init_leaky().to(dtype=x.dtype, device=x.device)
        out  = []
        for t in range(x.size(0)):
            spk1, mem1 = self.lif1(self.fc1(x[t]), mem1)
            spk1 = spk1.to(dtype=self.fc2.weight.dtype)
            spk2, mem2 = self.lif2(self.fc2(spk1), mem2)
            out.append(spk2)
        return torch.stack(out)                      # [T,B,10]

In [7]:
net = SNN().to(device)
opt   = torch.optim.Adam(net.parameters(), lr=lr)
lossf = nn.CrossEntropyLoss().to(device)

**train phase**

In [12]:
def train(net,train_ld,epochs):
    for ep in range(1, epochs + 1):
        net.train()
        for imgs, lbls in tqdm(train_ld, desc=f"train {ep}/{epochs}"):
            imgs = imgs.to(device).view(imgs.size(0), -1) * 5
            lbls = lbls.to(device)
            spk  = spikegen.rate(imgs, T).to(device)
            loss = lossf(net(spk).sum(0), lbls)
            opt.zero_grad(); loss.backward(); opt.step()

In [15]:
train(net,train_ld,epochs)

train 1/5: 100%|██████████| 469/469 [00:44<00:00, 10.52it/s]
train 2/5: 100%|██████████| 469/469 [00:44<00:00, 10.56it/s]
train 3/5: 100%|██████████| 469/469 [00:45<00:00, 10.41it/s]
train 4/5: 100%|██████████| 469/469 [00:43<00:00, 10.69it/s]
train 5/5: 100%|██████████| 469/469 [00:43<00:00, 10.78it/s]


**test phase**

In [20]:
def test(model,loader):
    @torch.no_grad()
    def eval_once(model, loader, dtype):
        model.eval(); c = t = 0
        for imgs, lbls in loader:
            imgs = (imgs.to(device).view(imgs.size(0), -1) * 5).to(dtype)
            spk  = spikegen.rate(imgs, T).to(dtype).to(device)
            c   += (model(spk).sum(0).argmax(1) == lbls.to(device)).sum().item()
            t   += lbls.size(0)
        return 100 * c / t

    acc32 = eval_once(net, test_ld, torch.float32)


    net16 = SNN().to(device)
    net16.load_state_dict(net.state_dict())
    net16.half().eval()
    acc16 = eval_once(net16, test_ld, torch.float16)

    print(f"Fashion-MNIST accuracy  FP32: {acc32:.2f}%")
    print(f"Fashion-MNIST accuracy  FP16: {acc16:.2f}%")

In [21]:
test(net,test_ld)

Fashion-MNIST accuracy  FP32: 82.83%
Fashion-MNIST accuracy  FP16: 82.90%
