In [None]:
from dvs_gesture import DVSGesture
import torch.nn as nn
import torch
import pytorch_lightning as pl
from tqdm.auto import tqdm
from dvs_gesture_model import ExodusNetwork, SlayerNetwork

In [None]:
batch_size = 16
dataset = DVSGesture(
    batch_size=batch_size,
    bin_dt=5000,
    spatial_factor=0.5,
    augmentation=True,
)
dataset.setup(reset_cache=False)
dataset.prepare_data()

In [None]:
# dataloader = dataset.val_dataloader()
trainloader = dataset.train_dataloader()

In [None]:
events = next(iter(trainloader))[0]

In [None]:
events.shape

In [None]:
import matplotlib.pyplot as plt

plt.imshow(events[7, 100:150].sum(0)[0])

In [None]:
from tqdm.auto import tqdm

def cycle_through_trainloader():
    for data, targets in tqdm(trainloader):
        data = data.cuda()
        targets = targets.cuda()
        
cycle_through_trainloader()

In [None]:
data, label = next(iter(dataset.val_dataloader()))

In [None]:
data.shape

In [None]:
model_kwargs = dict(
    batch_size=16,
    tau_mem=10,
    spike_threshold=0.25,
    base_channels=2,
    kernel_size=3,
    num_conv_layers=4,
    width_grad=1.0,
    scale_grad=1.0,
    iaf=True,
    num_timesteps=300,
    dropout=True,
    batchnorm=False,
    norm_weights=True,
)

In [None]:
from train_dvs_gesture import compare_forward

sinabs_model = ExodusNetwork(backend="sinabs", **model_kwargs)
exodus_model = ExodusNetwork(**model_kwargs)
slayer_model = SlayerNetwork(**model_kwargs)

proto_params = exodus_model.parameter_copy
sinabs_model.import_parameters(proto_params)
slayer_model.import_parameters(proto_params)

compare_forward({"exodus": exodus_model, "slayer": slayer_model}, data=dataset, no_lightning=True)
compare_forward({"exodus": sinabs_model, "slayer": slayer_model}, data=dataset, no_lightning=True)

In [None]:
from time import time
import numpy as np

algorithms = ["EXODUS", "SLAYER", "BPTT"]
times = {algo: {"forward": [], "backward": [], "reset": []} for algo in algorithms}

models = {"EXODUS": exodus_model, "SLAYER": slayer_model, "BPTT": sinabs_model}
for algo, model in models.items():
    for i in tqdm(range(3)):
        times_epoch = {"forward": [], "backward": [], "reset": []}
        for data, target in tqdm(trainloader):        
            data = data.cuda()
            target = target.cuda()
            t0 = time()
            model.reset_states()
            t1 = time()
            y_hat = model(data)
            t2 = time()
            y_hat.sum().backward()
            t3 = time()
            times_epoch["reset"].append(t1-t0)
            times_epoch["forward"].append(t2-t1)
            times_epoch["backward"].append(t3-t2)
        for step, t in times_epoch.items():
            times[algo][step].append(np.mean(t))

In [None]:
import numpy as np
for action in times["EXODUS"].keys():
    print(action)
    for model, ts in times.items():
        t = np.array(ts[action])
        print(f"\t{model}: ({np.mean(t):.2e} +- {np.std(t):.2e}) s")
        # np.save(f"timings_{model}.npy", t)

In [None]:
import pandas as pd

In [None]:
pd.DataFrame(times).to_csv("times_new.csv")

In [None]:
# Convert previously saved csv to one line per measurement

times = pd.read_csv("times_new.csv", index_col=0)

table = [
    pd.DataFrame(
        {"algorithm": algo, "time": t, "step": step}
        for step in times.index
        for algo in times.loc[step].index
        for t in eval((times.loc[step].loc[algo]).replace("nan, ", "").replace("nan", ""))
    )
]
table = pd.concat(table, ignore_index=True)
table.to_csv("times_new_table.csv")