### Init Parameters

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
from pathlib import Path
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [10, 5]

root_folder = Path(os.getcwd())
dataset_dir = root_folder / "data/har-up-spiking-dataset-240"

batch_size = 1
nb_steps = 1500
max_length = 60.0
tau_mem = 100
tau_syn = 50

### Init Dataset

In [None]:
from utils.SpikingDataset import SpikingDataset

dataset = SpikingDataset(
    root_dir=dataset_dir,
    time_duration=max_length,
    camera1_only=False,
    multiclass=False
)
train_dataset, dev_dataset, test_dataset = dataset.split_by_trials()

events, target = test_dataset[0]
print(events.shape)

### Init Dataloaders

In [None]:
from utils.SpikingDataLoader import SpikingDataLoader

train_loader = SpikingDataLoader(dataset=train_dataset,nb_steps=nb_steps,batch_size=batch_size,shuffle=False)
dev_loader = SpikingDataLoader(dataset=dev_dataset, nb_steps=nb_steps, batch_size=batch_size, shuffle=False)
test_loader = SpikingDataLoader(dataset=test_dataset, nb_steps=nb_steps, batch_size=batch_size, shuffle=False)

## Binary Classification Models

### Init SpikingNN Model

In [4]:
from models.SpikingNN import SpikingNN

model = SpikingNN(
    layer_sizes=[240 * 180, 5, 2],
    nb_steps=nb_steps,
    time_step=max_length / nb_steps,
    tau_mem=tau_mem * 1e-3,
    tau_syn=tau_syn * 1e-3,
)

### Init Leaky Model

In [None]:
from models.SNNTorchLeaky import SNNTorchLeaky

model = SNNTorchLeaky(
    num_inputs=dataset.nb_pixels,
    num_hidden=250,
    num_outputs=2,
    nb_steps=nb_steps,
    time_step=max_length / nb_steps,
    tau_mem=tau_mem * 1e-3,
)

### Init Synaptic Model

In [None]:
from models.SNNTorchSyn import SNNTorchSyn

model = SNNTorchSyn(
    num_inputs=dataset.nb_pixels,
    num_hidden=250,
    num_outputs=2,
    nb_steps=nb_steps,
    time_step=max_length / nb_steps,
    tau_mem=tau_mem * 1e-3,
    tau_syn=tau_syn * 1e-3,
)

### Init Convolutional Model

In [None]:
from models.SNNTorchConv import SNNTorchConv

model = SNNTorchConv(
    num_inputs=dataset.nb_pixels,
    num_hidden=250,
    num_outputs=2,
    nb_steps=nb_steps,
    time_step=max_length / nb_steps,
    tau_mem=tau_mem * 1e-3,
)


### Train model

In [None]:
from utils.BinaryTrainer import BinaryTrainer

trainer = BinaryTrainer(model=model)
trainer.train(
    train_loader,
    evaluate_dataloader=dev_loader,
    nb_epochs=5,
    stop_early=False,
    dataset_bias_ratio=5.0,
)

## Multiclass Classification Models

### Init SpikingNN Model

In [7]:
from models.SpikingNN import SpikingNN

model = SpikingNN(
    layer_sizes=[240 * 180, 5, 12],
    nb_steps=nb_steps,
    time_step=max_length / nb_steps,
    tau_mem=tau_mem * 1e-3,
    tau_syn=tau_syn * 1e-3,
)

### Train model

In [None]:
from utils.MultiTrainer import MultiTrainer

trainer = MultiTrainer(model=model)
trainer.train(
    train_loader,
    evaluate_dataloader=dev_loader,
    nb_epochs=5,
    stop_early=False,
    dataset_bias_ratio=5.0,
)

## Evaluation

### Save model weights

In [7]:
model_name = "B1_H5_LR25_W5_Atan_spk_local"
model_save_file = root_folder / "models/saved"

model.save(model_save_file / f"{model_name}.pth")

### Load model weights

In [5]:
import torch

model_name = "B4_H5_LR25_W5_Atan"
model_save_file = root_folder / "models/saved"
path = model_save_file / f"{model_name}.pth"

model = torch.load(path, weights_only=False)

### Test model

In [None]:
from utils.BinaryTrainer import BinaryTrainer

trainer = BinaryTrainer(model=model)
trainer.test(test_loader)

### Evaluate model

In [21]:
import torch

with torch.inference_mode():
    for i, (x_local, y_local) in enumerate(test_loader):

        x_local = x_local.to(model.device, model.dtype)
        y_local = y_local.to(model.device, model.dtype)

        mem_rec, spk_rec = model.forward(x_local.to_dense())
        
        if i == 2:
            break

x_local = x_local.to_dense().cpu().detach().numpy()
y_local = y_local.cpu().detach().numpy()

mem_rec = mem_rec.cpu().detach().numpy()
spk_rec = spk_rec.cpu().detach().numpy()

### Visualize results

In [None]:
from utils.snn_visualizers import visualize_events, visualize_input_grid

# print(x_local.shape)
# print(y_local.shape)

visualize_events(x_local[0], y_local[0], 60)

In [None]:
from utils.snn_visualizers import visualize_snn_output

# print(mem_rec.shape)
# print(spk_rec.shape)

visualize_snn_output(mem_rec[0], spk_rec[0], time_range=[0, 300])