In [1]:
import torch
from tqdm.notebook import tqdm
import torch.nn as nn
import numpy as np
from sinabs.layers.functional import threshold_subtract, threshold_reset 

In [2]:
class SpikeSomaLayer(nn.Module):
    def __init__(self, n_neurons, decay=1.0):
        super().__init__()
        self.n_neurons = n_neurons
        self.v_th = 1
        self.decay = decay
        self.reset_states()

    def reset_states(self, shape=None, device="cpu", randomize=False):
        if shape == None:
            shape = (1, self.n_neurons)

        if randomize:
            self.vmem = torch.rand(shape, device=device)
        else:
            self.vmem = torch.zeros(shape, device=device)

        self.activations = torch.zeros_like(self.vmem, device=device)

    def forward(self, inp):
        try:
            assert inp.shape == self.vmem.shape
            assert inp.device == self.vmem.device
        except AssertionError as e:
            self.reset_states(inp.shape, inp.device)
        self.vmem = self.decay*self.vmem + inp - self.activations*self.v_th
        self.activations = threshold_subtract(self.vmem, self.v_th, 0.5)
        return self.activations

In [4]:
def accuracy(preds, labels):
    with torch.no_grad():
        accuracy = 100 * (torch.argmax(preds, 1) == labels).float().sum() / len(labels)
    return accuracy.detach().item()


def binarize(data):
    return (data > 0).float()

In [5]:
import torchvision
from datetime import datetime
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

batch_size = 128
n_epochs = 20
n_neurons = 512
decay = 1.0

randomize_vmem = True

rnn = MyRNN(n_neurons=n_neurons, decay=decay)

dataset = torchvision.datasets.MNIST(root="./", train=False, download=True)
device = (
    torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
)

# Convert to tensor and binarize them
transform = transforms.Compose([transforms.ToTensor(), binarize, torch.squeeze])

# Download and load training dataset
trainset = torchvision.datasets.MNIST(
    root="./", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=8
)

# download and load testing dataset
testset = torchvision.datasets.MNIST(
    root="./", train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=8
)



In [6]:
# Training parameters
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=1e-4)

# try:
#    params = torch.load("trained/srnn_mnist_19-09-14-22:09.pt")
#    rnn.load_state_dict(params['model_state_dict'])
# except FileNotFoundError as e:
#    pass
rnn.to(device)

# Log data of the experiment
writer = SummaryWriter()
save_path = writer.get_logdir()

pbar_epoch = tqdm(range(n_epochs))
for epoch in pbar_epoch:
    running_loss = 0
    running_accuracy = []

    # Training dataset
    for data in tqdm(trainloader):
        optimizer.zero_grad()
        rnn.init_states(randomize=randomize_vmem)
        imgs, labels = data
        imgs = imgs.to(device)
        labels = labels.to(device)

        out, _ = rnn(imgs)

        loss = criterion(out, labels)
        loss.backward()

        optimizer.step()

        running_accuracy.append(accuracy(out, labels))

        running_loss += loss.detach().item()

    # Test dataset
    with torch.no_grad():
        test_accuracy = []
        for data in tqdm(testloader):
            rnn.init_states()
            imgs, labels = data
            imgs = imgs.to(device)
            labels = labels.to(device)
            out, spikes_out = rnn(imgs)
            test_accuracy.append(accuracy(out, labels))

        pbar_epoch.set_postfix(
            loss=loss.item(),
            weights=[p.abs().mean().item() for p in rnn.parameters()],
            train_accuracy=np.mean(running_accuracy),
            test_accuracy=np.mean(test_accuracy),
        )

        params = list(rnn.parameters())
        writer.add_scalars(
            "Accuracy",
            {"train": np.mean(running_accuracy), "test": np.mean(test_accuracy)},
            epoch,
        )
        writer.add_scalar("Weight/Input", params[0].abs().mean().item(), epoch)
        writer.add_scalar("Weight/Recurrent", params[1].abs().mean().item(), epoch)
        writer.add_scalar("Weight/Output", params[2].abs().mean().item(), epoch)
        writer.flush()
writer.close()

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))



