In [3]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0,), (1,)),
    ]
)

mnist_train = datasets.MNIST("./data", train=True, download=True, transform=transform)
mnist_test = datasets.MNIST("./data", train=False, download=True, transform=transform)

batch_size = 128
train_loader = DataLoader(
    mnist_train, batch_size=batch_size, shuffle=True, drop_last=True
)
test_loader = DataLoader(
    mnist_test, batch_size=batch_size, shuffle=True, drop_last=True
)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import snntorch as snn


class CNN(nn.Module):
    def __init__(self, beta=0.95):
        super(CNN, self).__init__()

        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 10)

    def forward(self, x0, filters=None):
        x = self.pool(F.relu(self.conv1(x0)))
        x = self.pool(F.relu(self.conv2(x)))

        x = x.flatten(1)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        return self.fc4(x)

        # x = F.softmax(self.fc3(x))

        # return x, filter


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
net = CNN().to(device)

In [4]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)

counter = 0

for epoch in range(1):
    train_batch = iter(train_loader)

    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        result = net(data)

        loss_val = torch.zeros((1), device=device)
        loss_val = loss(result, targets)

        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Print train/test loss/accuracy
        # if counter % 10 == 0:
        print(f"Iteration: {counter} \t Train Loss: {loss_val.item()}")
        counter += 1

        if counter == 100:
            break

Iteration: 0 	 Train Loss: 2.309497117996216
Iteration: 1 	 Train Loss: 2.301856279373169
Iteration: 2 	 Train Loss: 2.293412923812866
Iteration: 3 	 Train Loss: 2.290461301803589
Iteration: 4 	 Train Loss: 2.2786757946014404
Iteration: 5 	 Train Loss: 2.2800419330596924
Iteration: 6 	 Train Loss: 2.2743091583251953
Iteration: 7 	 Train Loss: 2.2552077770233154
Iteration: 8 	 Train Loss: 2.256417989730835
Iteration: 9 	 Train Loss: 2.234023332595825
Iteration: 10 	 Train Loss: 2.2234766483306885
Iteration: 11 	 Train Loss: 2.2056729793548584
Iteration: 12 	 Train Loss: 2.1975064277648926
Iteration: 13 	 Train Loss: 2.1773080825805664
Iteration: 14 	 Train Loss: 2.147063970565796
Iteration: 15 	 Train Loss: 2.140937089920044
Iteration: 16 	 Train Loss: 2.108513116836548
Iteration: 17 	 Train Loss: 2.0639944076538086
Iteration: 18 	 Train Loss: 2.049619436264038
Iteration: 19 	 Train Loss: 2.02984881401062
Iteration: 20 	 Train Loss: 2.0038208961486816
Iteration: 21 	 Train Loss: 1.95207

In [6]:
net.eval()

correct = 0
total = 0

counter = 0

with torch.no_grad():
    for data, targets in test_loader:
        data = data.to(device)
        targets = targets.to(device)

        outputs = net(data)

        _, predicted = torch.max(outputs, 1)

        total += targets.size(0)
        correct += (predicted == targets).sum().item()
        print(f"Iteration: {counter}")
        counter += 1
        if counter == 100:
            break

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

Iteration: 0
Iteration: 1
Iteration: 2
Iteration: 3
Iteration: 4
Iteration: 5
Iteration: 6
Iteration: 7
Iteration: 8
Iteration: 9
Iteration: 10
Iteration: 11
Iteration: 12
Iteration: 13
Iteration: 14
Iteration: 15
Iteration: 16
Iteration: 17
Iteration: 18
Iteration: 19
Iteration: 20
Iteration: 21
Iteration: 22
Iteration: 23
Iteration: 24
Iteration: 25
Iteration: 26
Iteration: 27
Iteration: 28
Iteration: 29
Iteration: 30
Iteration: 31
Iteration: 32
Iteration: 33
Iteration: 34
Iteration: 35
Iteration: 36
Iteration: 37
Iteration: 38
Iteration: 39
Iteration: 40
Iteration: 41
Iteration: 42
Iteration: 43
Iteration: 44
Iteration: 45
Iteration: 46
Iteration: 47
Iteration: 48
Iteration: 49
Iteration: 50
Iteration: 51
Iteration: 52
Iteration: 53
Iteration: 54
Iteration: 55
Iteration: 56
Iteration: 57
Iteration: 58
Iteration: 59
Iteration: 60
Iteration: 61
Iteration: 62
Iteration: 63
Iteration: 64
Iteration: 65
Iteration: 66
Iteration: 67
Iteration: 68
Iteration: 69
Iteration: 70
Iteration: 71
It