In [2]:
from hsd import HSD
from hsd_exodus import ExodusNetwork
from hsd_slayer import SlayerNetwork

In [3]:
dataset = HSD(
    batch_size=128,
    encoding_dim=100,
    num_workers=4,
    download_dir="./data",
)
dataset.setup()
trainloader = dataset.train_dataloader()
testloader = dataset.val_dataloader()

In [4]:
next(iter(trainloader))[0].shape

torch.Size([128, 250, 100])

In [5]:
from tqdm.auto import tqdm

def cycle_through_trainloader():
    for data, targets in tqdm(trainloader):
        data = data.cuda()
        targets = targets.cuda()

In [7]:
cycle_through_trainloader()

  0%|          | 0/63 [00:00<?, ?it/s]

In [8]:
dict_args = dict(
    encoding_dim=100,
    n_hidden_layers=2,
    hidden_dim=128,
    tau_mem=100000.0,
    output_dim=20,
    spike_threshold=1.,
    learning_rate=1e-3,
    width_grad=1.,
    scale_grad=1.,
    decoding_func='max_over_time',
)
slayer_model = SlayerNetwork(**dict_args, n_time_bins=250).cuda()
init_weights = slayer_model.state_dict()

exodus_model = ExodusNetwork(**dict_args, init_weights=init_weights).cuda()

sinabs_model = ExodusNetwork(**dict_args, init_weights=init_weights, backend='sinabs').cuda()


In [8]:
for model in [sinabs_model, exodus_model]:
    for data, target in tqdm(trainloader):
        data = data.cuda()
        target = target.cuda()
        model.reset_states()
        y_hat = model(data)
        y_hat.sum().backward()

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

In [9]:
for data, target in tqdm(trainloader):
    data = data.cuda()
    target = target.cuda()
    y_hat = slayer_model(data)
    y_hat.sum().backward()

  0%|          | 0/63 [00:00<?, ?it/s]