In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0,), (1,)),
    ]
)

mnist_train = datasets.MNIST("./data", train=True, download=True, transform=transform)
mnist_test = datasets.MNIST("./data", train=False, download=True, transform=transform)

batch_size = 128
train_loader = DataLoader(
    mnist_train, batch_size=batch_size, shuffle=True, drop_last=True
)
test_loader = DataLoader(
    mnist_test, batch_size=batch_size, shuffle=True, drop_last=True
)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import snntorch as snn


class CNNDenseSNN(nn.Module):
    def __init__(self, beta=0.95):
        super(CNNDenseSNN, self).__init__()

        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 9 * 9)
        self.fc4 = nn.Linear(64, 10)

        self.lif1 = snn.Leaky(beta=beta)

    def forward(self, x0, filters=None):
        mem1 = self.lif1.init_leaky()

        for _ in range(10):
            if filters is not None:
                outputs = []
                for i in range(x0.shape[0]):
                    output = F.conv2d(
                        x0[i].unsqueeze(0),
                        filters[i].unsqueeze(0).unsqueeze(0),
                        padding=1,
                    )
                    outputs.append(output.reshape(1, 28, 28))

                x0 = x0 + F.relu(torch.stack(outputs, dim=0)) / 10

            x = self.pool(F.relu(self.conv1(x0)))
            x = self.pool(F.relu(self.conv2(x)))

            x = x.flatten(1)

            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            filter_x = F.relu(self.fc3(x))

            spk1, mem1 = self.lif1(filter_x, mem1)
            filters = spk1.view(-1, 3, 3)

        return self.fc4(x)


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
net = CNNDenseSNN().to(device)

In [160]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)

counter = 0

for epoch in range(1):
    train_batch = iter(train_loader)

    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        result = net(data)

        loss_val = torch.zeros((1), device=device)
        loss_val = loss(result, targets)

        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Print train/test loss/accuracy
        # if counter % 10 == 0:
        print(f"Iteration: {counter} \t Train Loss: {loss_val.item()}")
        counter += 1

        if counter == 100:
            break

Iteration: 0 	 Train Loss: 2.2970423698425293
Iteration: 1 	 Train Loss: 2.299041509628296
Iteration: 2 	 Train Loss: 2.290811538696289
Iteration: 3 	 Train Loss: 2.2823262214660645
Iteration: 4 	 Train Loss: 2.2812108993530273
Iteration: 5 	 Train Loss: 2.256357192993164
Iteration: 6 	 Train Loss: 2.2369134426116943
Iteration: 7 	 Train Loss: 2.2182374000549316
Iteration: 8 	 Train Loss: 2.170041561126709
Iteration: 9 	 Train Loss: 2.1805713176727295
Iteration: 10 	 Train Loss: 2.100017547607422
Iteration: 11 	 Train Loss: 2.0457100868225098
Iteration: 12 	 Train Loss: 1.9376020431518555
Iteration: 13 	 Train Loss: 2.0795230865478516
Iteration: 14 	 Train Loss: 1.959992527961731
Iteration: 15 	 Train Loss: 1.7186599969863892
Iteration: 16 	 Train Loss: 1.6078206300735474
Iteration: 17 	 Train Loss: 1.5651899576187134
Iteration: 18 	 Train Loss: 1.2995526790618896
Iteration: 19 	 Train Loss: 1.3713992834091187
Iteration: 20 	 Train Loss: 1.226878046989441
Iteration: 21 	 Train Loss: 0.

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import snntorch as snn


class CNNDenseSNN(nn.Module):
    def __init__(self, beta=0.9):
        super(CNNDenseSNN, self).__init__()

        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 9 * 9)
        self.fc4 = nn.Linear(64, 10)

        self.lif1 = snn.Leaky(beta=beta)

    def forward(self, x0, filters=None):
        mem1 = self.lif1.init_leaky()

        for _ in range(10):
            if filters is not None:
                outputs = []
                for i in range(x0.shape[0]):
                    output = F.conv2d(
                        x0[i].unsqueeze(0),
                        filters[i].unsqueeze(0).unsqueeze(0),
                        padding=1,
                    )
                    outputs.append(output.reshape(1, 28, 28))

                x0 = x0 + F.relu(torch.stack(outputs, dim=0)) / 8

            x = self.pool(F.relu(self.conv1(x0)))
            x = self.pool(F.relu(self.conv2(x)))

            x = x.flatten(1)

            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            filter_x = F.relu(self.fc3(x))

            spk1, mem1 = self.lif1(filter_x, mem1)
            filters = spk1.view(-1, 3, 3)

        return self.fc4(x)


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
net = CNNDenseSNN().to(device)

In [12]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)

counter = 0

for epoch in range(1):
    train_batch = iter(train_loader)

    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        result = net(data)

        loss_val = torch.zeros((1), device=device)
        loss_val = loss(result, targets)

        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Print train/test loss/accuracy
        # if counter % 10 == 0:
        print(f"Iteration: {counter} \t Train Loss: {loss_val.item()}")
        counter += 1

        if counter == 100:
            break

Iteration: 0 	 Train Loss: 2.3100595474243164
Iteration: 1 	 Train Loss: 2.300555467605591
Iteration: 2 	 Train Loss: 2.301348924636841
Iteration: 3 	 Train Loss: 2.289250612258911
Iteration: 4 	 Train Loss: 2.283111572265625
Iteration: 5 	 Train Loss: 2.263258695602417
Iteration: 6 	 Train Loss: 2.276064395904541
Iteration: 7 	 Train Loss: 2.2461369037628174
Iteration: 8 	 Train Loss: 2.2103350162506104
Iteration: 9 	 Train Loss: 2.19960618019104
Iteration: 10 	 Train Loss: 2.1158485412597656
Iteration: 11 	 Train Loss: 2.1273183822631836
Iteration: 12 	 Train Loss: 2.0381686687469482
Iteration: 13 	 Train Loss: 1.772718906402588
Iteration: 14 	 Train Loss: 1.8430920839309692
Iteration: 15 	 Train Loss: 1.722664713859558
Iteration: 16 	 Train Loss: 1.6047680377960205
Iteration: 17 	 Train Loss: 1.488071084022522
Iteration: 18 	 Train Loss: 1.4315389394760132
Iteration: 19 	 Train Loss: 1.1387362480163574
Iteration: 20 	 Train Loss: 1.005287766456604
Iteration: 21 	 Train Loss: 1.32261

In [161]:
net.eval()

correct = 0
total = 0

counter = 0

with torch.no_grad():
    for data, targets in test_loader:
        data = data.to(device)
        targets = targets.to(device)

        outputs = net(data)

        _, predicted = torch.max(outputs, 1)

        total += targets.size(0)
        correct += (predicted == targets).sum().item()
        print(f"Iteration: {counter}")
        counter += 1
        if counter == 100:
            break

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

Iteration: 0
Iteration: 1
Iteration: 2
Iteration: 3
Iteration: 4
Iteration: 5
Iteration: 6
Iteration: 7
Iteration: 8
Iteration: 9
Iteration: 10
Iteration: 11
Iteration: 12
Iteration: 13
Iteration: 14
Iteration: 15
Iteration: 16
Iteration: 17
Iteration: 18
Iteration: 19
Iteration: 20
Iteration: 21
Iteration: 22
Iteration: 23
Iteration: 24
Iteration: 25
Iteration: 26
Iteration: 27
Iteration: 28
Iteration: 29
Iteration: 30
Iteration: 31
Iteration: 32
Iteration: 33
Iteration: 34
Iteration: 35
Iteration: 36
Iteration: 37
Iteration: 38
Iteration: 39
Iteration: 40
Iteration: 41
Iteration: 42
Iteration: 43
Iteration: 44
Iteration: 45
Iteration: 46
Iteration: 47
Iteration: 48
Iteration: 49
Iteration: 50
Iteration: 51
Iteration: 52
Iteration: 53
Iteration: 54
Iteration: 55
Iteration: 56
Iteration: 57
Iteration: 58
Iteration: 59
Iteration: 60
Iteration: 61
Iteration: 62
Iteration: 63
Iteration: 64
Iteration: 65
Iteration: 66
Iteration: 67
Iteration: 68
Iteration: 69
Iteration: 70
Iteration: 71
It

In [13]:
net.eval()

correct = 0
total = 0

counter = 0

with torch.no_grad():
    for data, targets in test_loader:
        data = data.to(device)
        targets = targets.to(device)

        outputs = net(data)

        _, predicted = torch.max(outputs, 1)

        total += targets.size(0)
        correct += (predicted == targets).sum().item()
        print(f"Iteration: {counter}")
        counter += 1
        if counter == 100:
            break

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

Iteration: 0
Iteration: 1
Iteration: 2
Iteration: 3
Iteration: 4
Iteration: 5
Iteration: 6
Iteration: 7
Iteration: 8
Iteration: 9
Iteration: 10
Iteration: 11
Iteration: 12
Iteration: 13
Iteration: 14
Iteration: 15
Iteration: 16
Iteration: 17
Iteration: 18
Iteration: 19
Iteration: 20
Iteration: 21
Iteration: 22
Iteration: 23
Iteration: 24
Iteration: 25
Iteration: 26
Iteration: 27
Iteration: 28
Iteration: 29
Iteration: 30
Iteration: 31
Iteration: 32
Iteration: 33
Iteration: 34
Iteration: 35
Iteration: 36
Iteration: 37
Iteration: 38
Iteration: 39
Iteration: 40
Iteration: 41
Iteration: 42
Iteration: 43
Iteration: 44
Iteration: 45
Iteration: 46
Iteration: 47
Iteration: 48
Iteration: 49
Iteration: 50
Iteration: 51
Iteration: 52
Iteration: 53
Iteration: 54
Iteration: 55
Iteration: 56
Iteration: 57
Iteration: 58
Iteration: 59
Iteration: 60
Iteration: 61
Iteration: 62
Iteration: 63
Iteration: 64
Iteration: 65
Iteration: 66
Iteration: 67
Iteration: 68
Iteration: 69
Iteration: 70
Iteration: 71
It

In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import snntorch as snn


class CNNDenseSNN(nn.Module):
    def __init__(self, beta=0.9):
        super(CNNDenseSNN, self).__init__()

        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 9 * 9)
        self.fc4 = nn.Linear(64, 10)

        self.lif1 = snn.Leaky(beta=beta)

    def forward(self, x0, filters=None):
        mem1 = self.lif1.init_leaky()

        for j in range(10):
            if filters is not None:
                outputs = []
                for i in range(x0.shape[0]):
                    output = F.conv2d(
                        x0[i].unsqueeze(0),
                        filters[i].unsqueeze(0).unsqueeze(0),
                        padding=1,
                    )
                    outputs.append(output.reshape(1, 28, 28))

                # x0 = x0 + F.relu(torch.stack(outputs, dim=0)) / 8

                x0 = (
                    x0
                    + pow(-1, j)
                    * torch.exp(torch.tensor((-j / 16)))
                    * torch.stack(outputs, dim=0)
                    / 5
                )

            x = self.pool(F.relu(self.conv1(x0)))
            x = self.pool(F.relu(self.conv2(x)))

            x = x.flatten(1)

            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            filter_x = F.relu(self.fc3(x))

            spk1, mem1 = self.lif1(filter_x, mem1)
            filters = spk1.view(-1, 3, 3)

        return self.fc4(x)


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
net = CNNDenseSNN().to(device)

In [24]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)

counter = 0

for epoch in range(1):
    train_batch = iter(train_loader)

    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        result = net(data)

        loss_val = torch.zeros((1), device=device)
        loss_val = loss(result, targets)

        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Print train/test loss/accuracy
        # if counter % 10 == 0:
        print(f"Iteration: {counter} \t Train Loss: {loss_val.item()}")
        counter += 1

        if counter == 100:
            break

Iteration: 0 	 Train Loss: 2.3011057376861572
Iteration: 1 	 Train Loss: 2.298447608947754
Iteration: 2 	 Train Loss: 2.2773585319519043
Iteration: 3 	 Train Loss: 2.273264169692993
Iteration: 4 	 Train Loss: 2.254110813140869
Iteration: 5 	 Train Loss: 2.2486047744750977
Iteration: 6 	 Train Loss: 2.230839729309082
Iteration: 7 	 Train Loss: 2.218226671218872
Iteration: 8 	 Train Loss: 2.1741645336151123
Iteration: 9 	 Train Loss: 2.1841020584106445
Iteration: 10 	 Train Loss: 2.118741750717163
Iteration: 11 	 Train Loss: 2.081988573074341
Iteration: 12 	 Train Loss: 2.089444160461426
Iteration: 13 	 Train Loss: 2.04512882232666
Iteration: 14 	 Train Loss: 2.011805295944214
Iteration: 15 	 Train Loss: 1.9291667938232422
Iteration: 16 	 Train Loss: 1.9229484796524048
Iteration: 17 	 Train Loss: 1.8611953258514404
Iteration: 18 	 Train Loss: 1.8005448579788208
Iteration: 19 	 Train Loss: 1.6944576501846313
Iteration: 20 	 Train Loss: 1.6482499837875366
Iteration: 21 	 Train Loss: 1.6814

In [25]:
net.eval()

correct = 0
total = 0

counter = 0

with torch.no_grad():
    for data, targets in test_loader:
        data = data.to(device)
        targets = targets.to(device)

        outputs = net(data)

        _, predicted = torch.max(outputs, 1)

        total += targets.size(0)
        correct += (predicted == targets).sum().item()
        print(f"Iteration: {counter}")
        counter += 1
        if counter == 100:
            break

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

Iteration: 0
Iteration: 1
Iteration: 2
Iteration: 3
Iteration: 4
Iteration: 5
Iteration: 6
Iteration: 7
Iteration: 8
Iteration: 9
Iteration: 10
Iteration: 11
Iteration: 12
Iteration: 13
Iteration: 14
Iteration: 15
Iteration: 16
Iteration: 17
Iteration: 18
Iteration: 19
Iteration: 20
Iteration: 21
Iteration: 22
Iteration: 23
Iteration: 24
Iteration: 25
Iteration: 26
Iteration: 27
Iteration: 28
Iteration: 29
Iteration: 30
Iteration: 31
Iteration: 32
Iteration: 33
Iteration: 34
Iteration: 35
Iteration: 36
Iteration: 37
Iteration: 38
Iteration: 39
Iteration: 40
Iteration: 41
Iteration: 42
Iteration: 43
Iteration: 44
Iteration: 45
Iteration: 46
Iteration: 47
Iteration: 48
Iteration: 49
Iteration: 50
Iteration: 51
Iteration: 52
Iteration: 53
Iteration: 54
Iteration: 55
Iteration: 56
Iteration: 57
Iteration: 58
Iteration: 59
Iteration: 60
Iteration: 61
Iteration: 62
Iteration: 63
Iteration: 64
Iteration: 65
Iteration: 66
Iteration: 67
Iteration: 68
Iteration: 69
Iteration: 70
Iteration: 71
It