In [1]:
from smnist_exodus import ExodusNetwork
from smnist_slayer import SlayerNetwork
import torch
import matplotlib.pyplot as plt
import torch.nn as nn

In [2]:
batch_size = 4
learning_rate = 1e-3
n_hidden_layers = 3
tau_mem = 10000.
spike_threshold = 0.1
width_grad = 1.
scale_grad = 1. 
n_time_bins = 28*28
encoding_dim = 80


slayer_model = SlayerNetwork(
    tau_mem=tau_mem,
    spike_threshold=spike_threshold,
    learning_rate=learning_rate,
    n_hidden_layers=n_hidden_layers,
    width_grad=width_grad,
    scale_grad=scale_grad,
    encoding_dim=encoding_dim,
    hidden_dim=100,
    decoding_func='sum_loss',
    n_time_bins=n_time_bins,
).cuda()

init_weights = slayer_model.state_dict()

exodus_model = ExodusNetwork(
    tau_mem=tau_mem,
    spike_threshold=spike_threshold,
    learning_rate=learning_rate,
    n_hidden_layers=n_hidden_layers,
    width_grad=width_grad,
    scale_grad=scale_grad,
    encoding_dim=encoding_dim,
    hidden_dim=100,
    decoding_func='sum_loss',
    init_weights=init_weights,
).cuda()

In [3]:
dummy_input = (torch.rand((batch_size, n_time_bins, encoding_dim)) < 0.3).float().cuda()
dummy_input.sum()

tensor(75230., device='cuda:0')

In [4]:
exodus_model.reset_states()
exodus_output = exodus_model(dummy_input)
print(f"shape: {exodus_output.shape}, sum: {exodus_output.sum()}")

shape: torch.Size([4, 784, 10]), sum: -4498.5498046875


In [5]:
slayer_output = slayer_model(dummy_input)
print(f"shape: {slayer_output.shape}, sum: {slayer_output.sum()}")

shape: torch.Size([4, 784, 10]), sum: -4498.2841796875
