### Init Parameters

In [4]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

root_folder = Path(os.getcwd())
dataset_dir = root_folder / "data/har-up-spiking-dataset-240"

multiclass = True
batch_size = 4
hidden_layers = [64,32]
nb_steps = 3000
time_duration = 60
tau_mem = 100
tau_syn = 50

last_layer_size = 12 if multiclass else 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Init Dataset

In [5]:
from utils.SpikingDataset import SpikingDataset

dataset = SpikingDataset(
    root_dir=dataset_dir,
    time_duration=time_duration,
    camera1_only=False,
    multiclass=multiclass
)
train_dataset, dev_dataset, test_dataset = dataset.split_by_trials()

events, target = test_dataset[0]
print(events.shape)

(169234,)


### Init Dataloaders

In [6]:
from utils.SpikingDataLoader import SpikingDataLoader

train_loader = SpikingDataLoader(dataset=train_dataset,nb_steps=nb_steps,batch_size=batch_size,shuffle=False)
dev_loader = SpikingDataLoader(dataset=dev_dataset, nb_steps=nb_steps, batch_size=batch_size, shuffle=False)
test_loader = SpikingDataLoader(dataset=test_dataset, nb_steps=nb_steps, batch_size=batch_size, shuffle=False)

Initializing DataLoader of size 679
Initializing DataLoader of size 66
Initializing DataLoader of size 372


### Init SpikingNN Model

In [7]:
from models.SpikingNN import SpikingNN

model = SpikingNN(
    layer_sizes=[240 * 180] + hidden_layers + [last_layer_size],
    nb_steps=nb_steps,
    time_step=time_duration / nb_steps,
    tau_mem=tau_mem * 1e-3,
    tau_syn=tau_syn * 1e-3,
)

### Init Leaky Model

In [None]:
from models.SNNTorchLeaky import SNNTorchLeaky

model = SNNTorchLeaky(
    num_inputs=dataset.nb_pixels,
    num_hidden=250,
    num_outputs=2,
    nb_steps=nb_steps,
    time_step=time_duration / nb_steps,
    tau_mem=tau_mem * 1e-3,
)

### Init Synaptic Model

In [None]:
from models.SNNTorchSyn import SNNTorchSyn

model = SNNTorchSyn(
    num_inputs=dataset.nb_pixels,
    num_hidden=250,
    num_outputs=2,
    nb_steps=nb_steps,
    time_step=time_duration / nb_steps,
    tau_mem=tau_mem * 1e-3,
    tau_syn=tau_syn * 1e-3,
)

### Init Convolutional Model

In [None]:
from models.SNNTorchConv import SNNTorchConv

model = SNNTorchConv(
    num_outputs=last_layer_size,
    nb_steps=nb_steps,
    time_step=time_duration / nb_steps,
    tau_mem=tau_mem * 1e-3,
)


### Show model summary

In [None]:
from torchinfo import summary

summary(model, input_size=(batch_size, nb_steps, 240, 180))

### Train model

In [None]:
from utils.BinaryTrainer import BinaryTrainer
from utils.MultiTrainer import MultiTrainer

if multiclass:
    trainer = MultiTrainer(model=model)
else:
    trainer = BinaryTrainer(model=model)

trainer.train(
    train_loader,
    evaluate_dataloader=dev_loader,
    nb_epochs=5,
    stop_early=False,
    dataset_bias_ratio=5.0,
)

## Evaluation

### Save model weights

In [6]:
model_name = "B4_Conv_Leaky_local"
model_save_file = root_folder / "models/saved"

model.save(model_save_file / f"{model_name}.pth")

### Load model weights

In [8]:
model_name = "B4_H64,32_N3000_LR25_W5_Multi"
model_save_file = root_folder / "models/saved"
path = model_save_file / f"{model_name}.pth"

model.load(path)

### Test model

In [None]:
from utils.BinaryTrainer import BinaryTrainer
from utils.MultiTrainer import MultiTrainer

if multiclass:
    trainer = MultiTrainer(model=model)
else:
    trainer = BinaryTrainer(model=model)

trainer = BinaryTrainer(model=model)
trainer.test(test_loader)

### Evaluate model

In [12]:
import numpy as np
import torch

last_layer_size = 12 if multiclass else 2
timesteps_per_sec = nb_steps // time_duration

# Initialize lists to store data from all batches
x_locals_list = []
y_preds_list  = []
y_locals_list = []
mem_recs_list = []
spk_recs_list = []

start_index = 0
end_index = 400

with torch.inference_mode():
    for i, (x_local, y_local) in enumerate(test_loader):
        if (start_index//batch_size <= i and i < end_index//batch_size):
            x_local = x_local.to(model.device, model.dtype)
            y_local = y_local.to(model.device, model.dtype)

            mem, spk = model.forward(x_local.to_dense())
            spk_reshaped = spk.reshape(-1, time_duration, timesteps_per_sec, last_layer_size).sum(dim=2)

            # Get the max value for each second as the prediction
            y_pred = torch.argmax(spk_reshaped, dim=2)
            
            # x_locals_list.append(x_local.to_dense().cpu().detach().numpy())
            y_preds_list.append(y_pred.cpu().detach().numpy())
            y_locals_list.append(y_local.cpu().detach().numpy())
            # mem_recs_list.append(mem.cpu().detach().numpy())
            # spk_recs_list.append(spk.cpu().detach().numpy())

# x_locals = np.concatenate(x_locals_list, axis=0)
y_preds  = np.concatenate(y_preds_list, axis=0)
y_locals = np.concatenate(y_locals_list, axis=0)
# mem_recs = np.concatenate(mem_recs_list, axis=0)
# spk_recs = np.concatenate(spk_recs_list, axis=0)

### Visualize set of samples

In [None]:
from utils.snn_visualizers import plot_correctness_matrix, plot_confusion_matrix

class_labels = {
    0: "No label",
    1: "Falling forward using hands",
    2: "Falling forward using knees",
    3: "Falling backwards",
    4: "Falling sideward",
    5: "Falling sitting in empty chair",
    6: "Walking",
    7: "Standing",
    8: "Sitting",
    9: "Picking up an object",
    10: "Jumping",
    11: "Laying"
} if multiclass else {0:"Not Fall", 1:"Fall"}

plot_confusion_matrix(y_locals, y_preds, class_labels)
# plot_correctness_matrix(y_locals, y_preds)

### Visualize specific sample

In [None]:
from utils.snn_visualizers import plot_predictions_and_labels, plot_snn_activity_combined, visualize_events

index = 10

visualize_events(x_locals[index], y_locals[index], time_duration)
plot_snn_activity_combined(mem_recs[index], spk_recs[index])
plot_predictions_and_labels(spk_recs[index], y_locals[index])