In [15]:
from torchvision import datasets
import torch

class MNIST_Dataset(datasets.MNIST):
    def __init__(self, root, train=True, single_channel=False):
        datasets.MNIST.__init__(self, root, train=train, download=True)
        self.single_channel = single_channel
        
    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = img.float() / 255.
        
        # default is  by row, output is [time, channels] = [28, 28]
        # OR if we want by single item, output is [784, 1]
        if self.single_channel:
            img = img.reshape(-1).unsqueeze(1)
        
        spikes = torch.rand(size=img.shape) < img
        spikes = spikes.float()
        
        return spikes, target


In [74]:
import sinabs.layers as sl
from torch import nn

class LinModel(nn.Module):
    def __init__(self, batch_size=1):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Linear(28, 128),
            sl.SpikingLayer(batch_size=batch_size),
            nn.Linear(128, 128),
            sl.SpikingLayer(batch_size=batch_size),
            nn.Linear(128, 256),
            sl.SpikingLayer(batch_size=batch_size),
            nn.Linear(256, 256),
            sl.SpikingLayer(batch_size=batch_size),
            nn.Linear(256, 10),
            sl.SpikingLayer(batch_size=batch_size)
        )

    def forward(self, x):
        return self.model(x)
    
    def reset_states(self):
        for lyr in self.model:
            if isinstance(lyr, sl.SpikingLayer):
                lyr.reset_states()

                
class LinModelSingle(nn.Module):
    def __init__(self, batch_size=1):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Linear(1, 1024),
            sl.SpikingLayer(batch_size=batch_size),
            nn.Linear(1024, 512),
            sl.SpikingLayer(batch_size=batch_size),
            nn.Linear(512, 10),
            sl.SpikingLayer(batch_size=batch_size)
        )

    def forward(self, x):
        return self.model(x)
    
    def reset_states(self):
        for lyr in self.model:
            if isinstance(lyr, sl.SpikingLayer):
                lyr.reset_states()


In [77]:
BATCH_SIZE = 64

dataset = MNIST_Dataset(root="./data/", train=True, single_channel=True)
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=BATCH_SIZE, drop_last=True)

model = LinModelSingle(batch_size=BATCH_SIZE)

In [78]:
from tqdm.notebook import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(10):
    pbar = tqdm(dataloader)
    for img, target in pbar:
        optimizer.zero_grad()
        model.reset_states()

        out = model(img)
        out = out.sum(1)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item())
    
    torch.save(model.state_dict(), "sequential_mnist.pth")

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=937.0), HTML(value='')))




RuntimeError: python_error

In [34]:
model.load_state_dict(torch.load("sequential_mnist.pth"))

<All keys matched successfully>

In [39]:
accs = []

dataset_test = MNIST_Dataset(root="./data/", train=False)
dataloader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=BATCH_SIZE, drop_last=True)

pbar = tqdm(dataloader_test)
for img, target in pbar:
    model.reset_states()

    out = model(img)
    out = out.sum(1)

    predicted = torch.max(out, axis=1)[1]
    acc = (predicted == target).sum().numpy() / BATCH_SIZE
    accs.append(acc)
        
print(sum(accs)/len(accs))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=156.0), HTML(value='')))


0.6365184294871795


## Training a baseline

In [71]:
ann = nn.Sequential(
    nn.Linear(28, 128),
    nn.ReLU(),
    nn.Linear(128, 128),
    nn.ReLU(),
    nn.Linear(128, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
    nn.ReLU()
)

In [72]:
from tqdm.notebook import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ann.parameters())

for epoch in range(10):
    pbar = tqdm(dataloader)
    for img, target in pbar:
        optimizer.zero_grad()
        
        target = target.unsqueeze(1).repeat([1, 28])
        img = img.reshape([-1, 28])
        target = target.reshape([-1])

        out = ann(img)
#         out = out.sum(1)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item())
    
    torch.save(model.state_dict(), "sequential_mnist.pth")

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=937.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=937.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=937.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=937.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=937.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=937.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=937.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=937.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=937.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=937.0), HTML(value='')))




KeyboardInterrupt: 

In [73]:
accs = []

dataset_test = MNIST_Dataset(root="./data/", train=False)
dataloader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=BATCH_SIZE, drop_last=True)

pbar = tqdm(dataloader_test)
for img, target in pbar:
    
    img = img.reshape([-1, 28])
    out = ann(img)
    out = out.reshape([64, 28, 10])
    out = out.sum(1)

    predicted = torch.max(out, axis=1)[1]
    acc = (predicted == target).sum().numpy() / BATCH_SIZE
    accs.append(acc)
        
print(sum(accs)/len(accs))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=156.0), HTML(value='')))


0.23858173076923078
