In [1]:
from ssc_exodus import ExodusNetwork
from ssc_slayer import SlayerNetwork
import torch
import matplotlib.pyplot as plt
import torch.nn as nn

In [2]:
batch_size = 4
kw_args = dict(
    learning_rate = 1e-3,
    tau_mem = 100000.,
    spike_threshold = 1,
    width_grad = 1.,
    scale_grad = 1.,
    n_hidden_layers = 2,
    encoding_dim = 100,
    hidden_dim = 128,
    output_dim = 35,
    optimizer = "sgd",
    decoding_func = 'max_over_time',
    grad_mode = True
)
n_time_bins = 250


slayer_model = SlayerNetwork(
    **kw_args,
    n_time_bins=n_time_bins,
).cuda()

init_weights = slayer_model.state_dict()

exodus_model = ExodusNetwork(
    **kw_args,
    init_weights=init_weights,
).cuda()

In [3]:
dummy_input = (torch.rand((batch_size, n_time_bins, kw_args['encoding_dim'])) < 0.3).float().cuda()
dummy_input.sum()

tensor(29765., device='cuda:0')

In [4]:
exodus_model.reset_states()
exodus_output = exodus_model(dummy_input)
print(f"shape: {exodus_output.shape}, sum: {exodus_output.sum()}")

Output: 97.92607116699219 0.002797887660562992 0.04841223359107971
shape: torch.Size([4, 250, 35]), sum: 97.92607116699219


In [5]:
slayer_output = slayer_model(dummy_input)
print(f"shape: {slayer_output.shape}, sum: {slayer_output.sum()}")

Output: 97.92607116699219 0.002797887660562992 0.04841223359107971
shape: torch.Size([4, 250, 35]), sum: 97.92607116699219


In [6]:
from ssc import SSC

data = SSC(
    batch_size=128,
    encoding_dim=100,
    num_workers=4,
    download_dir="./data",
    shuffle=False,
)
data.prepare_data()
data.setup()
dataloader = data.train_dataloader()
inp, *__ = next(iter(dataloader))

In [7]:
exodus_model.reset_states()
exodus_output = exodus_model(inp.cuda())
print(f"shape: {exodus_output.shape}, sum: {exodus_output.sum()}")

Output: 4772.669921875 0.004261312540620565 0.06951548904180527
shape: torch.Size([128, 250, 35]), sum: 4772.669921875


In [8]:
slayer_output = slayer_model(inp.cuda())
print(f"shape: {slayer_output.shape}, sum: {slayer_output.sum()}")

Output: 4772.66943359375 0.004261312074959278 0.06951487064361572
shape: torch.Size([128, 250, 35]), sum: 4772.66943359375
