In [2]:
import snntorch as snn
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import random
from torchvision import datasets, transforms #
from snntorch import utils
from torch.utils.data import DataLoader
from snntorch import spikegen
import numpy as np

In [3]:
data_path='/tmp/data/mnist'

# Torch Variables
dtype = torch.float       #para consumir menos memoria y asegurar compatibilidad con ciertas librerías que 
                          # esperan tensores de este tipo en luagr de 64 por ejemplo
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

In [4]:
batch_size=160
mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
print(f"The size of mnist_train is {len(mnist_train)}")
print(f"The size of mnist_test is {len(mnist_test)}")
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)
print(f"The size of train_loader is {len(train_loader)}")
print(f"The size of test_loader is {len(test_loader)}")


The size of mnist_train is 60000
The size of mnist_test is 10000
The size of train_loader is 375
The size of test_loader is 63


In [5]:
# Network Architecture
num_inputs = 28*28    
num_hidden = 64    
num_outputs = 10

# Temporal Dynamics
num_steps = 20                  
beta = 0.95                     

In [6]:
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)                    
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):
        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()                                
        mem2 = self.lif2.init_leaky()                               

        # Record the final layer
        spk2_rec = []
        mem2_rec = []
        #print(x[1].size())
        for step in range(num_steps):
            cur1 = self.fc1(x[step])
            spk1, mem1 = self.lif1(cur1, mem1)     #Aquí spk1 es el output donde en este caso es de tamaño batch, input_size. Es un array de 0s y 1s y para está dentro de un for de 25=num_steps de tiempo y mem es el potencial de membrana siguiente para cada elemento en el batch
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)
        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)      # Son una secuencia de tensores y los apila en la dimensión 0 (tiempo). Dimensiones: mem_rec = [num_steps, batch_size, num_outputs]

# Load the network onto CUDA if available
net = Net().to(device)

In [7]:
CE_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

In [15]:
# Training Parameters
num_epochs = 1
num_classes = 10  
loss_hist = []
test_loss_hist = []
counter = 0
# Entrenamiento del modelo
for epoch in range(num_epochs):
    net.train()
    for data, targets in train_loader:
        spike_data = spikegen.rate(data, num_steps=num_steps)
        spike_data = spike_data.to(device)                                 #.to(device) mueve el tensor data al dispositivo especificado para. Data: [batch_size, num_inputs], targets: 
        targets = targets.to(device) 
        spike_data = spike_data.view(num_steps, data.size(0), -1)
        print(spike_data.size())
        optimizer.zero_grad()
        outputs, mem_rec = net(spike_data)
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += CE_loss(mem_rec[step], targets)           #loss devuelve un escalar
        loss_hist.append(loss_val.item())
        loss_val.backward()
        optimizer.step() 
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss_val.item():.4f}")

torch.Size([20, 160, 784])
torch.Size([20, 160, 10])
torch.Size([20, 160, 10])
torch.Size([160])
torch.Size([20, 160, 784])
torch.Size([20, 160, 10])
torch.Size([20, 160, 10])
torch.Size([160])
torch.Size([20, 160, 784])
torch.Size([20, 160, 10])
torch.Size([20, 160, 10])
torch.Size([160])
torch.Size([20, 160, 784])
torch.Size([20, 160, 10])
torch.Size([20, 160, 10])
torch.Size([160])
torch.Size([20, 160, 784])
torch.Size([20, 160, 10])
torch.Size([20, 160, 10])
torch.Size([160])
torch.Size([20, 160, 784])
torch.Size([20, 160, 10])
torch.Size([20, 160, 10])
torch.Size([160])
torch.Size([20, 160, 784])
torch.Size([20, 160, 10])
torch.Size([20, 160, 10])
torch.Size([160])
torch.Size([20, 160, 784])
torch.Size([20, 160, 10])
torch.Size([20, 160, 10])
torch.Size([160])
torch.Size([20, 160, 784])
torch.Size([20, 160, 10])
torch.Size([20, 160, 10])
torch.Size([160])
torch.Size([20, 160, 784])
torch.Size([20, 160, 10])
torch.Size([20, 160, 10])
torch.Size([160])
torch.Size([20, 160, 784])
tor

In [22]:
# Prueba del modelo
net.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, targets in test_loader:
        targets = targets.to(device)
        spike_data = spikegen.rate(data, num_steps=num_steps)
        spike_data = spike_data.to(device)
        spike_data = spike_data.view(num_steps, data.size(0), -1)
        outputs, mem = net(spike_data)
        #_, predicted = torch.max(outputs, 1)
        _, pred = outputs.sum(dim=0).max(1)
        correct += (pred == targets).sum().item()
        total += targets.size(0)
        #correct += (predicted == targets).sum().item()

print(f"Accuracy: {100 * correct / total:.2f}%")
accuracy = 100. * correct / len(test_loader.dataset)

Accuracy: 91.98%


In [2]:
plt.figure(figsize=(12, 6))
plt.plot(loss_hist, label='Training Loss')
plt.plot(test_loss_hist, label='Test Loss', color='red')
plt.xlabel('Batch Number')
plt.ylabel('Loss')
plt.title('Training Loss per Batch')
plt.legend()
plt.grid(True)
plt.show()

NameError: name 'plt' is not defined