# Export input.txt and output.txt

In [1]:
import torch
import torchvision
from tqdm import tqdm
from spikingjelly.clock_driven import encoding

## For MNIST16

In [7]:
test_dataset = torchvision.datasets.MNIST(
        root="../Datasets/",
        train=False,
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize(16),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.1307,), (0.3081,))
        ])
    )
test_data_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=100,
        shuffle=False
    )

In [8]:
f1 = open('../Datasets/MNIST_16/input.txt', 'w')
f2 = open('../Datasets/MNIST_16/output.txt', 'w')

encoder = encoding.PoissonEncoder();

for img, label in tqdm(test_data_loader):
    for i in range(100):
        print(label[i].tolist(), file = f2)
        T = 999;
        total_spike = encoder(img).float()
        for t in range(T):
            total_spike += encoder(img).float()
        for j in range(16):
            for k in range(16): 
                print(int(total_spike[i][0][j][k].tolist()), file = f1, end = ' ')
        print("", file = f1)
f1.close()
f2.close()

100%|█████████████████████████████████████████████████████████████████████████████████| 100/100 [37:10<00:00, 22.30s/it]


## For MNIST32

In [9]:
test_dataset = torchvision.datasets.MNIST(
        root="../Datasets/",
        train=False,
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize(32),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.1307,), (0.3081,))
        ])
    )
test_data_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=100,
        shuffle=False
    )

In [10]:
f1 = open('../Datasets/MNIST_32/input.txt', 'w')
f2 = open('../Datasets/MNIST_32/output.txt', 'w')

encoder = encoding.PoissonEncoder();

for img, label in tqdm(test_data_loader):
    for i in range(100):
        print(label[i].tolist(), file = f2)
        T = 999;
        total_spike = encoder(img).float()
        for t in range(T):
            total_spike += encoder(img).float()
        for j in range(32):
            for k in range(32): 
                print(int(total_spike[i][0][j][k].tolist()), file = f1, end = ' ')
        print("", file = f1)
f1.close()
f2.close()

100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [2:26:05<00:00, 87.65s/it]


## For Fashion-MNIST

In [11]:
test_dataset = torchvision.datasets.FashionMNIST(
        root="../Datasets/",
        train=False,
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.1307,), (0.3081,))
        ])
    )
test_data_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=100,
        shuffle=False
    )

In [12]:
f1 = open('../Datasets/FashionMNIST/input.txt', 'w')
f2 = open('../Datasets/FashionMNIST/output.txt', 'w')

encoder = encoding.PoissonEncoder();

for img, label in tqdm(test_data_loader):
    for i in range(100):
        print(label[i].tolist(), file = f2)
        T = 999;
        total_spike = encoder(img).float()
        for t in range(T):
            total_spike += encoder(img).float()
        for j in range(28):
            for k in range(28): 
                print(int(total_spike[i][0][j][k].tolist()), file = f1, end = ' ')
        print("", file = f1)
f1.close()
f2.close()

100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [1:54:51<00:00, 68.91s/it]


## For SVHN

In [13]:
test_dataset = torchvision.datasets.SVHN(
        root="../Datasets/",
        split="test",
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    )
test_data_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=100,
        shuffle=False
    )

Using downloaded and verified file: ../Datasets/test_32x32.mat


In [16]:
f1 = open('../Datasets/SVHN/input.txt', 'w')
f2 = open('../Datasets/SVHN/output.txt', 'w')

cnt = 0;
for img, label in tqdm(test_data_loader):
    cnt += 1
    rg = 100
    if (cnt == 261):
        rg = 32
    for i in range(rg):
        print(label[j].tolist(), file = f2)
        T = 99;
        total_spike = encoder(img).float()
        for t in range(T):
            total_spike += encoder(img).float()
        for l in range(3):
            for j in range(32):
                for k in range(32):
                    print(int(total_spike[i][l][j][k].tolist() * 10), file = f1, end = ' ')
    
        print("", file = f1)

f1.close()
f2.close()

100%|███████████████████████████████████████████████████████████████████████████████| 261/261 [1:36:59<00:00, 22.30s/it]


##  For CIFAR-10

In [17]:
test_dataset = torchvision.datasets.CIFAR10(
        root="../Datasets/",
        train=False,
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    )
test_data_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=100,
        shuffle=False
    )

Files already downloaded and verified


In [18]:
f1 = open('../Datasets/CIFAR10/input.txt', 'w')
f2 = open('../Datasets/CIFAR10/output.txt', 'w')

for img, label in tqdm(test_data_loader):
    for i in range(100):
        print(label[j].tolist(), file = f2)
        T = 99;
        total_spike = encoder(img).float()
        for t in range(T):
            total_spike += encoder(img).float()
        for l in range(3):
            for j in range(32):
                for k in range(32):
                    print(int(total_spike[i][l][j][k].tolist() * 10), file = f1, end = ' ')
    
        print("", file = f1)

f1.close()
f2.close()

100%|█████████████████████████████████████████████████████████████████████████████████| 100/100 [36:44<00:00, 22.04s/it]


## For EuroSAT

In [19]:
all_dataset = torchvision.datasets.EuroSAT(
        root="../Datasets/",
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    )

train_dataset, test_dataset = torch.utils.data.random_split(all_dataset, [21600, 5400])

test_data_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=100,
        shuffle=False
)

In [20]:
f1 = open('../Datasets/EuroSAT_64/input.txt', 'w')
f2 = open('../Datasets/EuroSAT_64/output.txt', 'w')

for img, label in tqdm(test_data_loader):
    for i in range(100):
        print(label[j].tolist(), file = f2)
        T = 99;
        total_spike = encoder(img).float()
        for t in range(T):
            total_spike += encoder(img).float()
        for l in range(3):
            for j in range(64):
                for k in range(64):
                    print(int(total_spike[i][l][j][k].tolist() * 10), file = f1, end = ' ')
    
        print("", file = f1)

f1.close()
f2.close()

100%|█████████████████████████████████████████████████████████████████████████████████| 54/54 [1:24:48<00:00, 94.23s/it]


# Extract Weights of Model

## Fashion-MNIST

In [21]:
# from lif_snn_fashion_mnist.ckpt 784 * 196 * 10
model = torch.load("../Models/lif_snn_fashion_mnist.ckpt")

w_12 = model[1].weight.detach().cpu().numpy()
w_23 = model[3].weight.detach().cpu().numpy()

f = open('../Models/FashionMNIST/weights.txt', 'w')

tmp = 784
tmp_1 = 0
for i in range(784):
    print(i + tmp_1, file = f, end = " ")
    for j in range(196):
        print(j + tmp, file = f, end = " ");
        print(w_12[j][i], file = f, end = " ");
    print("", file = f)
        
tmp = 784 + 196
tmp_1 = 784
for i in range(196):
    print(i + tmp_1, file = f, end = " ")
    for j in range(10):
        print(j + tmp, file = f, end = " ");
        print(w_23[j][i], file = f, end = " ");
    print("", file = f)
for i in range(10):
    print(tmp + i, file = f)

f.close()

## SVHN

In [22]:
# from lif_snn_SVHN.ckpt 3072 * 192 * 10
model = torch.load("../Models/lif_snn_SVHN.ckpt")

w_12 = model[1].weight.detach().cpu().numpy()
w_23 = model[3].weight.detach().cpu().numpy()

f = open('../Models/SVHN/weights.txt', 'w')

tmp = 3072
tmp_1 = 0
for i in range(3072):
    print(i + tmp_1, file = f, end = " ")
    for j in range(192):
        print(j + tmp, file = f, end = " ");
        print(w_12[j][i], file = f, end = " ");
    print("", file = f)
        
tmp = 3072 + 192
tmp_1 = 3072
for i in range(192):
    print(i + tmp_1, file = f, end = " ")
    for j in range(10):
        print(j + tmp, file = f, end = " ");
        print(w_23[j][i], file = f, end = " ");
    print("", file = f)
for i in range(10):
    print(tmp + i, file = f)

f.close()

## CIFAR-10

In [2]:
# from lif_snn_cifar10.ckpt 3072 * 384 * 10
model = torch.load("../Models/lif_snn_cifar10.ckpt")

w_12 = model[1].weight.detach().cpu().numpy()
w_23 = model[3].weight.detach().cpu().numpy()

f = open('../Models/CIFAR10/weights.txt', 'w')

tmp = 3072
tmp_1 = 0
for i in range(3072):
    print(i + tmp_1, file = f, end = " ")
    for j in range(384):
        print(j + tmp, file = f, end = " ");
        print(w_12[j][i], file = f, end = " ");
    print("", file = f)
        
tmp = 3072 + 384
tmp_1 = 3072
for i in range(384):
    print(i + tmp_1, file = f, end = " ")
    for j in range(10):
        print(j + tmp, file = f, end = " ");
        print(w_23[j][i], file = f, end = " ");
    print("", file = f)
for i in range(10):
    print(tmp + i, file = f)

f.close()

## EuroSAT

In [3]:
# from lif_snn_EuroSAT.ckpt 12288 * 128 * 10
model = torch.load("../Models/lif_snn_EuroSAT.ckpt")

w_12 = model[1].weight.detach().cpu().numpy()
w_23 = model[3].weight.detach().cpu().numpy()

f = open('../Models/EuroSAT/weights.txt', 'w')

tmp = 12288
tmp_1 = 0
for i in range(12288):
    print(i + tmp_1, file = f, end = " ")
    for j in range(128):
        print(j + tmp, file = f, end = " ");
        print(w_12[j][i], file = f, end = " ");
    print("", file = f)
        
tmp = 12288 + 128
tmp_1 = 12288
for i in range(128):
    print(i + tmp_1, file = f, end = " ")
    for j in range(10):
        print(j + tmp, file = f, end = " ");
        print(w_23[j][i], file = f, end = " ");
    print("", file = f)
for i in range(10):
    print(tmp + i, file = f)

f.close()

# Create Dummy Networks

## 5 layers

In [1]:
# 1024 * 512 * 256 * 128 * 64
f = open('../Models/Dummy_5/weights.txt', 'w')

tmp = 1024
tmp_1 = 0
for i in range(1024):
    print(i + tmp_1, file = f, end = " ")
    for j in range(512):
        print(j + tmp, file = f, end = " ");
        print(0, file = f, end = " ");
    print("", file = f)
        
tmp = 1024 + 512
tmp_1 = 1024
for i in range(512):
    print(i + tmp_1, file = f, end = " ")
    for j in range(256):
        print(j + tmp, file = f, end = " ");
        print(0, file = f, end = " ");
    print("", file = f)
    
tmp = 1024 + 512 + 256
tmp_1 = 1024 + 512
for i in range(256):
    print(i + tmp_1, file = f, end = " ")
    for j in range(128):
        print(j + tmp, file = f, end = " ");
        print(0, file = f, end = " ");
    print("", file = f)
    
tmp = 1024 + 512 + 256 + 128
tmp_1 = 1024 + 512 + 256
for i in range(128):
    print(i + tmp_1, file = f, end = " ")
    for j in range(64):
        print(j + tmp, file = f, end = " ");
        print(0, file = f, end = " ");
    print("", file = f)
    
for i in range(64):
    print(tmp + i, file = f)
f.close()

## 4 layers

In [2]:
# 1024 * 512 * 256 * 64
f = open('../Models/Dummy_4/weights.txt', 'w')

tmp = 1024
tmp_1 = 0
for i in range(1024):
    print(i + tmp_1, file = f, end = " ")
    for j in range(512):
        print(j + tmp, file = f, end = " ");
        print(0, file = f, end = " ");
    print("", file = f)
        
tmp = 1024 + 512
tmp_1 = 1024
for i in range(512):
    print(i + tmp_1, file = f, end = " ")
    for j in range(256):
        print(j + tmp, file = f, end = " ");
        print(0, file = f, end = " ");
    print("", file = f)
    
tmp = 1024 + 512 + 256
tmp_1 = 1024 + 512
for i in range(256):
    print(i + tmp_1, file = f, end = " ")
    for j in range(64):
        print(j + tmp, file = f, end = " ");
        print(0, file = f, end = " ");
    print("", file = f)
    
for i in range(64):
    print(tmp + i, file = f)
f.close()

## 3 layers

In [1]:
# 1024 * 256 * 64
f = open('../Models/Dummy_3/weights.txt', 'w')

tmp = 1024
tmp_1 = 0
for i in range(1024):
    print(i + tmp_1, file = f, end = " ")
    for j in range(256):
        print(j + tmp, file = f, end = " ");
        print(0, file = f, end = " ");
    print("", file = f)
        
tmp = 1024 + 256
tmp_1 = 1024
for i in range(256):
    print(i + tmp_1, file = f, end = " ")
    for j in range(64):
        print(j + tmp, file = f, end = " ");
        print(0, file = f, end = " ");
    print("", file = f)
    
for i in range(64):
    print(tmp + i, file = f)
f.close()

## 2 layers

In [4]:
# 1024 * 64
f = open('../Models/Dummy_2/weights.txt', 'w')

tmp = 1024
tmp_1 = 0
for i in range(1024):
    print(i + tmp_1, file = f, end = " ")
    for j in range(64):
        print(j + tmp, file = f, end = " ");
        print(0, file = f, end = " ");
    print("", file = f)
    
for i in range(64):
    print(tmp + i, file = f)
f.close()