In [1]:
import copy, torch, torch.nn as nn, torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import snntorch as snn
from snntorch import spikegen
from tqdm import tqdm

**parameters**

In [2]:
num_in, num_hid, num_out = 784, 256, 10
beta, T, bs, epochs, lr = 0.9, 50, 128, 3, 1e-3

**device**

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)


<torch._C.Generator at 0x294aadcfc90>

**arch network**

In [4]:

class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1  = nn.Linear(num_in , num_hid, bias=False)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2  = nn.Linear(num_hid, num_out, bias=False)
        self.lif2 = snn.Leaky(beta=beta)
    def forward(self, x):
        x = x.to(self.fc1.weight.dtype)
        mem1 = self.lif1.init_leaky().to(device=x.device, dtype=x.dtype)
        mem2 = self.lif2.init_leaky().to(device=x.device, dtype=x.dtype)
        out  = []
        for t in range(x.size(0)):
            spk1, mem1 = self.lif1(self.fc1(x[t]), mem1)
            spk1 = spk1.to(dtype=self.fc2.weight.dtype)        #
            spk2, mem2 = self.lif2(self.fc2(spk1), mem2)
            out.append(spk2)
        return torch.stack(out)                                # [T,B,10]

In [5]:
tr = transforms.Compose([transforms.ToTensor()])
train_ds = torchvision.datasets.MNIST("data", True , download=True, transform=tr)
test_ds  = torchvision.datasets.MNIST("data", False, download=True, transform=tr)
train_ld = DataLoader(train_ds, bs, True , num_workers=2, pin_memory=True)
test_ld  = DataLoader(test_ds , bs, False, num_workers=2, pin_memory=True)

In [6]:
net = SNN().to(device)
opt = torch.optim.Adam(net.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss().to(device)

In [7]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(),lr=0.001)

**train phase**

In [8]:
def train(net,train_ld,epochs):
    for ep in range(1, epochs + 1):
        net.train()
        for imgs, lbls in tqdm(train_ld, desc=f"train {ep}/{epochs}"):
            imgs = imgs.to(device).view(imgs.size(0), -1) 
            lbls = lbls.to(device)
            spk  = spikegen.rate(imgs, T).to(device)
            loss = loss_fn(net(spk).sum(0), lbls)
            opt.zero_grad(); loss.backward(); opt.step()

In [9]:
# train(net,train_ld,epochs)

**test phase**

In [10]:
def test(model,loader):
    @torch.no_grad()
    def evaluate(model, loader, dtype):
        model.eval()
        correct = total = 0
        for imgs, lbls in loader:
            imgs = (imgs.to(device).view(imgs.size(0), -1) ).to(dtype)
            spk  = spikegen.rate(imgs, T).to(dtype).to(device)
            preds = model(spk).sum(0).argmax(1)
            correct += (preds == lbls.to(device)).sum().item()
            total   += lbls.size(0)
        return 100 * correct / total


    acc32 = evaluate(net, test_ld, torch.float32)


    net_fp16 = SNN().to(device)
    net_fp16.load_state_dict(net.state_dict())
    net_fp16.half().eval()
    acc16 = evaluate(net_fp16, test_ld, torch.float16)

    print(f"Accuracy FP32 : {acc32:.2f}%")
    print(f"Accuracy FP16 : {acc16:.2f}%")

create a test data spike train for cpp program

In [11]:
# # imgs, lbls = next(iter(test_ld))

# # img = imgs[9].to(device)
# # label = lbls[9].item()

# # img_flat = img.view(-1)


# # spk = spikegen.rate(img_flat.unsqueeze(0), T).squeeze(1) 

# # print("Label:", label)
# # print("Spike array shape:", spk.shape)  # [T, 784]
# # print("Spike array (as int):")
# # print(spk.int())

# spk_t = spk.int().cpu().tolist()  

# with open("C:\\Users\\Lenovo\\Desktop\\QNN\\HW\\FP32\\spikes.cpp", "w") as f:
#     f.write("#include <vector>\n")
#     f.write("#include \"spikes.h\"\n")
#     f.write(f"int label = {label};\n")
#     f.write("std::vector<std::vector<int>> spike_input = {\n")
    
#     for i, row in enumerate(spk_t): 
#         row_str = ", ".join(str(x) for x in row)
#         comma = "," if i < len(spk_t) - 1 else ""
#         f.write(f"    {{{row_str}}}{comma}\n")
#     f.write("};\n")


create dataset for test with cpp

In [12]:
# num_samples = 1000  

# all_spikes = []
# all_labels = []

# for batch_imgs, batch_lbls in test_ld:
#     for i in range(len(batch_imgs)):
#         if len(all_labels) >= num_samples:
#             break

#         img = batch_imgs[i].to(device)
#         label = batch_lbls[i].item()
#         img_flat = img.view(-1)

#         spk = spikegen.rate(img_flat.unsqueeze(0), T).squeeze(1)  
#         spk_np = spk.int().cpu().tolist()  

#         all_spikes.append(spk_np)         
#         all_labels.append(label)          

#     if len(all_labels) >= num_samples:
#         break

In [13]:


#     # with open("C:\\Users\\Lenovo\\Desktop\\QNN\\HW\\FP32\\TestDataset.cpp", "w") as f:
# #     f.write("#include <vector>\n")
# #     f.write("#include \"TestDataset.h\"\n")
# #     f.write("std::vector<int> labels = {\n")
# #     f.write("    " + ", ".join(str(lbl) for lbl in all_labels) + "\n};\n\n")

# #     f.write("std::vector<std::vector<std::vector<int>>> dataset_spikes = {\n")
# #     for i, sample in enumerate(all_spikes):
# #         f.write("    {\n")
# #         for j, timestep in enumerate(sample):
# #             row_str = ", ".join(str(x) for x in timestep)
# #             comma = "," if j < len(sample) - 1 else ""
# #             f.write(f"        {{{row_str}}}{comma}\n")
# #         comma_sample = "," if i < len(all_spikes) - 1 else ""
# #         f.write(f"    }}{comma_sample}\n")
# #     f.write("};\n")



datatest with txt

In [20]:

num_samples = 100

all_labels = []
label_file = open("C:\\Users\\Lenovo\\Desktop\\shayan\\QNN\\HW\\FP32\\labels.txt", "w")
spike_file = open("C:\\Users\\Lenovo\\Desktop\\shayan\\QNN\\HW\\FP32\\spikes.txt", "w")

with torch.no_grad():
    count = 0
    for batch_imgs, batch_lbls in test_ld:
        for i in range(len(batch_imgs)):
            if count >= num_samples:
                break

            img = batch_imgs[i].to(device)
            label = batch_lbls[i].item()

            img_flat = img.view(-1)
            spk = spikegen.rate(img_flat.unsqueeze(0), T).squeeze(1)  # [T, 784]
            spk_int = spk.int().cpu().tolist()

            spike_file.write(f"# sample {count}\n")
            for timestep in spk_int:
                line = " ".join(str(x) for x in timestep)
                spike_file.write(line + "\n")

            label_file.write(f"{label}\n")

            count += 1
        if count >= num_samples:
            break

spike_file.close()
label_file.close()

In [15]:
net.load_state_dict(torch.load('model_weights.pth', map_location=device))

<All keys matched successfully>

In [16]:
test(net,test_ld)

Accuracy FP32 : 96.28%
Accuracy FP16 : 96.32%


In [17]:
# torch.save(net.state_dict(),'model_weights.pth')

Write weights with format cpp in weights.cpp

In [18]:
# net.eval()
# with open("C:\\Users\\Lenovo\\Desktop\\QNN\\HW\\FP32\\weights.cpp", "w") as cpp_file:
#     cpp_file.write("#include <vector>\n\n")
#     cpp_file.write("#include \"weights.h\"\n")

#     for name, param in net.state_dict().items():
#         cpp_name = name.replace(".", "_")
#         data = param.cpu().numpy()
#         shape = data.shape

#         cpp_file.write(f"// shape: {shape}\n")
#         cpp_file.write(f"std::vector<float> {cpp_name} = {{\n    ")
#         cpp_file.write(", ".join(f"{v:.6f}" for v in data.flatten()))
#         cpp_file.write("\n};\n\n")