In [21]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0,), (1,)),
    ]
)

mnist_train = datasets.CIFAR10("./data", train=True, download=True, transform=transform)
mnist_test = datasets.CIFAR10("./data", train=False, download=True, transform=transform)

batch_size = 128
train_loader = DataLoader(
    mnist_train, batch_size=batch_size, shuffle=True, drop_last=True
)
test_loader = DataLoader(
    mnist_test, batch_size=batch_size, shuffle=True, drop_last=True
)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import snntorch as snn


class CNNDenseSNN(nn.Module):
    def __init__(self, beta=0.9):
        super(CNNDenseSNN, self).__init__()

        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 3 * 3 * 3 * 3)
        self.fc4 = nn.Linear(128, 10)

        self.lif1 = snn.Leaky(beta=beta)

    def forward(self, x0, filters=None):
        mem1 = self.lif1.init_leaky()

        for _ in range(10):
            if filters is not None:
                outputs = []
                for i in range(x0.shape[0]):
                    output = F.conv2d(
                        x0[i].unsqueeze(0),
                        filters[i],
                        padding=1,
                    )
                    outputs.append(output.reshape(3, 32, 32))

                x0 = F.relu(x0 + torch.stack(outputs, dim=0) / 24)

            x = self.pool(F.relu(self.conv1(x0)))
            x = self.pool(F.relu(self.conv2(x)))

            x = x.flatten(1)

            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            filter_x = F.relu(self.fc3(x))

            spk1, mem1 = self.lif1(filter_x, mem1)
            filters = spk1.view(-1, 3, 3, 3, 3)

        return self.fc4(x)


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
net = CNNDenseSNN().to(device)

In [13]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)

counter = 0

for epoch in range(1):
    train_batch = iter(train_loader)

    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        result = net(data)

        loss_val = torch.zeros((1), device=device)
        loss_val = loss(result, targets)

        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Print train/test loss/accuracy
        # if counter % 10 == 0:
        print(f"Iteration: {counter} \t Train Loss: {loss_val.item()}")
        counter += 1

        if counter == 100:
            break

Iteration: 0 	 Train Loss: 2.3092870712280273
Iteration: 1 	 Train Loss: 2.3108346462249756
Iteration: 2 	 Train Loss: 2.299056053161621
Iteration: 3 	 Train Loss: 2.2966506481170654
Iteration: 4 	 Train Loss: 2.320936679840088
Iteration: 5 	 Train Loss: 2.291184663772583
Iteration: 6 	 Train Loss: 2.2873988151550293
Iteration: 7 	 Train Loss: 2.290024518966675
Iteration: 8 	 Train Loss: 2.2734811305999756
Iteration: 9 	 Train Loss: 2.2841196060180664
Iteration: 10 	 Train Loss: 2.2756154537200928
Iteration: 11 	 Train Loss: 2.2098610401153564
Iteration: 12 	 Train Loss: 2.253403902053833
Iteration: 13 	 Train Loss: 2.2462620735168457
Iteration: 14 	 Train Loss: 2.1922690868377686
Iteration: 15 	 Train Loss: 2.26705265045166
Iteration: 16 	 Train Loss: 2.218339681625366
Iteration: 17 	 Train Loss: 2.1807172298431396
Iteration: 18 	 Train Loss: 2.2445011138916016
Iteration: 19 	 Train Loss: 2.165106773376465
Iteration: 20 	 Train Loss: 2.2003538608551025
Iteration: 21 	 Train Loss: 2.12

In [91]:
net.eval()

correct = 0
total = 0

counter = 0

with torch.no_grad():
    for data, targets in test_loader:
        data = data.to(device)
        targets = targets.to(device)

        outputs = net(data)

        _, predicted = torch.max(outputs, 1)

        total += targets.size(0)
        correct += (predicted == targets).sum().item()
        print(f"Iteration: {counter}")
        counter += 1
        if counter == 100:
            break

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

Iteration: 0
Iteration: 1
Iteration: 2
Iteration: 3
Iteration: 4
Iteration: 5
Iteration: 6
Iteration: 7
Iteration: 8
Iteration: 9
Iteration: 10
Iteration: 11
Iteration: 12
Iteration: 13
Iteration: 14
Iteration: 15
Iteration: 16
Iteration: 17
Iteration: 18
Iteration: 19
Iteration: 20
Iteration: 21
Iteration: 22
Iteration: 23
Iteration: 24
Iteration: 25
Iteration: 26
Iteration: 27
Iteration: 28
Iteration: 29
Iteration: 30
Iteration: 31
Iteration: 32
Iteration: 33
Iteration: 34
Iteration: 35
Iteration: 36
Iteration: 37
Iteration: 38
Iteration: 39
Iteration: 40
Iteration: 41
Iteration: 42
Iteration: 43
Iteration: 44
Iteration: 45
Iteration: 46
Iteration: 47
Iteration: 48
Iteration: 49
Iteration: 50
Iteration: 51
Iteration: 52
Iteration: 53
Iteration: 54
Iteration: 55
Iteration: 56
Iteration: 57
Iteration: 58
Iteration: 59
Iteration: 60
Iteration: 61
Iteration: 62
Iteration: 63
Iteration: 64
Iteration: 65
Iteration: 66
Iteration: 67
Iteration: 68
Iteration: 69
Iteration: 70
Iteration: 71
It

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import snntorch as snn


class CNNDenseSNN(nn.Module):
    def __init__(self, beta=0.9):
        super(CNNDenseSNN, self).__init__()

        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 3 * 3 * 3 * 3)
        self.fc4 = nn.Linear(128, 10)

        self.lif1 = snn.Leaky(beta=beta)

    def forward(self, x0, filters=None):
        mem1 = self.lif1.init_leaky()

        for j in range(10):
            if filters is not None:
                outputs = []
                for i in range(x0.shape[0]):
                    output = F.conv2d(
                        x0[i].unsqueeze(0),
                        filters[i],
                        padding=1,
                    )
                    outputs.append(output.reshape(3, 32, 32))

                x0 = (
                    x0
                    + pow(-1, j)
                    * torch.exp(torch.tensor((-j / 24)))
                    * torch.stack(outputs, dim=0)
                    / 24
                )

            x = self.pool(F.relu(self.conv1(x0)))
            x = self.pool(F.relu(self.conv2(x)))

            x = x.flatten(1)

            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            filter_x = F.relu(self.fc3(x))

            spk1, mem1 = self.lif1(filter_x, mem1)
            filters = spk1.view(-1, 3, 3, 3, 3)

        return self.fc4(x)


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
net = CNNDenseSNN().to(device)

In [42]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)

counter = 0

for epoch in range(1):
    train_batch = iter(train_loader)

    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        result = net(data)

        loss_val = torch.zeros((1), device=device)
        loss_val = loss(result, targets)

        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Print train/test loss/accuracy
        # if counter % 10 == 0:
        print(f"Iteration: {counter} \t Train Loss: {loss_val.item()}")
        counter += 1

        if counter == 100:
            break

Iteration: 0 	 Train Loss: 2.302978038787842
Iteration: 1 	 Train Loss: 2.301532506942749
Iteration: 2 	 Train Loss: 2.3033957481384277
Iteration: 3 	 Train Loss: 2.2999377250671387
Iteration: 4 	 Train Loss: 2.291844129562378
Iteration: 5 	 Train Loss: 2.293843984603882
Iteration: 6 	 Train Loss: 2.289381980895996
Iteration: 7 	 Train Loss: 2.273117780685425
Iteration: 8 	 Train Loss: 2.2849888801574707
Iteration: 9 	 Train Loss: 2.2770159244537354
Iteration: 10 	 Train Loss: 2.255248546600342
Iteration: 11 	 Train Loss: 2.2611918449401855
Iteration: 12 	 Train Loss: 2.263638734817505
Iteration: 13 	 Train Loss: 2.2289857864379883
Iteration: 14 	 Train Loss: 2.240945339202881
Iteration: 15 	 Train Loss: 2.242457866668701
Iteration: 16 	 Train Loss: 2.1963462829589844
Iteration: 17 	 Train Loss: 2.188034772872925
Iteration: 18 	 Train Loss: 2.18349027633667
Iteration: 19 	 Train Loss: 2.199744939804077
Iteration: 20 	 Train Loss: 2.2231099605560303
Iteration: 21 	 Train Loss: 2.0947494

In [43]:
net.eval()

correct = 0
total = 0

counter = 0

with torch.no_grad():
    for data, targets in test_loader:
        data = data.to(device)
        targets = targets.to(device)

        outputs = net(data)

        _, predicted = torch.max(outputs, 1)

        total += targets.size(0)
        correct += (predicted == targets).sum().item()
        print(f"Iteration: {counter}")
        counter += 1
        if counter == 100:
            break

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

Iteration: 0
Iteration: 1
Iteration: 2
Iteration: 3
Iteration: 4
Iteration: 5
Iteration: 6
Iteration: 7
Iteration: 8
Iteration: 9
Iteration: 10
Iteration: 11
Iteration: 12
Iteration: 13
Iteration: 14
Iteration: 15
Iteration: 16
Iteration: 17
Iteration: 18
Iteration: 19
Iteration: 20
Iteration: 21
Iteration: 22
Iteration: 23
Iteration: 24
Iteration: 25
Iteration: 26
Iteration: 27
Iteration: 28
Iteration: 29
Iteration: 30
Iteration: 31
Iteration: 32
Iteration: 33
Iteration: 34
Iteration: 35
Iteration: 36
Iteration: 37
Iteration: 38
Iteration: 39
Iteration: 40
Iteration: 41
Iteration: 42
Iteration: 43
Iteration: 44
Iteration: 45
Iteration: 46
Iteration: 47
Iteration: 48
Iteration: 49
Iteration: 50
Iteration: 51
Iteration: 52
Iteration: 53
Iteration: 54
Iteration: 55
Iteration: 56
Iteration: 57
Iteration: 58
Iteration: 59
Iteration: 60
Iteration: 61
Iteration: 62
Iteration: 63
Iteration: 64
Iteration: 65
Iteration: 66
Iteration: 67
Iteration: 68
Iteration: 69
Iteration: 70
Iteration: 71
It