In [1]:
from hsd import HSD
from hsd_exodus import ExodusNetwork
from hsd_slayer import SlayerNetwork

In [2]:
dataset = HSD(
    batch_size=128,
    encoding_dim=100,
    num_workers=4,
    download_dir="./data",
)
dataset.setup()
trainloader = dataset.train_dataloader()
testloader = dataset.val_dataloader()

In [None]:
next(iter(trainloader))[0].shape

In [3]:
from tqdm.auto import tqdm

def cycle_through_trainloader():
    for data, targets in tqdm(trainloader):
        data = data.cuda()
        targets = targets.cuda()

In [None]:
cycle_through_trainloader()

In [4]:
dict_args = dict(
    encoding_dim=100,
    n_hidden_layers=2,
    hidden_dim=128,
    tau_mem=100000.0,
    output_dim=20,
    spike_threshold=1.,
    learning_rate=1e-3,
    width_grad=1.,
    scale_grad=1.,
    decoding_func='max_over_time',
    optimizer="sgd",
)
slayer_model = SlayerNetwork(**dict_args, n_time_bins=250).cuda()
init_weights = slayer_model.state_dict()

exodus_model = ExodusNetwork(**dict_args, slayer_mode=True, init_weights=init_weights).cuda()

# sinabs_model = ExodusNetwork(**dict_args, init_weights=init_weights, backend='sinabs').cuda()


In [5]:
for m in exodus_model.network.modules():
    if hasattr(m, "does_spike"):
        m.tau_mem.requires_grad_(False)

In [6]:
from time import time
import numpy as np

algorithms = ["EXODUS", "SLAYER", "BPTT"]
times = {algo: {"forward": [], "backward": [], "reset": []} for algo in algorithms}

for algo, model in zip(["EXODUS"], [exodus_model]):
    for i in tqdm(range(3)):
        times_epoch = {"forward": [], "backward": [], "reset": []}
        for data, target in tqdm(trainloader):        
            data = data.cuda()
            target = target.cuda()
            t0 = time()
            model.reset_states()
            t1 = time()
            y_hat = model(data)
            t2 = time()
            y_hat.sum().backward()
            t3 = time()
            times_epoch["reset"].append(t1-t0)
            times_epoch["forward"].append(t2-t1)
            times_epoch["backward"].append(t3-t2)
        for step, t in times_epoch.items():
            times[algo][step].append(np.mean(t))

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f71a871b940>
Traceback (most recent call last):
  File "/home/felix/.pyenv/versions/3.8.10/envs/exodus/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/home/felix/.pyenv/versions/3.8.10/envs/exodus/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/home/felix/.pyenv/versions/3.8.10/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f71a871b940>
Traceback (most recent call last):
  File "/home/felix/.pyenv/versions/3.8.10/envs/exodus/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/home/fel

  0%|          | 0/63 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f71a871b940>
Traceback (most recent call last):
  File "/home/felix/.pyenv/versions/3.8.10/envs/exodus/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/home/felix/.pyenv/versions/3.8.10/envs/exodus/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/home/felix/.pyenv/versions/3.8.10/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f71a871b940>
Traceback (most recent call last):
  File "/home/felix/.pyenv/versions/3.8.10/envs/exodus/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/home/fel

In [7]:
times # In slayer mode

{'EXODUS': {'forward': [0.02937603375268361,
   0.029612980191669767,
   0.030455755808996777],
  'backward': [0.20491660208929152, 0.2096554892403739, 0.20776023183550155],
  'reset': [0.00027866590590704056,
   0.0002933343251546224,
   0.00028858487568204363]},
 'SLAYER': {'forward': [], 'backward': [], 'reset': []},
 'BPTT': {'forward': [], 'backward': [], 'reset': []}}

In [11]:
times # In normal mode

{'EXODUS': {'forward': [0.028515297269064283,
   0.029192508213103763,
   0.029273180734543575],
  'backward': [0.03350253332228888, 0.034576832302032956, 0.03430612503536164],
  'reset': [0.00034044659326946925,
   0.00036979100060841394,
   0.0003590016137985956]},
 'SLAYER': {'forward': [], 'backward': [], 'reset': []},
 'BPTT': {'forward': [], 'backward': [], 'reset': []}}

In [7]:
times # In previous (torch-based) slayer mode

{'EXODUS': {'forward': [0.029698534617348324,
   0.02905493312411838,
   0.02906798937964061],
  'backward': [0.047602104762243844,
   0.04563159791250077,
   0.045841391124422586],
  'reset': [0.0003617226131378658,
   0.00035459654671805244,
   0.00032133147830054875]},
 'SLAYER': {'forward': [], 'backward': [], 'reset': []},
 'BPTT': {'forward': [], 'backward': [], 'reset': []}}

In [None]:
for i in tqdm(range(3)):
    t0 = time()
    times_epoch = {"forward": [], "backward": [], "reset": []}
    for data, target in tqdm(trainloader):
        data = data.cuda()
        target = target.cuda()
        t0 = time()
        y_hat = slayer_model(data)
        t1 = time()
        y_hat.sum().backward()
        t2 = time()
        times_epoch["forward"].append(t1-t0)
        times_epoch["backward"].append(t2-t1)
    for step, t in times_epoch.items():
        times["SLAYER"][step].append(np.mean(t))

In [None]:
import numpy as np
for action in times["EXODUS"].keys():
    print(action)
    for model, ts in times.items():
        t = np.array(ts[action])
        print(f"\t{model}: ({np.mean(t):.2e} +- {np.std(t):.2e}) s")
        # np.save(f"timings_{model}.npy", t)

In [None]:
import pandas as pd

In [None]:
pd.DataFrame(times).to_csv("times_new.csv")

In [None]:
# Convert previously saved csv to one line per measurement

times = pd.read_csv("times_new.csv", index_col=0)

table = [
    pd.DataFrame(
        {"algorithm": algo, "time": t, "step": step}
        for step in times.index
        for algo in times.loc[step].index
        for t in eval((times.loc[step].loc[algo]).replace("nan, ", "").replace("nan", ""))
    )
]
table = pd.concat(table, ignore_index=True)
table.to_csv("times_new_table.csv")