In [2]:
import torch
import torch.nn as nn
import snntorch as snn
import snntorch.functional as SF
from snntorch import surrogate


In [3]:

# Define the network
class SimpleSNN(nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs, beta):
        super(SimpleSNN, self).__init__()
        
        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid())
        
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid())
    
    def forward(self, x):
        # Initialize membrane potentials
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        
        # First layer
        cur1 = self.fc1(x)
        spk1, mem1 = self.lif1(cur1, mem1)
        
        # Second layer
        cur2 = self.fc2(spk1)
        spk2, mem2 = self.lif2(cur2, mem2)
        
        return spk2, mem2


In [4]:

# Hyperparameters
num_inputs = 100  # Number of input features
num_hidden = 50   # Number of hidden neurons
num_outputs = 10  # Number of output neurons (e.g., for 10-class classification)
beta = 0.9        # Decay rate for the LIF neurons


In [5]:
device = torch.device('cuda:0')

In [6]:

# Instantiate the model
snn_model = SimpleSNN(num_inputs=num_inputs, num_hidden=num_hidden, num_outputs=num_outputs, beta=beta)


In [7]:
model = torch.nn.DataParallel(snn_model, device_ids=[0,1])