In [1]:
import torch, random
import torch.nn as nn
from tqdm.notebook import tqdm

from tonic.transforms import ToFrame
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

import matplotlib.pyplot as plt
import numpy as np

from seq_model import SNN

In [2]:
torch.backends.cudnn.deterministic = True
random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)
np.random.seed(1)

In [3]:
batch_size = 3
batch_size_pre = 32
num_workers = 1
epochs_pretrain = 1
epochs = 30
lr = 1e-3
n_time_steps = 50

In [4]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print('device: ', torch.cuda.get_device_name(0))
else:
    device = torch.device('cpu')

device:  NVIDIA GeForce RTX 3070 Ti


## Training/Testing helper functions

In [5]:
def train(batch_size, feature_map_size, dataloader_train, model, loss_fn, optimizer, epochs, test_func, dataloader_test, phase):
    epochs_y = []
    epochs_x = []
    epochs_acc = []
    model.train()

    for e in range(epochs):
        losses = []
        batches = []
        batch_count = 0
        train_p_bar = tqdm(dataloader_train)

        for X, y in train_p_bar:
            # reshape the input from [Batch, Time, Channel, Height, Width] into [Batch*Time, Channel, Height, Width]
            X = X.reshape(-1, feature_map_size[2], feature_map_size[0], feature_map_size[1]).to(dtype=torch.float, device=device)
            y = y.to(dtype=torch.long, device=device)

            # forward
            pred = model(X)

            # reshape the output from [Batch*Time,num_classes] into [Batch, Time, num_classes]
            pred = pred.reshape(batch_size, n_time_steps, -1)

            # accumulate all time-steps output for final prediction
            pred = pred.sum(dim = 1)
            loss = loss_fn(pred, y)

            # gradient update
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # detach the neuron states and activations from current computation graph(necessary)
            model.detach_neuron_states()

            train_p_bar.set_description(f"{phase} - Epoch {e} - BPTT Training Loss: {round(loss.item(), 4)}")

            batch_count += 1
            losses.append(loss.item())
            batches.append(batch_count)

        epochs_y.append(losses)
        epochs_x.append(batches)

        acc = test_func(feature_map_size, dataloader_test, model)
        print(f'{phase} - Epoch {e} accuracy: {acc}')
        epochs_acc.append(acc)

    return epochs_x, epochs_y, epochs_acc


In [6]:
def test(feature_map_size, dataloader, model):
    correct_predictions = []
    with torch.no_grad():
        test_p_bar = tqdm(dataloader)
        for X, y in test_p_bar:
            # reshape the input from [Batch, Time, Channel, Height, Width] into [Batch*Time, Channel, Height, Width]
            X = X.reshape(-1, feature_map_size[2], feature_map_size[0], feature_map_size[1]).to(dtype=torch.float, device=device)
            y = y.to(dtype=torch.long, device=device)

            # forward
            output = model(X)

            # reshape the output from [Batch*Time,num_classes] into [Batch, Time, num_classes]
            output = output.reshape(batch_size, n_time_steps, -1)

            # accumulate all time-steps output for final prediction
            output = output.sum(dim=1)

            # calculate accuracy
            pred = output.argmax(dim=1, keepdim=True)

            # compute the total correct predictions
            correct_predictions.append(pred.eq(y.view_as(pred)))

            test_p_bar.set_description(f"Testing Model...")
    
    correct_predictions = torch.cat(correct_predictions)
    return correct_predictions.sum().item()/(len(correct_predictions))*100

## Pre-training loop

Loading the pre-training data. Dataset used to pre-train the network such that its parameters are set within a "good" region of the parameters space (i.e., hopefully training on a "simpler" dataset sets the wheights to values that improve the training on a harder dataset).

In [7]:
from tonic.datasets.nmnist import NMNIST

root_dir = "../NMNIST"
_ = NMNIST(save_to=root_dir, train=True)
_ = NMNIST(save_to=root_dir, train=False)

to_raster = ToFrame(sensor_size=NMNIST.sensor_size, n_time_bins=n_time_steps)

snn_train_dataset_pre = NMNIST(save_to=root_dir, train=True, transform=to_raster)
snn_test_dataset_pre = NMNIST(save_to=root_dir, train=False, transform=to_raster)

sample_data, label = snn_train_dataset_pre[0]
print(f"The transformed array is in shape [Time-Step, Channel, Height, Width] --> {sample_data.shape}")

snn_train_dataloader_pre = DataLoader(snn_train_dataset_pre, batch_size=batch_size_pre, num_workers=num_workers, drop_last=True, shuffle=True)
snn_test_dataloader_pre = DataLoader(snn_test_dataset_pre, batch_size=batch_size_pre, num_workers=num_workers, drop_last=True, shuffle=False)

The transformed array is in shape [Time-Step, Channel, Height, Width] --> (50, 2, 34, 34)


instantiating model...

In [8]:
snn = SNN(10, 10, batch_size_pre).to(device)
snn.init_weights()

loss and optimizer...

In [9]:
optimizer = Adam(snn.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-8)
loss_fn = CrossEntropyLoss()

pre-training the model...

In [10]:
epochs_x_pre, epochs_y_pre, epochs_acc_pre = train(
    batch_size_pre,
    NMNIST.sensor_size, 
    snn_train_dataloader_pre, 
    snn, 
    loss_fn, 
    optimizer, 
    epochs_pretrain, 
    test, 
    snn_test_dataloader_pre,
    'pre-training'
    )

snn.export_conv_params()

  0%|          | 0/1875 [00:00<?, ?it/s]

plotting...

In [None]:
y_avg = []
for y in epochs_y_pre:
    y_avg.append(np.mean(y))

plt.plot(np.arange(len(epochs_x_pre)), y_avg, marker = '.')
plt.xlabel('epoch')
plt.ylabel('average loss')
plt.ylim(0,)
plt.xticks(np.arange(len(epochs_x_pre)))
for i, txt in enumerate(y_avg):
    if i%2 == 0:
        pass
    else:
        plt.text(i, txt, f'{txt:.2f}', ha='center', va='bottom', color = 'k')
plt.show()

In [None]:
plt.plot(np.arange(len(epochs_x_pre)), epochs_acc_pre, marker = '.')
plt.xlabel('epoch')
plt.ylabel('test accuracy')
plt.ylim(0, 100)
plt.xticks(np.arange(len(epochs_x_pre)))
for i, txt in enumerate(epochs_acc_pre):
    if i%2 == 0:
        pass
    else:
        plt.text(i, txt, f'{txt:.2f}', ha='center', va='bottom', color = 'k')
plt.show()

## "Post-training" loop

loading the data...

In [None]:
from tonic.datasets.dvsgesture import DVSGesture

root_dir = "../DVSGESTURE"
_ = DVSGesture(save_to=root_dir, train=True)
_ = DVSGesture(save_to=root_dir, train=False)

to_raster = ToFrame(sensor_size=DVSGesture.sensor_size, n_time_bins=n_time_steps)

snn_train_dataset = DVSGesture(save_to=root_dir, train=True, transform=to_raster)
snn_test_dataset = DVSGesture(save_to=root_dir, train=False, transform=to_raster)

sample_data, label = snn_train_dataset[0]
print(f"The transformed array is in shape [Time-Step, Channel, Height, Width] --> {sample_data.shape}")

snn_train_dataloader = DataLoader(snn_train_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, shuffle=True)
snn_test_dataloader = DataLoader(snn_test_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, shuffle=False)

instantiating model...

In [None]:
snn = SNN(11, 810, batch_size).to(device)
snn.init_weights()

loading weights from pre-training...

In [None]:
snn.load_conv_params()

loss and optimizer...

In [None]:
optimizer = Adam(snn.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-8)
loss_fn = CrossEntropyLoss()

training the model...

In [None]:
epochs_x_dvs128, epochs_y_dvs128, epochs_acc_dvs128 = train(
    batch_size,
    DVSGesture.sensor_size, 
    snn_train_dataloader, 
    snn, 
    loss_fn, 
    optimizer, 
    epochs, 
    test, 
    snn_test_dataloader,
    'post-training'
    )

In [None]:
y_avg = []
for y in epochs_y_dvs128:
    y_avg.append(np.mean(y))

plt.plot(np.arange(len(epochs_x_dvs128)), y_avg, marker = '.')
plt.xlabel('epoch')
plt.ylabel('average loss')
plt.ylim(0,)
plt.xticks(np.arange(len(epochs_x_dvs128)))
for i, txt in enumerate(y_avg):
    if i%2 == 0:
        pass
    else:
        plt.text(i, txt, f'{txt:.2f}', ha='center', va='bottom', color = 'k')
plt.show()

In [None]:
plt.plot(np.arange(len(epochs_x_dvs128)), epochs_acc_dvs128, marker = '.')
plt.xlabel('epoch')
plt.ylabel('test accuracy')
plt.ylim(0, 100)
plt.xticks(np.arange(len(epochs_x_dvs128)))
for i, txt in enumerate(epochs_acc_dvs128):
    if i%2 == 0:
        pass
    else:
        plt.text(i, txt, f'{txt:.2f}', ha='center', va='bottom', color = 'k')
plt.show()