In [None]:
from data_modules.dvs_gesture import DVSGesture
import torch.nn as nn
import torch
import pytorch_lightning as pl
from tqdm.auto import tqdm
from dvs_gesture_model import ExodusNetwork, SlayerNetwork

In [None]:
batch_size = 32
dataset = DVSGesture(
    batch_size=32,
    bin_dt=5000,
    spatial_factor=0.5,
    fraction=1,
    augmentation=False,
    num_time_bins=300,
)
dataset.setup(reset_cache=False)
dataset.prepare_data()

In [None]:
# dataloader = dataset.val_dataloader()
trainloader = dataset.train_dataloader()

In [None]:
events = next(iter(trainloader))[0]

In [None]:
events.shape

In [None]:
from tqdm.auto import tqdm

def cycle_through_trainloader():
    for data, targets in tqdm(trainloader):
        data = data.cuda()
        targets = targets.cuda()
        
cycle_through_trainloader()

In [None]:
data, label = next(iter(dataset.val_dataloader()))

In [None]:
data.shape

In [None]:
model_kwargs = dict(
    batch_size=32,
    tau_mem=10,
    spike_threshold=0.25,
    base_channels=2,
    kernel_size=3,
    num_conv_layers=4,
    width_grad=1.0,
    scale_grad=1.0,
    iaf=True,
    num_timesteps=300,
    dropout=True,
    batchnorm=False,
    norm_weights=True,
)

In [None]:
from train_dvs_gesture import compare_forward

sinabs_model = ExodusNetwork(backend="sinabs", **model_kwargs)
exodus_model = ExodusNetwork(**model_kwargs)
slayer_model = SlayerNetwork(**model_kwargs)

proto_params = exodus_model.parameter_copy
sinabs_model.import_parameters(proto_params)
slayer_model.import_parameters(proto_params)

compare_forward({"exodus": exodus_model, "slayer": slayer_model}, data=dataset, no_lightning=True)
compare_forward({"exodus": sinabs_model, "slayer": slayer_model}, data=dataset, no_lightning=True)

In [None]:
from time import time

models = {"BPTT": sinabs_model, "EXODUS": exodus_model, "SLAYER": slayer_model}
times = {k: [] for k in models}

for name, model in models.items():
    for i in tqdm(range(10)):
        t0 = time()
        for data, target in tqdm(trainloader):
            data = data.cuda()
            target = target.cuda()
            model.reset_states()
            y_hat = model(data)
            y_hat.sum().backward()
        times[name].append(time() - t0)
    

In [None]:
import numpy as np
for model, ts in times.items():
    t = np.array(ts)
    print(f"{model}: ({np.mean(t)} +- {np.std(t)}) s")
    # np.save(f"timings_{model}.npy", t)

In [4]:
import pandas as pd

In [16]:
pd.DataFrame(times).to_csv("times.csv")