In [1]:
from nmnist_exodus import ExodusNetwork
from nmnist_slayer import SlayerNetwork
import torch
import matplotlib.pyplot as plt
import torch.nn as nn

In [19]:
batch_size = 4
learning_rate = 1e-3
architecture = 'paper'
tau_mem = 10000.
spike_threshold = 0.1
width_grad = 1.
scale_grad = 1. 
n_time_bins = 100


slayer_model = SlayerNetwork(
    tau_mem=tau_mem,
    spike_threshold=spike_threshold,
    learning_rate=learning_rate,
    n_time_bins=n_time_bins,
    architecture=architecture,
).cuda()

init_weights = slayer_model.state_dict()

exodus_model = ExodusNetwork(
    batch_size=batch_size,
    tau_mem=tau_mem,
    spike_threshold=spike_threshold,
    learning_rate=learning_rate,
    architecture=architecture,
    width_grad=width_grad,
    scale_grad=scale_grad,
    init_weights=init_weights,
).cuda()


In [20]:
dummy_input = (torch.rand((batch_size, n_time_bins, 2, 34, 34)) < 0.3).float().cuda()
dummy_input.sum()

tensor(276779., device='cuda:0')

In [21]:
exodus_model.reset_states()
exodus_output = exodus_model(dummy_input)
print(f"shape: {exodus_output.shape}, sum: {exodus_output.sum()}")

shape: torch.Size([4, 100, 10]), sum: 1312.61474609375


In [22]:
slayer_output = slayer_model(dummy_input)
print(f"shape: {slayer_output.shape}, sum: {slayer_output.sum()}")

shape: torch.Size([4, 100, 10]), sum: 1312.6156005859375
