In [3]:
import os
import random
import mne
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import scipy.signal as signal
%matplotlib qt

from sklearn.metrics import mean_squared_error, accuracy_score, precision_score
from scipy.stats import pearsonr
from sklearn.metrics.pairwise import cosine_similarity
from skimage.metrics import structural_similarity as ssim
from sklearn.model_selection import train_test_split

import torch
import snntorch as snn
from snntorch import spikegen, surrogate, utils
import snntorch.functional as SF
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

import logging
import warnings
warnings.filterwarnings("ignore")
logging.getLogger('mne').setLevel(logging.WARNING)

# 1

In [2]:
# dataloader arguments
batch_size = 128
data_path='/tmp/data/mnist'

dtype = torch.float
device = torch.device("cuda")

transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

for inputs, targets in train_loader:
    print(inputs.shape)
    print(targets.shape)
    break

torch.Size([128, 1, 28, 28])
torch.Size([128])


In [3]:
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5
num_steps = 20

In [4]:
def forward_pass(input_data, model, num_steps):
    spk_rec = []
    mem_rec = []
    utils.reset(model)

    for i in range(num_steps):
        spk, mem = model(input_data)
        spk_rec.append(spk)
        mem_rec.append(mem)

    return torch.stack(spk_rec), torch.stack(mem_rec)

def batch_acc(data_loader, model, num_steps):
    with torch.no_grad():
        model.eval()
        acc = 0
        total = 0
        for inputs, targets in iter(data_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)
        
            spk_rec, mem_rec = forward_pass(inputs, model, num_steps)
            acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
            total += spk_rec.size(1)
    return acc / total    
    
class MyEEGSNNModel(nn.Module):
    def __init__(self, n_outputs):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 12, 5)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(12, 64, 5)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc1 = nn.Linear(64*4*4, n_outputs)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        # self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(1, 5))
        # self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        # self.conv2 = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=(1, 5))
        # self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        # self.fc1 = nn.Linear(64*14*157, n_outputs)
        # self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)

    def forward(self, x):

        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        cur1 = F.max_pool2d(self.conv1(x), kernel_size=(2, 2))
        # cur1 = F.max_pool2d(self.conv1(x), kernel_size=(1, 2))
        spk1, mem1 = self.lif1(cur1, mem1)
         
        cur2 = F.max_pool2d(self.conv2(spk1), kernel_size=(2, 2))
        # cur2 = F.max_pool2d(self.conv2(spk1), kernel_size=(1, 2))
        spk2, mem2 = self.lif2(cur2, mem2)
        
        cur3 = self.fc1(spk2.view(batch_size, -1))
        spk3, mem3 = self.lif3(cur3, mem3)

        return spk3, mem3   

In [5]:
num_classes = 10
model = MyEEGSNNModel(num_classes).to("cuda")

for inputs, targets in tqdm(train_loader):
    inputs = inputs.to(device)
    targets = targets.to(device)

    spk_rec, _= forward_pass(inputs, model, num_steps)
    print(spk_rec.shape)
    break

  0%|                                                                                          | 0/468 [00:00<?, ?it/s]

torch.Size([20, 128, 10])





In [8]:
n_epochs = 5
learning_rate = 1e-3
loss_history = []
acc_history = []

num_classes = 10
# num_classes = 4
model = MyEEGSNNModel(num_classes).to("cuda")
for param in model.parameters():
    print(param.requires_grad)
criterion = SF.ce_rate_loss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))

for epoch in range(n_epochs):
    epoch_loss = 0.0
    train_acc = 0.0
    counter = 0
    acc_co = 0
    
    for inputs, targets in tqdm(iter(train_loader)):
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        model.train()
        optimizer.zero_grad()

        spk_rec, _ = forward_pass(inputs, model, num_steps)
        loss = criterion(spk_rec, targets)
            
        loss.backward()         
        optimizer.step()  
        epoch_loss += loss.item()
        counter+=1

        if counter % 75 == 0:
        # if counter % 6 == 0:
            train_acc += batch_acc(test_loader, model, num_steps)
            acc_co+=1
        
    train_loss = epoch_loss / counter
    train_acc = train_acc.item() / acc_co
    # train_acc = train_acc.item() / acc_co
    loss_history.append(train_loss)
    acc_history.append(train_acc)
    print(f'Epoch [{epoch + 1}/{n_epochs}], Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}')

True
True
True
True
True
True


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [01:14<00:00,  6.32it/s]


Epoch [1/5], Loss: 2.3026, Accuracy: 0.0979


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [01:13<00:00,  6.34it/s]


Epoch [2/5], Loss: 2.0290, Accuracy: 0.4005


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [01:13<00:00,  6.35it/s]


Epoch [3/5], Loss: 1.5172, Accuracy: 0.8882


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [01:13<00:00,  6.33it/s]


Epoch [4/5], Loss: 1.5029, Accuracy: 0.9171


100%|████████████████████████████████████████████████████████████████████████████████| 468/468 [01:13<00:00,  6.34it/s]

Epoch [5/5], Loss: 1.4988, Accuracy: 0.9263





In [None]:
model.eval()
for inputs, targets in tqdm(test_loader):
    with torch.no_grad():
        spk_rec, mem_rec = model(inputs, num_steps)
        spk_count = torch.sum(spk_rec, dim=0)
        outputs = F.softmax(spk_count)
        
    loss = criterion(outputs, targets)
    epoch_loss += loss.item()
print(f'Loss: {epoch_loss / len(test_loader):.4f}')

# 2

In [4]:
folder_path = "Extras/GAMEEMO_EPOCH/5sec"
epochs_data = []

for file in tqdm(os.listdir(folder_path)):
    file_path = os.path.join(folder_path, file)
    epochs = mne.read_epochs(file_path, preload=True)
    epochs_choice_idx = np.random.choice(epochs.get_data().shape[0], size=40, replace=False)
    epochs_choice = epochs.get_data()[epochs_choice_idx]
    epochs_data.append(epochs_choice)

epochs_data = np.stack(epochs_data, axis=0)
epochs_data_reshaped = epochs_data.reshape(-1, epochs_data.shape[-2], epochs_data.shape[-1])
print(epochs_data_reshaped.shape)

for i in tqdm(range(epochs_data_reshaped.shape[0])):
    for j in range(epochs_data_reshaped.shape[1]):
        data = epochs_data_reshaped[i,j,:]
        epochs_data_reshaped[i,j,:] = (data - np.min(data)) / (np.max(data) - np.min(data))

plt.figure(1)
plt.plot(epochs_data_reshaped[0][0])

100%|████████████████████████████████████████████████████████████████████████████████| 112/112 [00:02<00:00, 39.83it/s]


(4480, 14, 640)


100%|████████████████████████████████████████████████████████████████████████████| 4480/4480 [00:01<00:00, 3533.89it/s]


[<matplotlib.lines.Line2D at 0x1674dad7100>]

In [5]:
label_list = [0, 1, 2, 3]
labels = []
for i in range(int(112/4)):
    labels.extend(label_list)

labels = np.array(labels)
labels = np.repeat(labels, int(epochs_data_reshaped.shape[0]/112))
oh_labels = np.eye(4)[labels]
print(labels.shape)
print(oh_labels.shape)
np.set_printoptions(threshold=20)
print(labels)
print(oh_labels)

(4480,)
(4480, 4)
[0 0 0 ... 3 3 3]
[[1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 ...
 [0. 0. 0. 1.]
 [0. 0. 0. 1.]
 [0. 0. 0. 1.]]


In [6]:
X_train, X_test, y_train, y_test = train_test_split(epochs_data_reshaped, labels, test_size=0.25, random_state=56)

num_steps = 32
device = torch.device("cuda") 
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test  = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test  = torch.tensor(y_test, dtype=torch.long)
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)

batch_size = 80
train_dataset = TensorDataset(X_train.unsqueeze(1), y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

test_dataset = TensorDataset(X_test.unsqueeze(1), y_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

torch.Size([3360, 14, 640]) torch.Size([1120, 14, 640]) torch.Size([3360]) torch.Size([1120])


In [7]:
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5
num_steps = 20

In [8]:
def forward_pass(input_data, model, num_steps):
    spk_rec = []
    mem_rec = []
    utils.reset(model)

    for i in range(num_steps):
        spk, mem = model(input_data)
        spk_rec.append(spk)
        mem_rec.append(mem)

    return torch.stack(spk_rec), torch.stack(mem_rec)

def batch_acc(data_loader, model, num_steps):
    with torch.no_grad():
        model.eval()
        acc = 0
        total = 0
        for inputs, targets in iter(data_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)
        
            spk_rec, mem_rec = forward_pass(inputs, model, num_steps)
            acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
            total += spk_rec.size(1)
    return acc / total    
    
class MyEEGSNNModel(nn.Module):
    def __init__(self, n_outputs):
        super().__init__()
        
        # self.conv1 = nn.Conv2d(1, 12, 5)
        # self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        # self.conv2 = nn.Conv2d(12, 64, 5)
        # self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        # self.fc1 = nn.Linear(64*4*4, n_outputs)
        # self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(1, 5))
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=(1, 5))
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc1 = nn.Linear(64*14*157, n_outputs)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)

    def forward(self, x):

        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # cur1 = F.max_pool2d(self.conv1(x), kernel_size=(2, 2))
        cur1 = F.max_pool2d(self.conv1(x), kernel_size=(1, 2))
        spk1, mem1 = self.lif1(cur1, mem1)
         
        # cur2 = F.max_pool2d(self.conv2(spk1), kernel_size=(2, 2))
        cur2 = F.max_pool2d(self.conv2(spk1), kernel_size=(1, 2))
        spk2, mem2 = self.lif2(cur2, mem2)
        
        cur3 = self.fc1(spk2.view(batch_size, -1))
        spk3, mem3 = self.lif3(cur3, mem3)

        return spk3, mem3   

In [9]:
num_classes = 4
model = MyEEGSNNModel(num_classes).to("cuda")

for inputs, targets in tqdm(train_loader):
    inputs = inputs.to(device)
    targets = targets.to(device)

    spk_rec, _= forward_pass(inputs, model, num_steps)
    print(spk_rec.shape)
    break

  0%|                                                                                           | 0/42 [00:01<?, ?it/s]

torch.Size([20, 80, 4])





In [10]:
n_epochs = 5
learning_rate = 1e-3
loss_history = []
acc_history = []

# num_classes = 10
num_classes = 4
model = MyEEGSNNModel(num_classes).to("cuda")
for param in model.parameters():
    print(param.requires_grad)
criterion = SF.ce_rate_loss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))

for epoch in range(n_epochs):
    epoch_loss = 0.0
    train_acc = 0.0
    counter = 0
    acc_co = 0
    
    for inputs, targets in tqdm(iter(train_loader)):
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        model.train()
        optimizer.zero_grad()

        spk_rec, _ = forward_pass(inputs, model, num_steps)
        loss = criterion(spk_rec, targets)
            
        loss.backward()         
        optimizer.step()  
        epoch_loss += loss.item()
        counter+=1

        # if counter % 75 == 0:
        if counter % 6 == 0:
            train_acc += batch_acc(test_loader, model, num_steps)
            acc_co+=1
        
    train_loss = epoch_loss / counter
    train_acc = train_acc.item() / acc_co
    loss_history.append(train_loss)
    acc_history.append(train_acc)
    print(f'Epoch [{epoch + 1}/{n_epochs}], Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}')

True
True
True
True
True
True


100%|██████████████████████████████████████████████████████████████████████████████████| 42/42 [02:20<00:00,  3.36s/it]


Epoch [1/5], Loss: 1.3863, Accuracy: 0.2536


100%|██████████████████████████████████████████████████████████████████████████████████| 42/42 [02:19<00:00,  3.31s/it]


Epoch [2/5], Loss: 1.3863, Accuracy: 0.2536


100%|██████████████████████████████████████████████████████████████████████████████████| 42/42 [02:18<00:00,  3.30s/it]


Epoch [3/5], Loss: 1.3863, Accuracy: 0.2536


100%|██████████████████████████████████████████████████████████████████████████████████| 42/42 [02:18<00:00,  3.31s/it]


Epoch [4/5], Loss: 1.3863, Accuracy: 0.2536


100%|██████████████████████████████████████████████████████████████████████████████████| 42/42 [02:19<00:00,  3.31s/it]

Epoch [5/5], Loss: 1.3863, Accuracy: 0.2536





# 3

In [None]:
# dataloader arguments
batch_size = 128
data_path='/tmp/data/mnist'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

# neuron and simulation parameters
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5
num_steps = 50

# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.conv1 = nn.Conv2d(1, 12, 5)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(12, 64, 5)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc1 = nn.Linear(64*4*4, 10)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)

    def forward(self, x):

        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        cur1 = F.max_pool2d(self.conv1(x), 2)
        spk1, mem1 = self.lif1(cur1, mem1)

        cur2 = F.max_pool2d(self.conv2(spk1), 2)
        spk2, mem2 = self.lif2(cur2, mem2)

        cur3 = self.fc1(spk2.view(batch_size, -1))
        spk3, mem3 = self.lif3(cur3, mem3)

        return spk3, mem3

def forward_pass(net, num_steps, data):
  mem_rec = []
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps):
      spk_out, mem_out = net(data)
      spk_rec.append(spk_out)
      mem_rec.append(mem_out)

  return torch.stack(spk_rec), torch.stack(mem_rec)

def batch_accuracy(train_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    train_loader = iter(train_loader)
    for data, targets in train_loader:
      data = data.to(device)
      targets = targets.to(device)
      spk_rec, _ = forward_pass(net, num_steps, data)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total

# already imported snntorch.functional as SF
loss_fn = SF.ce_rate_loss()

net = Net().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, betas=(0.9, 0.999))
num_epochs = 1
loss_hist = []
test_acc_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):

    # Training loop
    for data, targets in tqdm(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, _ = forward_pass(net, num_steps, data)

        # initialize the loss & sum over time
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        if counter % 50 == 0:
            with torch.no_grad():
                net.eval()
    
                # Test set forward pass
                test_acc = batch_accuracy(test_loader, net, num_steps)
                print(f"Iteration {counter}, Test Acc: {test_acc * 100:.2f}%\n")
                test_acc_hist.append(test_acc.item())

        counter += 1

print(batch_accuracy(test_loader, net, num_steps))

# 4

In [1]:
import os
import random
import mne
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import scipy.signal as signal
import cv2
from PIL import Image
%matplotlib qt

from sklearn.metrics import mean_squared_error, accuracy_score, precision_score
from scipy.stats import pearsonr
from sklearn.metrics.pairwise import cosine_similarity
from skimage.metrics import structural_similarity as ssim
from sklearn.model_selection import train_test_split

import torch
import snntorch as snn
from snntorch import spikegen, surrogate, utils
import snntorch.functional as SF
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

import logging
import warnings
warnings.filterwarnings("ignore")
logging.getLogger('mne').setLevel(logging.WARNING)

device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"

In [2]:
img_dir = "shuvankar_spectro_snn"
img_files = np.array(os.listdir(img_dir))
print(img_files.shape)

data_array = []
labels_array = []

for files in tqdm(img_files):
    data_path = os.path.join(img_dir, files)
    read_data = cv2.imread(data_path, cv2.COLOR_BGR2GRAY)
    read_data = (read_data - np.min(read_data))/(np.max(read_data) - np.min(read_data))
    read_data = cv2.resize(read_data, (200, 200))
    data_array.append(read_data)
    labels_array.append(int(files[4]))

(672,)


100%|████████████████████████████████████████████████████████████████████████████████| 672/672 [00:07<00:00, 85.94it/s]


In [3]:
X_train, X_test, y_train, y_test = train_test_split(data_array, labels_array, test_size=0.2, random_state=42)

In [4]:
train_data_array = torch.tensor(X_train, device=device, dtype=torch.float32)
train_labels_array = torch.tensor(y_train, device=device, dtype=torch.long)
test_data_array = torch.tensor(X_test, device=device, dtype=torch.float32)
test_labels_array = torch.tensor(y_test, device=device, dtype=torch.long)

In [5]:
train_dataset = TensorDataset(train_data_array, train_labels_array)
test_dataset = TensorDataset(test_data_array, test_labels_array)

batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [6]:
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5
num_steps = 20

In [9]:
def forward_pass(input_data, model, num_steps):
    spk_rec = []
    mem_rec = []
    utils.reset(model)

    for i in range(num_steps):
        spk, mem = model(input_data)
        spk_rec.append(spk)
        mem_rec.append(mem)

    return torch.stack(spk_rec), torch.stack(mem_rec)

def batch_acc(data_loader, model, num_steps):
    with torch.no_grad():
        model.eval()
        acc = 0
        total = 0
        for inputs, targets in iter(data_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)
        
            spk_rec, mem_rec = forward_pass(inputs, model, num_steps)
            acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
            total += spk_rec.size(1)
    return acc / total    
    
class MyEEGSNNModel(nn.Module):
    def __init__(self, n_outputs):
        super().__init__()
        
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(16, 64, 5)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc1 = nn.Linear(64*47*47, n_outputs)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        # self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(1, 5))
        # self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        # self.conv2 = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=(1, 5))
        # self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        # self.fc1 = nn.Linear(64*14*157, n_outputs)
        # self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)

        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        cur1 = F.max_pool2d(self.conv1(x), kernel_size=(2, 2))
        # cur1 = F.max_pool2d(self.conv1(x), kernel_size=(1, 2))
        spk1, mem1 = self.lif1(cur1, mem1)
         
        cur2 = F.max_pool2d(self.conv2(spk1), kernel_size=(2, 2))
        # cur2 = F.max_pool2d(self.conv2(spk1), kernel_size=(1, 2))
        spk2, mem2 = self.lif2(cur2, mem2)
        
        cur3 = self.fc1(spk2.contiguous().view(batch_size, -1))
        spk3, mem3 = self.lif3(cur3, mem3)

        return spk3, mem3  

In [10]:
num_classes = 4
model = MyEEGSNNModel(num_classes).to(device)

for inputs, targets in tqdm(train_loader):
    inputs = inputs.to(device)
    targets = targets.to(device)

    spk_rec, _= forward_pass(inputs, model, num_steps)
    print(spk_rec.shape)
    break

  0%|                                                                                           | 0/16 [00:01<?, ?it/s]

torch.Size([20, 32, 4])





In [13]:
n_epochs = 5
learning_rate = 1e-3
loss_history = []
acc_history = []

num_classes = 4
# num_classes = 4
model = MyEEGSNNModel(num_classes).to(device)
for param in model.parameters():
    print(param.requires_grad)
criterion = SF.ce_rate_loss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))

for epoch in range(n_epochs):
    epoch_loss = 0.0
    train_acc = 0.0
    counter = 0
    acc_co = 0
    
    for inputs, targets in tqdm(iter(train_loader)):
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        model.train()
        optimizer.zero_grad()

        spk_rec, _ = forward_pass(inputs, model, num_steps)
        loss = criterion(spk_rec, targets)
            
        loss.backward()         
        optimizer.step()  
        epoch_loss += loss.item()
        counter+=1

        if counter % 75 == 0:
        # if counter % 6 == 0:
            train_acc += batch_acc(test_loader, model, num_steps)
            acc_co+=1
        
    train_loss = epoch_loss / counter
    train_acc = train_acc.item() / acc_co
    # train_acc = train_acc.item() / acc_co
    loss_history.append(train_loss)
    acc_history.append(train_acc)
    print(f'Epoch [{epoch + 1}/{n_epochs}], Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}')

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
model.eval()
for inputs, targets in tqdm(test_loader):
    with torch.no_grad():
        spk_rec, mem_rec = model(inputs, num_steps)
        spk_count = torch.sum(spk_rec, dim=0)
        outputs = F.softmax(spk_count)
        
    loss = criterion(outputs, targets)
    epoch_loss += loss.item()
print(f'Loss: {epoch_loss / len(test_loader):.4f}')