In [11]:
import copy, torch, torch.nn as nn, torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import snntorch as snn
from snntorch import spikegen
from tqdm import tqdm

**parameters**

In [12]:
num_in, num_hid, num_out = 784, 256, 10
beta, T, bs, epochs, lr = 0.9, 50, 128, 3, 1e-3

**device**

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)


<torch._C.Generator at 0x1216d2a8ad0>

**arch network**

In [15]:

class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1  = nn.Linear(num_in , num_hid, bias=False)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2  = nn.Linear(num_hid, num_out, bias=False)
        self.lif2 = snn.Leaky(beta=beta)
    def forward(self, x):
        x = x.to(self.fc1.weight.dtype)
        mem1 = self.lif1.init_leaky().to(device=x.device, dtype=x.dtype)
        mem2 = self.lif2.init_leaky().to(device=x.device, dtype=x.dtype)
        out  = []
        for t in range(x.size(0)):
            spk1, mem1 = self.lif1(self.fc1(x[t]), mem1)
            spk1 = spk1.to(dtype=self.fc2.weight.dtype)        #
            spk2, mem2 = self.lif2(self.fc2(spk1), mem2)
            out.append(spk2)
        return torch.stack(out)                                # [T,B,10]

In [16]:
tr = transforms.Compose([transforms.ToTensor()])
train_ds = torchvision.datasets.MNIST("data", True , download=True, transform=tr)
test_ds  = torchvision.datasets.MNIST("data", False, download=True, transform=tr)
train_ld = DataLoader(train_ds, bs, True , num_workers=2, pin_memory=True)
test_ld  = DataLoader(test_ds , bs, False, num_workers=2, pin_memory=True)

In [17]:
net = SNN().to(device)
opt = torch.optim.Adam(net.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss().to(device)

In [18]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(),lr=0.001)

**train phase**

In [21]:
def train(net,train_ld,epochs):
    for ep in range(1, epochs + 1):
        net.train()
        for imgs, lbls in tqdm(train_ld, desc=f"train {ep}/{epochs}"):
            imgs = imgs.to(device).view(imgs.size(0), -1) 
            lbls = lbls.to(device)
            spk  = spikegen.rate(imgs, T).to(device)
            loss = loss_fn(net(spk).sum(0), lbls)
            opt.zero_grad(); loss.backward(); opt.step()

In [20]:
train(net,train_ld,epochs)

train 1/3:  27%|██▋       | 126/469 [00:17<00:48,  7.01it/s]


KeyboardInterrupt: 

**test phase**

In [22]:
def test(model,loader):
    @torch.no_grad()
    def evaluate(model, loader, dtype):
        model.eval()
        correct = total = 0
        for imgs, lbls in loader:
            imgs = (imgs.to(device).view(imgs.size(0), -1) ).to(dtype)
            spk  = spikegen.rate(imgs, T).to(dtype).to(device)
            preds = model(spk).sum(0).argmax(1)
            correct += (preds == lbls.to(device)).sum().item()
            total   += lbls.size(0)
        return 100 * correct / total


    acc32 = evaluate(net, test_ld, torch.float32)


    net_fp16 = SNN().to(device)
    net_fp16.load_state_dict(net.state_dict())
    net_fp16.half().eval()
    acc16 = evaluate(net_fp16, test_ld, torch.float16)

    print(f"Accuracy FP32 : {acc32:.2f}%")
    print(f"Accuracy FP16 : {acc16:.2f}%")

In [25]:
test(net,test_ld)

Accuracy FP32 : 96.33%
Accuracy FP16 : 96.27%


In [15]:
torch.save(net.state_dict(),'model_weights.pth')

In [24]:
net.load_state_dict(torch.load('model_weights.pth', map_location=device))

<All keys matched successfully>

In [43]:
fc2w = net.fc2.weight.data

In [44]:
fc2w[9].shape

torch.Size([256])

In [None]:
net.eval()
with open("weights.cpp", "w") as cpp_file:
    cpp_file.write("#include <vector>\n\n")

    for name, param in net.state_dict().items():
        cpp_name = name.replace(".", "_")
        data = param.cpu().numpy()
        shape = data.shape

        cpp_file.write(f"// shape: {shape}\n")
        cpp_file.write(f"std::vector<float> {cpp_name} = {{\n    ")
        cpp_file.write(", ".join(f"{v:.6f}" for v in data.flatten()))
        cpp_file.write("\n};\n\n")