In [1]:
import torch
from tqdm.notebook import tqdm
import torch.nn as nn
import numpy as np
from sinabs.layers.functional import threshold_subtract, threshold_reset 
from sinabs.layers import Layer
# - Type alias for array-like objects
from typing import Optional, Union, List, Tuple
ArrayLike = Union[np.ndarray, List, Tuple]
from abc import abstractmethod
import warnings

In [2]:
class SpikingLayer(Layer):
    def __init__(
        self,
        input_shape: ArrayLike,
        threshold: float = 1.0,
        threshold_low: Optional[float] = -1.0,
        membrane_subtract: Optional[float] = 1.0,
        membrane_reset: float = 0,
        layer_name: str = "spiking",
        negative_spikes: bool = False
    ):
        """
        Pytorch implementation of a spiking neuron.
        This class is the base class for any layer that need to implement integrate-and-fire operations.

        :param input_shape: Input data shape
        :param threshold: Spiking threshold of the neuron
        :param threshold_low: Lowerbound for membrane potential
        :param membrane_subtract: Upon spiking if the membrane potential is subtracted as opposed to reset, what is its value
        :param membrane_reset: What is the reset membrane potential of the neuron
        :param layer_name: Name of this layer
        :param negative_spikes: Implement a linear transfer function through negative spiking

        NOTE: SUBTRACT superseeds Reset value
        """
        super().__init__(input_shape=input_shape, layer_name=layer_name)
        # Initialize neuron states
        self.membrane_subtract = membrane_subtract
        self.membrane_reset = membrane_reset
        self.threshold = threshold
        self.threshold_low = threshold_low
        self.negative_spikes = negative_spikes

        # Blank parameter place holders
        self.register_buffer("state", torch.zeros(1))
        self.register_buffer("activations", torch.zeros(1))
        self.spikes_number = None

    @property
    def threshold_low(self):
        return self._threshold_low

    @threshold_low.setter
    def threshold_low(self, new_threshold_low):
        self._threshold_low = new_threshold_low
        if new_threshold_low is None:
            try:
                del self.thresh_lower
            except AttributeError:
                pass
        else:
            # Relu on the layer
            self.thresh_lower = nn.Threshold(new_threshold_low, new_threshold_low)

    def reset_states(self, shape=None):
        """
        Reset the state of all neurons in this layer
        """
        if shape is None:
            shape = self.state.shape
        self.state = torch.zeros(shape, device=self.state.device)
        self.activations = torch.zeros(shape, device=self.state.device)

    @abstractmethod
    def synaptic_output(self, input_spikes: torch.Tensor) -> torch.Tensor:
        """
        This method needs to be overridden/defined by the child class

        :param input_spikes: torch.Tensor input to the layer.
        :return:  torch.Tensor - synaptic output current
        """
        pass

    def forward(self, binary_input: torch.Tensor):
        # Determine no. of time steps from input
        neg_spikes = self.negative_spikes

        # Compute the synaptic current
        syn_out: torch.Tensor = self.synaptic_output(binary_input)
        time_steps = len(syn_out)

        # Local variables
        membrane_subtract = self.membrane_subtract
        threshold = self.threshold
        threshold_low = self.threshold_low
        membrane_reset = self.membrane_reset

        # Create a vector to hold all output spikes
        spikes = []

        # Initialize state as required
        state = self.state
        if state.shape != syn_out.shape[1:]:
            self.reset_states(shape=syn_out.shape[1:])
            
        activations = self.activations
        # Loop over time steps
        for iCurrentTimeStep in range(time_steps):
            # update neuron states
            state = syn_out[iCurrentTimeStep] + state - activations*threshold
            # generate spikes
            activations = threshold_subtract(state, threshold, threshold/2)
            spikes.append(activations)

        self.state = state
        self.tw = time_steps
        self.activations = activations
        all_spikes = torch.stack(spikes)
        spikes_number = all_spikes.sum()
        return all_spikes

In [3]:
class SpikeSomaLayer(SpikingLayer):
    def synaptic_output(self, input_current: torch.Tensor) -> torch.Tensor:
        """
        This method needs to be overridden/defined by the child class

        :param input_spikes: torch.Tensor input to the layer.
        :return:  torch.Tensor - synaptic output current
        """
        return input_current
    
    def get_output_shape(self, in_shape):
        return in_inshape

In [4]:
class MyRNN(nn.Module):
    def __init__(self, n_inp=28, n_neurons=100, n_out=10, decay=0.8):
        super().__init__()
        self.n_inp = n_inp
        self.n_neurons = n_neurons
        self.n_out = n_out
        self.v_th = 1.0
        # Feed forward input
        self.inp_linear = nn.Linear(n_inp, n_neurons, bias=False)
        # Recurrent pool of neurons
        self.rec = nn.Linear(n_neurons, n_neurons, bias=False)
        self.rec_neurons = SpikeSomaLayer(input_shape=n_neurons)
        # Feedforward output
        self.out_linear = nn.Linear(n_neurons, n_out, bias=False)
        self.out_neurons = SpikeSomaLayer(input_shape=n_out)
        self.init_states()

    def init_states(self, randomize=True, batch_size=1):
        self.rec_neurons.reset_states(shape=(batch_size,self.n_neurons))
        self.out_neurons.reset_states(shape=(batch_size, self.n_out))

    def forward(self, inp) -> (torch.Tensor, torch.Tensor):
        activations = self.rec_neurons.activations
        # all_rec_spikes = []
        all_out_spikes = []
        for row in range(inp.shape[1]):
            # Readout layer
            out_linear = self.out_linear(activations)
            out = self.out_neurons(out_linear.unsqueeze(0)).squeeze(0)
            assert out.shape == (128, 10)
            all_out_spikes.append(out)
            # Recurrent input
            recurrent_inputs = self.rec(activations)  # Recurrent input
            # Input activations
            input_ext = self.inp_linear(inp[:, row])  # External input
            activations = self.rec_neurons(
                (input_ext + recurrent_inputs).unsqueeze(0)
            ).squeeze(0)  # recurrent spiking neuron output

        all_out_spikes = torch.stack(all_out_spikes).transpose(0,1)

        return all_out_spikes.sum(1), all_out_spikes

In [5]:
def accuracy(preds, labels):
    with torch.no_grad():
        accuracy = 100 * (torch.argmax(preds, 1) == labels).float().sum() / len(labels)
    return accuracy.detach().item()


def binarize(data):
    return (data > 0).float()

In [6]:
import torchvision
from datetime import datetime
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

batch_size = 128
n_epochs = 20
n_neurons = 512
decay = 1.0

randomize_vmem = True


dataset = torchvision.datasets.MNIST(root="./", train=False, download=True)
device = (
    torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
)

# Convert to tensor and binarize them
transform = transforms.Compose([transforms.ToTensor(), binarize, torch.squeeze])

# Download and load training dataset
trainset = torchvision.datasets.MNIST(
    root="./", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True
)

# download and load testing dataset
testset = torchvision.datasets.MNIST(
    root="./", train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=8, drop_last=True
)

In [7]:
rnn = MyRNN(n_neurons=n_neurons, decay=decay)

In [8]:
#torch.autograd.set_detect_anomaly(True)

# Training parameters
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=1e-4)

# try:
#    params = torch.load("trained/srnn_mnist_19-09-14-22:09.pt")
#    rnn.load_state_dict(params['model_state_dict'])
# except FileNotFoundError as e:
#    pass
rnn.to(device)

# Log data of the experiment
writer = SummaryWriter()
save_path = writer.get_logdir()

pbar_epoch = tqdm(range(n_epochs))
for epoch in pbar_epoch:
    running_loss = 0
    running_accuracy = []

    # Training dataset
    for data in tqdm(trainloader):
        optimizer.zero_grad()
        rnn.init_states(randomize=randomize_vmem, batch_size=batch_size)
        imgs, labels = data
        imgs = imgs.to(device)
        labels = labels.to(device)

        out, _ = rnn(imgs)

        loss = criterion(out, labels)
        loss.backward()

        optimizer.step()

        running_accuracy.append(accuracy(out, labels))

        running_loss += loss.detach().item()

    # Test dataset
    with torch.no_grad():
        test_accuracy = []
        for data in tqdm(testloader):
            rnn.init_states(randomize=randomize_vmem, batch_size=batch_size)
            imgs, labels = data
            imgs = imgs.to(device)
            labels = labels.to(device)
            out, spikes_out = rnn(imgs)
            test_accuracy.append(accuracy(out, labels))

        pbar_epoch.set_postfix(
            loss=loss.item(),
            weights=[p.abs().mean().item() for p in rnn.parameters()],
            train_accuracy=np.mean(running_accuracy),
            test_accuracy=np.mean(test_accuracy),
        )

        params = list(rnn.parameters())
        writer.add_scalars(
            "Accuracy",
            {"train": np.mean(running_accuracy), "test": np.mean(test_accuracy)},
            epoch,
        )
        writer.add_scalar("Weight/Input", params[0].abs().mean().item(), epoch)
        writer.add_scalar("Weight/Recurrent", params[1].abs().mean().item(), epoch)
        writer.add_scalar("Weight/Output", params[2].abs().mean().item(), epoch)
        writer.flush()
writer.close()

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=468.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=78.0), HTML(value='')))





In [9]:
%debug

ERROR:root:No traceback has been produced, nothing to debug.
