In [1]:
from hsd import HSD
from hsd_exodus import ExodusNetwork
from hsd_slayer import SlayerNetwork

In [2]:
dataset = HSD(
    batch_size=128,
    encoding_dim=100,
    num_workers=4,
    download_dir="./data",
)
dataset.setup()
trainloader = dataset.train_dataloader()
testloader = dataset.val_dataloader()

Downloading https://zenkelab.org/datasets/shd_train.h5.zip to ./data/SHD/shd_train.h5.zip


  0%|          | 0/130863613 [00:00<?, ?it/s]

Extracting ./data/SHD/shd_train.h5.zip to ./data/SHD
Downloading https://zenkelab.org/datasets/shd_test.h5.zip to ./data/SHD/shd_test.h5.zip


  0%|          | 0/38141465 [00:00<?, ?it/s]

Extracting ./data/SHD/shd_test.h5.zip to ./data/SHD


In [4]:
next(iter(trainloader))[0].shape

torch.Size([128, 250, 100])

In [3]:
from tqdm.auto import tqdm

def cycle_through_trainloader():
    for data, targets in tqdm(trainloader):
        data = data.cuda()
        targets = targets.cuda()

In [4]:
cycle_through_trainloader()

  0%|          | 0/63 [00:00<?, ?it/s]

In [5]:
dict_args = dict(
    encoding_dim=100,
    n_hidden_layers=2,
    hidden_dim=128,
    tau_mem=100000.0,
    output_dim=20,
    spike_threshold=1.,
    learning_rate=1e-3,
    width_grad=1.,
    scale_grad=1.,
    decoding_func='max_over_time',
)
slayer_model = SlayerNetwork(**dict_args, n_time_bins=250).cuda()
init_weights = slayer_model.state_dict()

exodus_model = ExodusNetwork(**dict_args, init_weights=init_weights).cuda()

sinabs_model = ExodusNetwork(**dict_args, init_weights=init_weights, backend='sinabs').cuda()


In [6]:
from time import time

times = {"EXODUS": [], "SLAYER": [], "BPTT": []}

for name, model in zip(["BPTT", "EXODUS"], [sinabs_model, exodus_model]):
    for i in tqdm(range(3)):
        t0 = time()
        for data, target in tqdm(trainloader):        
            data = data.cuda()
            target = target.cuda()
            model.reset_states()
            y_hat = model(data)
            y_hat.sum().backward()
        times[name].append(time()-t0)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fd33af38940>
Traceback (most recent call last):
  File "/home/felix/.pyenv/versions/3.8.10/envs/exodus/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
Exception ignored in:     self._shutdown_workers()
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fd33af38940>  File "/home/felix/.pyenv/versions/3.8.10/envs/exodus/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    
if w.is_alive():
Traceback (most recent call last):
  File "/home/felix/.pyenv/versions/3.8.10/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
  File "/home/felix/.pyenv/versions/3.8.10/envs/exodus/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
AssertionError: can only test a child process
    self._shutdown_workers()
  File "/home/fel

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fd33af38940>
Traceback (most recent call last):
  File "/home/felix/.pyenv/versions/3.8.10/envs/exodus/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/home/felix/.pyenv/versions/3.8.10/envs/exodus/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/home/felix/.pyenv/versions/3.8.10/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fd33af38940>
Traceback (most recent call last):
  File "/home/felix/.pyenv/versions/3.8.10/envs/exodus/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/home/fel

  0%|          | 0/63 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fd33af38940>
Traceback (most recent call last):
  File "/home/felix/.pyenv/versions/3.8.10/envs/exodus/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/home/felix/.pyenv/versions/3.8.10/envs/exodus/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/home/felix/.pyenv/versions/3.8.10/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fd33af38940>
Traceback (most recent call last):
  File "/home/felix/.pyenv/versions/3.8.10/envs/exodus/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/home/fel

In [7]:
for i in tqdm(range(3)):
    t0 = time()
    for data, target in tqdm(trainloader):
        data = data.cuda()
        target = target.cuda()
        y_hat = slayer_model(data)
        y_hat.sum().backward()
    times["SLAYER"].append(time()-t0)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

In [8]:
import numpy as np
for model, ts in times.items():
    t = np.array(ts)
    print(f"{model}: ({np.mean(t)} +- {np.std(t)}) s")
    # np.save(f"timings_{model}.npy", t)

EXODUS: (15.730019410451254 +- 0.06395552805859003) s
SLAYER: (16.082467158635456 +- 0.14825376463280038) s
BPTT: (40.355397860209145 +- 9.329345988191516) s


In [9]:
import pandas as pd

In [10]:
pd.DataFrame(times).to_csv("times.csv")