In [None]:
from torchvision import datasets
import torch

torch.manual_seed(0)


class MNIST(datasets.MNIST):
    def __init__(self, root, train=True, single_channel=False):
        datasets.MNIST.__init__(self, root, train=train, download=True)
        self.single_channel = single_channel

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = img.float() / 255.0

        # default is  by row, output is [time, channels] = [28, 28]
        # OR if we want by single item, output is [784, 1]
        if self.single_channel:
            img = img.reshape(-1).unsqueeze(1)

        spikes = torch.rand(size=img.shape) < img
        spikes = spikes.float()

        return spikes, target

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BATCH_SIZE = 64

dataset_test = MNIST(root="dataset/", train=False)
dataloader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=BATCH_SIZE, drop_last=True
)

dataset = MNIST(root="dataset/", train=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, drop_last=True)

In [None]:
from torch import nn

ann = nn.Sequential(
    nn.Linear(28, 128),
    nn.ReLU(),
    nn.Linear(128, 128),
    nn.ReLU(),
    nn.Linear(128, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
    nn.ReLU(),
)

In [7]:
from tqdm.auto import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ann.parameters())

for epoch in range(2):
    pbar = tqdm(dataloader)
    for img, target in pbar:
        optimizer.zero_grad()

        target = target.unsqueeze(1).repeat([1, 28])
        img = img.reshape([-1, 28])
        target = target.reshape([-1])
        

        out = ann(img)
        #         out = out.sum(1)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item())

100%|█████████████████████████████████| 937/937 [00:07<00:00, 118.41it/s, loss=2.2]
100%|█████████████████████████████████| 937/937 [00:08<00:00, 113.96it/s, loss=2.2]


In [8]:
accs = []

pbar = tqdm(dataloader_test)
for img, target in pbar:

    img = img.reshape([-1, 28])
    out = ann(img)
    out = out.reshape([64, 28, 10])
    out = out.sum(1)

    predicted = torch.max(out, axis=1)[1]
    acc = (predicted == target).sum().numpy() / BATCH_SIZE
    accs.append(acc)

print(sum(accs) / len(accs))

100%|███████████████████████████████████████████| 156/156 [00:00<00:00, 235.83it/s]

0.43830128205128205





In [9]:
from sinabs.from_torch import from_model

model = from_model(ann, batch_size=BATCH_SIZE).to(device)
model = model.train()

In [10]:
model

Network(
  (spiking_model): Sequential(
    (0): Linear(in_features=28, out_features=128, bias=True)
    (1): IAFSqueeze(spike_threshold=Parameter containing:
    tensor(1., device='cuda:0'), min_v_mem=Parameter containing:
    tensor(-1., device='cuda:0'), batch_size=64, num_timesteps=-1)
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): IAFSqueeze(spike_threshold=Parameter containing:
    tensor(1., device='cuda:0'), min_v_mem=Parameter containing:
    tensor(-1., device='cuda:0'), batch_size=64, num_timesteps=-1)
    (4): Linear(in_features=128, out_features=256, bias=True)
    (5): IAFSqueeze(spike_threshold=Parameter containing:
    tensor(1., device='cuda:0'), min_v_mem=Parameter containing:
    tensor(-1., device='cuda:0'), batch_size=64, num_timesteps=-1)
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): IAFSqueeze(spike_threshold=Parameter containing:
    tensor(1., device='cuda:0'), min_v_mem=Parameter containing:
    tensor(-1., devi