In [1]:
# Clone the repository
!git clone https://github.com/miladmozafari/SpykeTorch

Cloning into 'SpykeTorch'...
remote: Enumerating objects: 1026, done.[K
remote: Counting objects: 100% (28/28), done.[K
remote: Compressing objects: 100% (19/19), done.[K
remote: Total 1026 (delta 9), reused 24 (delta 7), pack-reused 998[K
Receiving objects: 100% (1026/1026), 6.11 MiB | 11.91 MiB/s, done.
Resolving deltas: 100% (91/91), done.


In [2]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import wave
import pylab
import random
import torchvision
import glob
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.nn.parameter import Parameter
from scipy import signal
from scipy.io import wavfile
from pathlib import Path
from torchvision import transforms
from tqdm import tqdm
from PIL import Image
from pathlib import Path
from sklearn import datasets
from SpykeTorch.SpykeTorch import snn
from SpykeTorch.SpykeTorch import functional as sf
from SpykeTorch.SpykeTorch import visualization as vis
from SpykeTorch.SpykeTorch import utils



In [4]:
use_cuda = True

In [5]:
!nvidia-smi

Mon Aug 28 08:37:30 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [6]:
class MNISTSpykeTorch(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = snn.Convolution(6, 30, 5, 0.8, 0.05)
        self.conv1_t = 15
        self.k1 = 5
        self.r1 = 3
        
        self.conv2 = snn.Convolution(30, 250, 3, 0.8, 0.05)
        self.conv2_t = 10
        self.k2 = 8
        self.r2 = 1
        
        self.conv3 = snn.Convolution(250, 200, 5, 0.8, 0.05)
        
        self.stdp1 = snn.STDP(self.conv1, (0.004, -0.003))
        self.stdp2 = snn.STDP(self.conv2, (0.004, -0.003))
        self.stdp3 = snn.STDP(self.conv3, (0.004, -0.003), False, 0.2, 0.8)
        self.anti_stdp3 = snn.STDP(self.conv3, (-0.004, 0.0005), False, 0.2, 0.8)
        self.max_ap = Parameter(torch.Tensor([0.15]))
        
        self.decision_map = []
        for i in range(10):
            self.decision_map.extend([i]*20)

            
        self.ctx = {"input_spikes":None, "potentials":None, "output_spikes":None, "winners":None}
        self.spk_cnt1 = 0
        self.spk_cnt2 = 0

    def forward(self, input, max_layer):
        input = sf.pad(input.float(), (2,2,2,2), 0)
        if self.training:
            pot = self.conv1(input)
            spk, pot = sf.fire(pot, self.conv1_t, True)
            if max_layer == 1:
                self.spk_cnt1 += 1
                if self.spk_cnt1 >= 500:
                    self.spk_cnt1 = 0
                    ap = torch.tensor(self.stdp1.learning_rate[0][0].item(), device=self.stdp1.learning_rate[0][0].device) * 2
                    ap = torch.min(ap, self.max_ap)
                    an = ap * -0.75
                    self.stdp1.update_all_learning_rate(ap.item(), an.item())
                pot = sf.pointwise_inhibition(pot)
                spk = pot.sign()
                winners = sf.get_k_winners(pot, self.k1, self.r1, spk)
                self.ctx["input_spikes"] = input
                self.ctx["potentials"] = pot
                self.ctx["output_spikes"] = spk
                self.ctx["winners"] = winners
                return spk, pot
            spk_in = sf.pad(sf.pooling(spk, 2, 2), (1,1,1,1))
            pot = self.conv2(spk_in)
            spk, pot = sf.fire(pot, self.conv2_t, True)
            if max_layer == 2:
                self.spk_cnt2 += 1
                if self.spk_cnt2 >= 500:
                    self.spk_cnt2 = 0
                    ap = torch.tensor(self.stdp2.learning_rate[0][0].item(), device=self.stdp2.learning_rate[0][0].device) * 2
                    ap = torch.min(ap, self.max_ap)
                    an = ap * -0.75
                    self.stdp2.update_all_learning_rate(ap.item(), an.item())
                pot = sf.pointwise_inhibition(pot)
                spk = pot.sign()
                winners = sf.get_k_winners(pot, self.k2, self.r2, spk)
                self.ctx["input_spikes"] = spk_in
                self.ctx["potentials"] = pot
                self.ctx["output_spikes"] = spk
                self.ctx["winners"] = winners
                return spk, pot
            spk_in = sf.pad(sf.pooling(spk, 3, 3), (2,2,2,2))
            pot = self.conv3(spk_in)
            spk = sf.fire(pot)
            winners = sf.get_k_winners(pot, 1, 0, spk)
            self.ctx["input_spikes"] = spk_in
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            output = -1
            if len(winners) != 0:
                output = self.decision_map[winners[0][0]]
            return output
        else:
            pot = self.conv1(input)
            spk, pot = sf.fire(pot, self.conv1_t, True)
            if max_layer == 1:
                return spk, pot
            pot = self.conv2(sf.pad(sf.pooling(spk, 2, 2), (1,1,1,1)))
            spk, pot = sf.fire(pot, self.conv2_t, True)
            if max_layer == 2:
                return spk, pot
            pot = self.conv3(sf.pad(sf.pooling(spk, 3, 3), (2,2,2,2)))
            spk = sf.fire(pot)
            winners = sf.get_k_winners(pot, 1, 0, spk)
            output = -1
            if len(winners) != 0:
                output = self.decision_map[winners[0][0]]
            return output
        
    def stdp(self, layer_idx):
        if layer_idx == 1:
            self.stdp1(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 2:
            self.stdp2(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
            
    def update_learning_rates(self, stdp_ap, stdp_an, anti_stdp_ap, anti_stdp_an):
        self.stdp3.update_all_learning_rate(stdp_ap, stdp_an)
        self.anti_stdp3.update_all_learning_rate(anti_stdp_an, anti_stdp_ap)
        
    def reward(self):
        self.stdp3(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])

    def punish(self):
        self.anti_stdp3(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])

In [7]:
def train_unsupervise(network, data, layer_idx):
    network.train()
    for i in range(len(data)):
        data_in = data[i]
        if use_cuda:
            data_in = data_in.cuda()
        network(data_in, layer_idx)
        network.stdp(layer_idx)

In [8]:
def train_rl(network, data, target):
    network.train()
    perf = np.array([0,0,0]) # correct, wrong, silence
    for i in range(len(data)):
        data_in = data[i]
        target_in = target[i]
        if use_cuda:
            data_in = data_in.cuda()
            target_in = target_in.cuda()
        d = network(data_in, 3)
        if d != -1:
            if d == target_in:
                perf[0]+=1
                network.reward()
            else:
                perf[1]+=1
                network.punish()
        else:
            perf[2]+=1
    return perf/len(data)

In [9]:
def test(network, data, target):
    network.eval()
    perf = np.array([0,0,0]) # correct, wrong, silence
    for i in range(len(data)):
        data_in = data[i]
        target_in = target[i]
        if use_cuda:
            data_in = data_in.cuda()
            target_in = target_in.cuda()
        d = network(data_in, 3)
        if d != -1:
            if d == target_in:
                perf[0]+=1
            else:
                perf[1]+=1
        else:
            perf[2]+=1
    return perf/len(data)

In [10]:
class S1C1Transform:
    def __init__(self, filter, timesteps = 15):
        self.resize = transforms.Resize((64, 64)),
        self.gray_scale = transforms.Grayscale()
        self.to_tensor = transforms.ToTensor()
        self.filter = filter
        self.temporal_transform = utils.Intensity2Latency(timesteps)
        self.cnt = 0
        
    def __call__(self, image):
        if self.cnt % 1000 == 0:
            print(self.cnt)
        self.cnt+=1
        image = self.to_tensor(image) * 255
        image = self.gray_scale(image)
        image.unsqueeze_(0)
        image = self.filter(image)
        image = sf.local_normalization(image, 8)
        temporal_image = self.temporal_transform(image)
        return temporal_image.sign().byte()

In [11]:
kernels = [ utils.DoGKernel(3,3/9,6/9),
            utils.DoGKernel(3,6/9,3/9),
            utils.DoGKernel(7,7/9,14/9),
            utils.DoGKernel(7,14/9,7/9),
            utils.DoGKernel(13,13/9,26/9),
            utils.DoGKernel(13,26/9,13/9)]

In [12]:
filter = utils.Filter(kernels, padding = 6, thresholds = 50)
s1c1 = S1C1Transform(filter)

In [13]:
data_root = "data"
MNIST_train = utils.CacheDataset(torchvision.datasets.MNIST(root=data_root, train=True, download=True, transform = s1c1))
MNIST_test = utils.CacheDataset(torchvision.datasets.MNIST(root=data_root, train=False, download=True, transform = s1c1))
MNIST_loader = DataLoader(MNIST_train, batch_size=1000, shuffle=False)
MNIST_testLoader = DataLoader(MNIST_test, batch_size=len(MNIST_test), shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 169769538.96it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 65620635.87it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 25192485.11it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 11943905.18it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [14]:
net = MNISTSpykeTorch()

In [15]:
if use_cuda:
    net.cuda()

In [17]:
# Training The First Layer
print("Training the first layer")
if os.path.isfile("/kaggle/input/mnist_saved/saved_l1.net"):
    fsdd.load_state_dict(torch.load("/kaggle/input/mnist_saved/saved_l1.net"))
else:
    for epoch in range(2):
        print("Epoch", epoch)
        iter = 0
        for data,targets in MNIST_loader:
            print("Iteration", iter)
            train_unsupervise(net, data, 1)
            print("Done!")
            iter+=1
    torch.save(net.state_dict(), "saved_l1.net")

Training the first layer
Epoch 0
Iteration 0
Done!
2000
Iteration 1
Done!
3000
Iteration 2
Done!
4000
Iteration 3
Done!
5000
Iteration 4
Done!
6000
Iteration 5
Done!
7000
Iteration 6
Done!
8000
Iteration 7
Done!
9000
Iteration 8
Done!
10000
Iteration 9
Done!
11000
Iteration 10
Done!
12000
Iteration 11
Done!
13000
Iteration 12
Done!
14000
Iteration 13
Done!
15000
Iteration 14
Done!
16000
Iteration 15
Done!
17000
Iteration 16
Done!
18000
Iteration 17
Done!
19000
Iteration 18
Done!
20000
Iteration 19
Done!
21000
Iteration 20
Done!
22000
Iteration 21
Done!
23000
Iteration 22
Done!
24000
Iteration 23
Done!
25000
Iteration 24
Done!
26000
Iteration 25
Done!
27000
Iteration 26
Done!
28000
Iteration 27
Done!
29000
Iteration 28
Done!
30000
Iteration 29
Done!
31000
Iteration 30
Done!
32000
Iteration 31
Done!
33000
Iteration 32
Done!
34000
Iteration 33
Done!
35000
Iteration 34
Done!
36000
Iteration 35
Done!
37000
Iteration 36
Done!
38000
Iteration 37
Done!
39000
Iteration 38
Done!
40000
Iteration 

In [18]:
# Training The Second Layer
print("Training the second layer")
if os.path.isfile("/kaggle/input/mnist_saved/saved_l2.net"):
    net.load_state_dict(torch.load("/kaggle/input/mnist_saved/saved_l2.net"))
else:
    for epoch in range(4):
        print("Epoch", epoch)
        iter = 0
        for data,targets in MNIST_loader:
            print("Iteration", iter)
            train_unsupervise(net, data, 2)
            print("Done!")
            iter+=1
    torch.save(net.state_dict(), "saved_l2.net")

Training the second layer
Epoch 0
Iteration 0
Done!
Iteration 1
Done!
Iteration 2
Done!
Iteration 3
Done!
Iteration 4
Done!
Iteration 5
Done!
Iteration 6
Done!
Iteration 7
Done!
Iteration 8
Done!
Iteration 9
Done!
Iteration 10
Done!
Iteration 11
Done!
Iteration 12
Done!
Iteration 13
Done!
Iteration 14
Done!
Iteration 15
Done!
Iteration 16
Done!
Iteration 17
Done!
Iteration 18
Done!
Iteration 19
Done!
Iteration 20
Done!
Iteration 21
Done!
Iteration 22
Done!
Iteration 23
Done!
Iteration 24
Done!
Iteration 25
Done!
Iteration 26
Done!
Iteration 27
Done!
Iteration 28
Done!
Iteration 29
Done!
Iteration 30
Done!
Iteration 31
Done!
Iteration 32
Done!
Iteration 33
Done!
Iteration 34
Done!
Iteration 35
Done!
Iteration 36
Done!
Iteration 37
Done!
Iteration 38
Done!
Iteration 39
Done!
Iteration 40
Done!
Iteration 41
Done!
Iteration 42
Done!
Iteration 43
Done!
Iteration 44
Done!
Iteration 45
Done!
Iteration 46
Done!
Iteration 47
Done!
Iteration 48
Done!
Iteration 49
Done!
Iteration 50
Done!
Iterati

In [20]:
# initial adaptive learning rates
apr = net.stdp3.learning_rate[0][0].item()
anr = net.stdp3.learning_rate[0][1].item()
app = net.anti_stdp3.learning_rate[0][1].item()
anp = net.anti_stdp3.learning_rate[0][0].item()

adaptive_min = 0
adaptive_int = 1
apr_adapt = ((1.0 - 1.0 / 10) * adaptive_int + adaptive_min) * apr
anr_adapt = ((1.0 - 1.0 / 10) * adaptive_int + adaptive_min) * anr
app_adapt = ((1.0 / 10) * adaptive_int + adaptive_min) * app
anp_adapt = ((1.0 / 10) * adaptive_int + adaptive_min) * anp

# perf
best_train = np.array([0.0,0.0,0.0,0.0]) # correct, wrong, silence, epoch
best_test = np.array([0.0,0.0,0.0,0.0]) # correct, wrong, silence, epoch



Training the third layer
Epoch #: 0
Current Train: [0.8174 0.1826 0.    ]
   Best Train: [0.8174 0.1826 0.     0.    ]
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
 Current Test: [0.847 0.153 0.   ]
Epoch #: 1
Current Train: [0.84758333 0.15241667 0.        ]
   Best Train: [0.84758333 0.15241667 0.         1.        ]
 Current Test: [0.86 0.14 0.  ]
Epoch #: 2
Current Train: [0.86125 0.13875 0.     ]
   Best Train: [0.86125 0.13875 0.      2.     ]
 Current Test: [0.8659 0.1341 0.    ]
Epoch #: 3
Current Train: [0.87675 0.12325 0.     ]
   Best Train: [0.87675 0.12325 0.      3.     ]
 Current Test: [0.882 0.118 0.   ]
Epoch #: 4
Current Train: [0.88515 0.11485 0.     ]
   Best Train: [0.88515 0.11485 0.      4.     ]
 Current Test: [0.8922 0.1078 0.    ]
Epoch #: 5
Current Train: [0.89068333 0.10931667 0.        ]
   Best Train: [0.89068333 0.10931667 0.         5.        ]
 Current Test: [0.8924 0.1076 0.    ]
Epoch #: 6
Current Train: [0.89303333 0.10696667 0.       

In [None]:
if os.path.isfile("/kaggle/input/mnist_saved/saved.net"):
    net.load_state_dict(torch.load("/kaggle/input/mnist_saved/saved.net"))

In [21]:
# Training The Third Layer
print("Training the third layer")
for epoch in range(20):
    print("Epoch #:", epoch)
    perf_train = np.array([0.0,0.0,0.0])
    for data,targets in MNIST_loader:
        perf_train_batch = train_rl(net, data, targets)
        #update adaptive learning rates
        apr_adapt = apr * (perf_train_batch[1] * adaptive_int + adaptive_min)
        anr_adapt = anr * (perf_train_batch[1] * adaptive_int + adaptive_min)
        app_adapt = app * (perf_train_batch[0] * adaptive_int + adaptive_min)
        anp_adapt = anp * (perf_train_batch[0] * adaptive_int + adaptive_min)
        net.update_learning_rates(apr_adapt, anr_adapt, app_adapt, anp_adapt)
        perf_train += perf_train_batch
    perf_train /= len(MNIST_loader)
    if best_train[0] <= perf_train[0]:
        best_train = np.append(perf_train, epoch)
    print("Current Train:", perf_train)
    print("   Best Train:", best_train)

    for data,targets in MNIST_testLoader:
        perf_test = test(net, data, targets)
        if best_test[0] <= perf_test[0]:
            best_test = np.append(perf_test, epoch)
            torch.save(net.state_dict(), "saved.net")
    print(" Current Test:", perf_test)


Training the third layer
Epoch #: 0
Current Train: [0.9242 0.0758 0.    ]
   Best Train: [0.9242 0.0758 0.     0.    ]
 Current Test: [0.9212 0.0788 0.    ]
Epoch #: 1
Current Train: [0.92553333 0.07446667 0.        ]
   Best Train: [0.92553333 0.07446667 0.         1.        ]
 Current Test: [0.9207 0.0793 0.    ]
Epoch #: 2
Current Train: [0.92705 0.07295 0.     ]
   Best Train: [0.92705 0.07295 0.      2.     ]
 Current Test: [0.922 0.078 0.   ]
Epoch #: 3
Current Train: [0.92553333 0.07446667 0.        ]
   Best Train: [0.92705 0.07295 0.      2.     ]
 Current Test: [0.9243 0.0757 0.    ]
Epoch #: 4
Current Train: [0.92848333 0.07151667 0.        ]
   Best Train: [0.92848333 0.07151667 0.         4.        ]
 Current Test: [0.9271 0.0729 0.    ]
Epoch #: 5
Current Train: [0.92936667 0.07063333 0.        ]
   Best Train: [0.92936667 0.07063333 0.         5.        ]
 Current Test: [0.9263 0.0737 0.    ]
Epoch #: 6
Current Train: [0.93031667 0.06968333 0.        ]
   Best Train: [0.