In [None]:
import copy, torch, torch.nn as nn, torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import snntorch as snn
from snntorch import spikegen
from tqdm import tqdm
import numpy as np

**parameters**

In [2]:
num_in, num_hid, num_out = 784, 256, 10
beta, T, bs, epochs, lr = 1, 50, 128, 3, 1e-3

**device**

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)


<torch._C.Generator at 0x1f2fca6fc90>

**arch network**

In [46]:

class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1  = nn.Linear(num_in , num_hid, bias=False)
        self.lif1 = snn.Leaky(beta=beta,threshold=0.8)
        self.fc2  = nn.Linear(num_hid, num_out, bias=False)
        self.lif2 = snn.Leaky(beta=beta,threshold=0.8)
    def forward(self, x):
        x = x.to(self.fc1.weight.dtype)
        mem1 = self.lif1.init_leaky().to(device=x.device, dtype=x.dtype)
        mem2 = self.lif2.init_leaky().to(device=x.device, dtype=x.dtype)
        out  = []
        for t in range(x.size(0)):
            spk1, mem1 = self.lif1(self.fc1(x[t]), mem1)
            spk1 = spk1.to(dtype=self.fc2.weight.dtype)        #
            spk2, mem2 = self.lif2(self.fc2(spk1), mem2)
            out.append(spk2)
        return torch.stack(out)                                # [T,B,10]

In [47]:
tr = transforms.Compose([transforms.ToTensor()])
train_ds = torchvision.datasets.MNIST("data", True , download=True, transform=tr)
test_ds  = torchvision.datasets.MNIST("data", False, download=True, transform=tr)
train_ld = DataLoader(train_ds, bs, True , num_workers=2, pin_memory=True)
test_ld  = DataLoader(test_ds , bs, False, num_workers=2, pin_memory=True)

In [48]:
net = SNN().to(device)
opt = torch.optim.Adam(net.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss().to(device)

In [49]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(),lr=0.001)

**train phase**

In [50]:
def train(net,train_ld,epochs):
    for ep in range(1, epochs + 1):
        net.train()
        for imgs, lbls in tqdm(train_ld, desc=f"train {ep}/{epochs}"):
            imgs = imgs.to(device).view(imgs.size(0), -1) 
            lbls = lbls.to(device)
            spk  = spikegen.rate(imgs, T).to(device)
            loss = loss_fn(net(spk).sum(0), lbls)
            opt.zero_grad(); loss.backward(); opt.step()

In [51]:
train(net,train_ld,epochs)

train 1/3: 100%|██████████| 469/469 [03:28<00:00,  2.25it/s]
train 2/3: 100%|██████████| 469/469 [02:45<00:00,  2.84it/s]
train 3/3: 100%|██████████| 469/469 [01:55<00:00,  4.06it/s]


**test phase**

In [9]:
def test(model,loader):
    @torch.no_grad()
    def evaluate(model, loader, dtype):
        model.eval()
        correct = total = 0
        for imgs, lbls in loader:
            imgs = (imgs.to(device).view(imgs.size(0), -1) ).to(dtype)
            spk  = spikegen.rate(imgs, T).to(dtype).to(device)
            preds = model(spk).sum(0).argmax(1)
            correct += (preds == lbls.to(device)).sum().item()
            total   += lbls.size(0)
        return 100 * correct / total


    acc32 = evaluate(net, test_ld, torch.float32)


    net_fp16 = SNN().to(device)
    net_fp16.load_state_dict(net.state_dict())
    net_fp16.half().eval()
    acc16 = evaluate(net_fp16, test_ld, torch.float16)

    print(f"Accuracy FP32 : {acc32:.2f}%")
    print(f"Accuracy FP16 : {acc16:.2f}%")

create a test data spike train for cpp program

In [11]:
imgs, lbls = next(iter(test_ld))

img = imgs[1].to(device)
label = lbls[1].item()

img_flat = img.view(-1)


spk = spikegen.rate(img_flat.unsqueeze(0), T).squeeze(1) 

print("Label:", label)
print("Spike array shape:", spk.shape)  # [T, 784]
print("Spike array (as int):")
print(spk.int())

spk_t = spk.int().cpu().tolist()  

with open("C:\\Users\\Lenovo\\Desktop\\shayan\\QNN\\HW\\FP32\\spikes.cpp", "w") as f:
    f.write("#include <vector>\n")
    f.write("#include \"spikes.h\"\n")
    f.write(f"int label = {label};\n")
    f.write("std::vector<std::vector<int>> spike_input = {\n")
    
    for i, row in enumerate(spk_t): 
        row_str = ", ".join(str(x) for x in row)
        comma = "," if i < len(spk_t) - 1 else ""
        f.write(f"    {{{row_str}}}{comma}\n")
    f.write("};\n")


Label: 2
Spike array shape: torch.Size([50, 784])
Spike array (as int):
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.int32)


create dataset for test with cpp

In [12]:
# num_samples = 1000  

# all_spikes = []
# all_labels = []

# for batch_imgs, batch_lbls in test_ld:
#     for i in range(len(batch_imgs)):
#         if len(all_labels) >= num_samples:
#             break

#         img = batch_imgs[i].to(device)
#         label = batch_lbls[i].item()
#         img_flat = img.view(-1)

#         spk = spikegen.rate(img_flat.unsqueeze(0), T).squeeze(1)  
#         spk_np = spk.int().cpu().tolist()  

#         all_spikes.append(spk_np)         
#         all_labels.append(label)          

#     if len(all_labels) >= num_samples:
#         break

In [13]:


#     # with open("C:\\Users\\Lenovo\\Desktop\\QNN\\HW\\FP32\\TestDataset.cpp", "w") as f:
# #     f.write("#include <vector>\n")
# #     f.write("#include \"TestDataset.h\"\n")
# #     f.write("std::vector<int> labels = {\n")
# #     f.write("    " + ", ".join(str(lbl) for lbl in all_labels) + "\n};\n\n")

# #     f.write("std::vector<std::vector<std::vector<int>>> dataset_spikes = {\n")
# #     for i, sample in enumerate(all_spikes):
# #         f.write("    {\n")
# #         for j, timestep in enumerate(sample):
# #             row_str = ", ".join(str(x) for x in timestep)
# #             comma = "," if j < len(sample) - 1 else ""
# #             f.write(f"        {{{row_str}}}{comma}\n")
# #         comma_sample = "," if i < len(all_spikes) - 1 else ""
# #         f.write(f"    }}{comma_sample}\n")
# #     f.write("};\n")



datatest with txt

In [27]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
T = 50
num_samples = 10000

label_file = open("C:\\Users\\Lenovo\\Desktop\\shayan\\QNN\\hls\\labels.txt", "w")
spike_file = open("C:\\Users\\Lenovo\\Desktop\\shayan\\QNN\\hls\\spikes.txt", "w")

with torch.no_grad():
    count = 0
    for batch_imgs, batch_lbls in test_ld:
        for i in range(len(batch_imgs)):
            if count >= num_samples:
                break

            img = batch_imgs[i].to(device)
            label = batch_lbls[i].item()

            img_flat = img.view(-1).unsqueeze(0)                 # [1, 784]
            spk = spikegen.rate(img_flat, T)                     # [T, 1, 784]
            spk = spk.squeeze(1).int().cpu()                     # [T, 784]

            if spk.shape != (T, 784):
                raise ValueError(f"Expected shape [50, 784], got {spk.shape}")

            spike_file.write(f"# sample {count}\n")
            for timestep in spk:
                line = " ".join(str(bit.item()) for bit in timestep)
                spike_file.write(line + "\n")

            label_file.write(f"{label}\n")
            count += 1

        if count >= num_samples:
            break

spike_file.close()
label_file.close()


In [12]:
net.load_state_dict(torch.load('model_weights.pth', map_location=device))

<All keys matched successfully>

In [52]:
test(net,test_ld)

Accuracy FP32 : 96.07%
Accuracy FP16 : 96.15%


In [41]:
torch.set_printoptions(precision=30)

In [53]:
net_fp16 = SNN()

In [54]:
net_fp16.load_state_dict(net.state_dict())

net_fp16 = net_fp16.half()

In [55]:
net.fc1.weight

Parameter containing:
tensor([[ 0.010213527828454971313476562500,  0.034269195050001144409179687500,
          0.032603953033685684204101562500,  ...,
          0.014157786965370178222656250000, -0.016394995152950286865234375000,
         -0.006290988996624946594238281250],
        [-0.035448536276817321777343750000, -0.005012039095163345336914062500,
         -0.024140700697898864746093750000,  ...,
          0.029441650956869125366210937500,  0.019238635897636413574218750000,
         -0.001913335174322128295898437500],
        [-0.030705563724040985107421875000,  0.005754742771387100219726562500,
          0.031043078750371932983398437500,  ...,
          0.020431201905012130737304687500, -0.004949940368533134460449218750,
         -0.016702154651284217834472656250],
        ...,
        [-0.023359090089797973632812500000,  0.009700365364551544189453125000,
          0.030937489122152328491210937500,  ...,
          0.014517899602651596069335937500,  0.028145592659711837768554687500

In [56]:
net_fp16.fc1.weight

Parameter containing:
tensor([[ 0.010215759277343750000000000000,  0.034271240234375000000000000000,
          0.032592773437500000000000000000,  ...,
          0.014160156250000000000000000000, -0.016387939453125000000000000000,
         -0.006290435791015625000000000000],
        [-0.035461425781250000000000000000, -0.005012512207031250000000000000,
         -0.024139404296875000000000000000,  ...,
          0.029434204101562500000000000000,  0.019241333007812500000000000000,
         -0.001913070678710937500000000000],
        [-0.030700683593750000000000000000,  0.005756378173828125000000000000,
          0.031036376953125000000000000000,  ...,
          0.020431518554687500000000000000, -0.004951477050781250000000000000,
         -0.016708374023437500000000000000],
        ...,
        [-0.023361206054687500000000000000,  0.009696960449218750000000000000,
          0.030944824218750000000000000000,  ...,
          0.014518737792968750000000000000,  0.028152465820312500000000000000

In [None]:
# torch.save(net.state_dict(),'model_weights.pth')

Write weights with format cpp in weights.cpp

with 6 digits

In [None]:
# net.eval()
# with open("C:\\Users\\Lenovo\\Desktop\\shayan\\QNN\\HW\\FP32\\weightst.cpp", "w") as cpp_file:
#     cpp_file.write("#include <vector>\n\n")
#     cpp_file.write("#include \"weights.h\"\n\n")

#     for name, param in net.state_dict().items():
#         cpp_name = name.replace(".", "_")
#         data = param.cpu().numpy()
#         shape = data.shape

#         if "weight" in name and len(shape) == 2:
#             data = data.T  # Transpose to [input][output]
#             shape = data.shape

#             cpp_file.write(f"// shape: {shape} // Transposed\n")
#             cpp_file.write(f"std::vector<std::vector<float>> {cpp_name} = {{\n")

#             for row in data:
#                 row_str = ", ".join(f"{v:.6f}" for v in row)
#                 cpp_file.write(f"    {{{row_str}}},\n")

#             cpp_file.write("};\n\n")

#         elif "bias" in name and len(shape) == 1:
#             cpp_file.write(f"// shape: {shape}\n")
#             cpp_file.write(f"std::vector<float> {cpp_name} = {{\n    ")
#             cpp_file.write(", ".join(f"{v:.6f}" for v in data))
#             cpp_file.write("\n};\n\n")


with 25 digits

In [58]:
net.eval()
with open("C:\\Users\\Lenovo\\Desktop\\shayan\\QNN\\HW\\FP32\\weights_.cpp", "w") as cpp_file:
    cpp_file.write("#include <vector>\n\n")
    cpp_file.write("#include \"weights.h\"\n\n")

    for name, param in net.state_dict().items():
        cpp_name = name.replace(".", "_")
        data = param.cpu().numpy()
        shape = data.shape

        if "weight" in name and len(shape) == 2:
            data = data.T  # Transpose to [input][output]
            shape = data.shape

            cpp_file.write(f"// shape: {shape} // Transposed\n")
            cpp_file.write(f"const std::vector<std::vector<float>> {cpp_name} = {{\n")

            for row in data:
                row_str = ", ".join(f"{v:.25f}" for v in row)  
                cpp_file.write(f"    {{{row_str}}},\n")

            cpp_file.write("};\n\n")

        elif "bias" in name and len(shape) == 1:
            cpp_file.write(f"// shape: {shape}\n")
            cpp_file.write(f"const std::vector<float> {cpp_name} = {{\n    ")
            cpp_file.write(", ".join(f"{v:.25f}" for v in data)) 
            cpp_file.write("\n};\n\n")


In [59]:
net.lif1.threshold

tensor(0.800000011920928955078125000000)