In [1]:
# Clone the repository
!git clone https://github.com/miladmozafari/SpykeTorch

Cloning into 'SpykeTorch'...
remote: Enumerating objects: 1026, done.[K
remote: Counting objects: 100% (28/28), done.[K
remote: Compressing objects: 100% (19/19), done.[K
remote: Total 1026 (delta 9), reused 24 (delta 7), pack-reused 998[K
Receiving objects: 100% (1026/1026), 6.11 MiB | 11.45 MiB/s, done.
Resolving deltas: 100% (91/91), done.


In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import torch
import torch.nn as nn
import pylab
import torchvision
import random
import skimage.io
import glob
import shutil
from tqdm import tqdm
from PIL import Image
from pathlib import Path
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.nn.parameter import Parameter
from scipy import signal
from scipy.io import wavfile
from pathlib import Path
from torchvision import transforms
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

In [3]:
from SpykeTorch.SpykeTorch import snn
from SpykeTorch.SpykeTorch import functional as sf
from SpykeTorch.SpykeTorch import visualization as vis
from SpykeTorch.SpykeTorch import utils

In [4]:
use_cuda = True

!nvidia-smi

Tue Sep 12 17:53:04 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P8     9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [5]:
class CIFAR10SpykeTorch(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = snn.Convolution(4, 32, 5, 0.8, 0.05)
        self.conv1_t = 15
        self.k1 = 5
        self.r1 = 3

        self.conv2 = snn.Convolution(32, 256, 3, 0.8, 0.05)
        self.conv2_t = 10
        self.k2 = 8
        self.r2 = 1

        self.conv3 = snn.Convolution(256, 200, 5, 0.8, 0.05)

        self.stdp1 = snn.STDP(self.conv1, (0.002, -0.001))
        self.stdp2 = snn.STDP(self.conv2, (0.002, -0.001))
        self.stdp3 = snn.STDP(self.conv3, (0.002, -0.001), False, 0.2, 0.8)
        self.anti_stdp3 = snn.STDP(self.conv3, (-0.002, 0.0003), False, 0.2, 0.8)
        self.max_ap = Parameter(torch.Tensor([0.05]))

        self.decision_map = []
        for i in range(10):
            self.decision_map.extend([i]*20)


        self.ctx = {"input_spikes":None, "potentials":None, "output_spikes":None, "winners":None}
        self.spk_cnt1 = 0
        self.spk_cnt2 = 0

    def forward(self, input, max_layer):
        input = sf.pad(input.float(), (2,2,2,2), 0)
        if self.training:
            pot = self.conv1(input)
            spk, pot = sf.fire(pot, self.conv1_t, True)
            if max_layer == 1:
                self.spk_cnt1 += 1
                if self.spk_cnt1 >= 500:
                    self.spk_cnt1 = 0
                    ap = torch.tensor(self.stdp1.learning_rate[0][0].item(), device=self.stdp1.learning_rate[0][0].device) * 2
                    ap = torch.min(ap, self.max_ap)
                    an = ap * -0.75
                    self.stdp1.update_all_learning_rate(ap.item(), an.item())
                pot = sf.pointwise_inhibition(pot)
                spk = pot.sign()
                winners = sf.get_k_winners(pot, self.k1, self.r1, spk)
                self.ctx["input_spikes"] = input
                self.ctx["potentials"] = pot
                self.ctx["output_spikes"] = spk
                self.ctx["winners"] = winners
                return spk, pot
            spk_in = sf.pad(sf.pooling(spk, 2, 2), (1,1,1,1))
            pot = self.conv2(spk_in)
            spk, pot = sf.fire(pot, self.conv2_t, True)
            if max_layer == 2:
                self.spk_cnt2 += 1
                if self.spk_cnt2 >= 500:
                    self.spk_cnt2 = 0
                    ap = torch.tensor(self.stdp2.learning_rate[0][0].item(), device=self.stdp2.learning_rate[0][0].device) * 2
                    ap = torch.min(ap, self.max_ap)
                    an = ap * -0.75
                    self.stdp2.update_all_learning_rate(ap.item(), an.item())
                pot = sf.pointwise_inhibition(pot)
                spk = pot.sign()
                winners = sf.get_k_winners(pot, self.k2, self.r2, spk)
                self.ctx["input_spikes"] = spk_in
                self.ctx["potentials"] = pot
                self.ctx["output_spikes"] = spk
                self.ctx["winners"] = winners
                return spk, pot
            spk_in = sf.pad(sf.pooling(spk, 3, 3), (2,2,2,2))
            pot = self.conv3(spk_in)
            spk = sf.fire(pot)
            winners = sf.get_k_winners(pot, 1, 0, spk)
            self.ctx["input_spikes"] = spk_in
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            output = -1
            if len(winners) != 0:
                output = self.decision_map[winners[0][0]]
            return output
        else:
            pot = self.conv1(input)
            spk, pot = sf.fire(pot, self.conv1_t, True)
            if max_layer == 1:
                return spk, pot
            pot = self.conv2(sf.pad(sf.pooling(spk, 2, 2), (1,1,1,1)))
            spk, pot = sf.fire(pot, self.conv2_t, True)
            if max_layer == 2:
                return spk, pot
            pot = self.conv3(sf.pad(sf.pooling(spk, 3, 3), (2,2,2,2)))
            spk = sf.fire(pot)
            winners = sf.get_k_winners(pot, 1, 0, spk)
            output = -1
            if len(winners) != 0:
                output = self.decision_map[winners[0][0]]
            return output

    def stdp(self, layer_idx):
        if layer_idx == 1:
            self.stdp1(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 2:
            self.stdp2(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])

    def update_learning_rates(self, stdp_ap, stdp_an, anti_stdp_ap, anti_stdp_an):
        self.stdp3.update_all_learning_rate(stdp_ap, stdp_an)
        self.anti_stdp3.update_all_learning_rate(anti_stdp_an, anti_stdp_ap)

    def reward(self):
        self.stdp3(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])

    def punish(self):
        self.anti_stdp3(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])

In [6]:
def train_unsupervise(network, data, layer_idx):
    network.train()
    for i in range(len(data)):
        data_in = data[i]
        if use_cuda:
            data_in = data_in.cuda()
        network(data_in, layer_idx)
        network.stdp(layer_idx)

In [7]:
def train_rl(network, data, target):
    network.train()
    perf = np.array([0,0,0]) # correct, wrong, silence
    for i in range(len(data)):
        data_in = data[i]
        target_in = target[i]
        if use_cuda:
            data_in = data_in.cuda()
            target_in = target_in.cuda()
        d = network(data_in, 3)
        if d != -1:
            if d == target_in:
                perf[0]+=1
                network.reward()
            else:
                perf[1]+=1
                network.punish()
        else:
            perf[2]+=1
    return perf/len(data)

In [8]:
def test(network, data, target):
    network.eval()
    perf = np.array([0,0,0]) # correct, wrong, silence
    for i in range(len(data)):
        data_in = data[i]
        target_in = target[i]
        if use_cuda:
            data_in = data_in.cuda()
            target_in = target_in.cuda()
        d = network(data_in, 3)
        if d != -1:
            if d == target_in:
                perf[0]+=1
            else:
                perf[1]+=1
        else:
            perf[2]+=1
    return perf/len(data)

In [9]:
class S1C1Transform:
    def __init__(self, filter, timesteps = 15):
        self.gray_scale = transforms.Grayscale()
        self.to_tensor = transforms.ToTensor()
        self.filter = filter
        self.temporal_transform = utils.Intensity2Latency(timesteps)
        self.cnt = 0

    def __call__(self, image):
        if self.cnt % 1000 == 0:
            print(self.cnt)
        self.cnt+=1
        image = self.to_tensor(image) * 255
        image = self.gray_scale(image)
        image.unsqueeze_(0)
        image = self.filter(image)
        image = sf.local_normalization(image, 8)
        temporal_image = self.temporal_transform(image)
        return temporal_image.sign().byte()

In [10]:
kernels = [
            utils.DoGKernel(3,6/9,3/9),
            utils.DoGKernel(3,12/9,3/9),
            utils.DoGKernel(2,4/9,2/9),
            utils.DoGKernel(2,8/9,2/9),
            ]

In [11]:
filter = utils.Filter(kernels, padding = 6, thresholds = 15)
s1c1 = S1C1Transform(filter)

In [12]:
data_root = "data"
CIFAR_train = utils.CacheDataset(torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform = s1c1))
CIFAR_test = utils.CacheDataset(torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform = s1c1))
CIFAR_loader = DataLoader(CIFAR_train, batch_size=1000, shuffle=False)
CIFAR_testLoader = DataLoader(CIFAR_test, batch_size=len(CIFAR_test), shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 43573515.77it/s]


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified


In [13]:
net = CIFAR10SpykeTorch()

if use_cuda:
    net.cuda()

In [14]:
# Training The First Layer
print("Training the first layer")
if os.path.isfile("/content/saved_l1.net"):
    net.load_state_dict(torch.load("/content/saved_l1.net"))

Training the first layer


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000


 10%|█         | 1/10 [04:00<36:04, 240.47s/it]

Epoch 1


 20%|██        | 2/10 [05:49<21:45, 163.21s/it]

Epoch 2


 30%|███       | 3/10 [07:38<16:09, 138.49s/it]

Epoch 3


 40%|████      | 4/10 [09:27<12:41, 126.96s/it]

Epoch 4


 50%|█████     | 5/10 [11:16<10:01, 120.34s/it]

Epoch 5


 60%|██████    | 6/10 [13:05<07:45, 116.37s/it]

Epoch 6


 70%|███████   | 7/10 [14:53<05:41, 113.78s/it]

Epoch 7


 80%|████████  | 8/10 [16:41<03:43, 111.95s/it]

Epoch 8


 90%|█████████ | 9/10 [18:31<01:51, 111.42s/it]

Epoch 9


100%|██████████| 10/10 [20:20<00:00, 122.01s/it]


In [27]:
for epoch in tqdm(range(15)):
    print("Epoch", epoch)
    for data,targets in CIFAR_loader:
        train_unsupervise(net, data, 1)
torch.save(net.state_dict(), "saved_l1.net")

  0%|          | 0/15 [00:00<?, ?it/s]

Epoch 0


  7%|▋         | 1/15 [01:43<24:13, 103.80s/it]

Epoch 1


 13%|█▎        | 2/15 [03:27<22:28, 103.75s/it]

Epoch 2


 20%|██        | 3/15 [05:12<20:50, 104.23s/it]

Epoch 3


 27%|██▋       | 4/15 [06:57<19:08, 104.42s/it]

Epoch 4


 33%|███▎      | 5/15 [08:41<17:23, 104.30s/it]

Epoch 5


 40%|████      | 6/15 [10:24<15:35, 103.97s/it]

Epoch 6


 47%|████▋     | 7/15 [12:07<13:49, 103.72s/it]

Epoch 7


 53%|█████▎    | 8/15 [13:50<12:03, 103.38s/it]

Epoch 8


 60%|██████    | 9/15 [15:34<10:21, 103.57s/it]

Epoch 9


 67%|██████▋   | 10/15 [17:17<08:37, 103.57s/it]

Epoch 10


 73%|███████▎  | 11/15 [19:01<06:54, 103.60s/it]

Epoch 11


 80%|████████  | 12/15 [20:45<05:10, 103.61s/it]

Epoch 12


 87%|████████▋ | 13/15 [22:27<03:26, 103.26s/it]

Epoch 13


 93%|█████████▎| 14/15 [24:10<01:43, 103.23s/it]

Epoch 14


100%|██████████| 15/15 [25:54<00:00, 103.66s/it]


In [15]:
# Training The Second Layer
print("Training the second layer")
if os.path.isfile("/content/saved_l2.net"):
    net.load_state_dict(torch.load("/content/saved_l2.net"))

Training the second layer


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0


 10%|█         | 1/10 [02:56<26:30, 176.74s/it]

Epoch 1


 20%|██        | 2/10 [05:53<23:34, 176.86s/it]

Epoch 2


 30%|███       | 3/10 [08:50<20:36, 176.62s/it]

Epoch 3


 40%|████      | 4/10 [11:47<17:41, 176.88s/it]

Epoch 4


 50%|█████     | 5/10 [14:43<14:43, 176.77s/it]

Epoch 5


 60%|██████    | 6/10 [17:39<11:46, 176.54s/it]

Epoch 6


 70%|███████   | 7/10 [20:38<08:51, 177.30s/it]

Epoch 7


 80%|████████  | 8/10 [23:36<05:54, 177.42s/it]

Epoch 8


 90%|█████████ | 9/10 [26:32<02:57, 177.09s/it]

Epoch 9


100%|██████████| 10/10 [29:29<00:00, 176.94s/it]


In [30]:
for epoch in tqdm(range(10)):
    print("Epoch", epoch)
    for data,targets in CIFAR_loader:
        train_unsupervise(net, data, 2)
torch.save(net.state_dict(), "saved_l2.net")

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0


 10%|█         | 1/10 [02:57<26:38, 177.66s/it]

Epoch 1


 20%|██        | 2/10 [05:51<23:21, 175.18s/it]

Epoch 2


 30%|███       | 3/10 [08:44<20:19, 174.21s/it]

Epoch 3


 40%|████      | 4/10 [11:37<17:23, 173.92s/it]

Epoch 4


 50%|█████     | 5/10 [14:31<14:29, 173.95s/it]

Epoch 5


 60%|██████    | 6/10 [17:25<11:35, 173.77s/it]

Epoch 6


 70%|███████   | 7/10 [20:18<08:40, 173.58s/it]

Epoch 7


 80%|████████  | 8/10 [23:09<05:45, 172.77s/it]

Epoch 8


 90%|█████████ | 9/10 [26:01<02:52, 172.75s/it]

Epoch 9


100%|██████████| 10/10 [28:52<00:00, 173.26s/it]


In [16]:
# initial adaptive learning rates
apr = net.stdp3.learning_rate[0][0].item()
anr = net.stdp3.learning_rate[0][1].item()
app = net.anti_stdp3.learning_rate[0][1].item()
anp = net.anti_stdp3.learning_rate[0][0].item()

In [17]:
adaptive_min = 0.0
adaptive_int = 1.0

In [18]:
apr_adapt = ((1.0 - 1.0 / 10) * adaptive_int + adaptive_min) * apr
anr_adapt = ((1.0 - 1.0 / 10) * adaptive_int + adaptive_min) * anr
app_adapt = ((1.0 / 10) * adaptive_int + adaptive_min) * app
anp_adapt = ((1.0 / 10) * adaptive_int + adaptive_min) * anp

In [19]:
# perf
best_train = np.array([0.0,0.0,0.0,0.0]) # correct, wrong, sile, epoch
best_test = np.array([0.0,0.0,0.0,0.0]) # correct, wrong, silence, epoch

In [20]:
if os.path.isfile("/content/saved.net"):
    net.load_state_dict(torch.load("/content/saved.net"))

In [35]:
# Training The Third Layer
print("Training the third layer")
for epoch in tqdm(range(3)):
    print("Epoch #:", epoch)
    perf_train = np.array([0.0,0.0,0.0])
    perf_test = np.array([0.0,0.0,0.0])
    for data,targets in CIFAR_loader:
        perf_train_batch = train_rl(net, data, targets)
#         print(perf_train_batch)
        # update adaptive learning rates
        apr_adapt = apr * (perf_train_batch[1] * adaptive_int + adaptive_min)
        anr_adapt = anr * (perf_train_batch[1] * adaptive_int + adaptive_min)
        app_adapt = app * (perf_train_batch[0] * adaptive_int + adaptive_min)
        anp_adapt = anp * (perf_train_batch[0] * adaptive_int + adaptive_min)
        net.update_learning_rates(apr_adapt, anr_adapt, app_adapt, anp_adapt)
        perf_train += perf_train_batch
    perf_train /= len(CIFAR_loader)
    if best_train[0] <= perf_train[0]:
        best_train = np.append(perf_train, epoch)
    print("Current Train:", perf_train)
    print("   Best Train:", best_train)

    for data,targets in CIFAR_testLoader:
        perf_test += test(net, data, targets)
    perf_test /= len(CIFAR_testLoader)
    if best_test[0] <= perf_test[0]:
        best_test = np.append(perf_test, epoch)
        torch.save(net.state_dict(), "saved.net")
    print(" Current Test:", perf_test)
#     print("    Best Test:", best_test)

Training the third layer


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch #: 0
Current Train: [0.50908 0.49092 0.     ]
   Best Train: [0.50938 0.49062 0.      7.     ]


 33%|███▎      | 1/3 [02:22<04:44, 142.03s/it]

 Current Test: [0.3838 0.6162 0.    ]
Epoch #: 1
Current Train: [0.5096 0.4904 0.    ]
   Best Train: [0.5096 0.4904 0.     1.    ]


 67%|██████▋   | 2/3 [04:38<02:18, 138.82s/it]

 Current Test: [0.3866 0.6134 0.    ]
Epoch #: 2
Current Train: [0.51024 0.48976 0.     ]
   Best Train: [0.51024 0.48976 0.      2.     ]


100%|██████████| 3/3 [06:54<00:00, 138.15s/it]

 Current Test: [0.3863 0.6137 0.    ]



