In [53]:
import time

import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, layer, surrogate
from norse.torch.module.lif import LIFCell

from src.snn.block.blocks import Blocks
from src.snn.snn import SNN

## Setting up the different implementations

Network benchmarked: 200 input units -> 100 spiking units over 1000 simulation steps using a batch size of 128.

In [93]:
def time_jelly():
    input_tensor = torch.rand(128, 200, 1000).cuda()
    
    jelly_layer = nn.Sequential(
        layer.Linear(200, 100, bias=False),
        neuron.LIFNode(tau=100.0, surrogate_function=surrogate.ATan())
    ).cuda()
    
    start_time = time.time()
    
    for t in range(1000):
        out = jelly_layer(input_tensor[:, :, t])
        
    end_time = time.time()
    return end_time - start_time

In [94]:
def time_norse():
    input_tensor = torch.rand(128, 200, 1000).cuda()
    
    norse_layer = nn.Sequential(
        layer.Linear(200, 100, bias=False),
        LIFCell()
    ).cuda()
    
    start_time = time.time()
    
    for t in range(1000):
        out = norse_layer(input_tensor[:, :, t])
    
    end_time = time.time()
    
    return end_time - start_time

In [95]:
def time_blocks():
    input_tensor = torch.rand(128, 200, 1000).cuda()
    
    blocks_snn = Blocks(200, 100, 1, 1000, t_latency=50, recurrent=False, init_beta=0.99, init_p=0.99).cuda()
    start_time = time.time()
    out = blocks_snn(input_tensor)
    end_time = time.time()
    
    return end_time - start_time
    
def time_standard():
    input_tensor = torch.rand(128, 200, 1000).cuda()
    
    blocks_snn = SNN(200, 100, 1, 1000, t_latency=1, recurrent=False, init_beta=0.99, init_p=0.99).cuda()
    start_time = time.time()
    out = blocks_snn(input_tensor)
    end_time = time.time()
    
    return end_time - start_time

## Benchmarking the differnet implementations

In [97]:
print(f"Norse={time_norse()}")
print(f"Jelly={time_jelly()}")
print(f"Standard={time_standard()}")
print(f"Blocks={time_blocks()}")

Norse=0.3080105781555176
Jelly=0.1869184970855713
Standard=0.3317074775695801
Blocks=0.016329288482666016
