In [1]:
#!pip uninstall powerapi
#!pip install powerapi
#!pip install torchvision

In [2]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np

import time


In [3]:
# Leaky neuron model, overriding the backward pass with a custom function
class LeakySurrogate(nn.Module):
  def __init__(self, beta, threshold=1.0):
      super(LeakySurrogate, self).__init__()

      # initialize decay rate beta and threshold
      self.beta = beta
      self.threshold = threshold
      self.spike_gradient = self.ATan.apply
  
  # the forward function is called each time we call Leaky
  def forward(self, input_, mem):
    spk = self.spike_gradient((mem-self.threshold))  # call the Heaviside function
    reset = (self.beta * spk * self.threshold).detach() # remove reset from computational graph
    mem = self.beta * mem + input_ - reset # Eq (1)
    return spk, mem

  # Forward pass: Heaviside function
  # Backward pass: Override Dirac Delta with the ArcTan function
  @staticmethod
  class ATan(torch.autograd.Function):
      @staticmethod
      def forward(ctx, mem):
          spk = (mem > 0).float() # Heaviside on the forward pass: Eq(2)
          ctx.save_for_backward(mem)  # store the membrane for use in the backward pass
          return spk

      @staticmethod
      def backward(ctx, grad_output):
          (mem,) = ctx.saved_tensors  # retrieve the membrane potential 
          grad = 1 / (1 + (np.pi * mem).pow_(2)) * grad_output # Eqn 5
          return grad

In [4]:
lif1 = LeakySurrogate(beta=0.9)

In [5]:
lif1 = snn.Leaky(beta=0.9)

In [6]:
# dataloader arguments
batch_size = 1024
data_path='../datos'

dtype = torch.float
dispositivo = "cuda"

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#device = torch.device("cpu")
#device = torch_ipu.IPUDevice()

In [7]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

In [8]:
# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

In [9]:
# Parámetros de la red
num_inputs = 28 * 28
num_hidden1 = 256
num_hidden2 = 64
num_hidden3 = 128
num_outputs = 10

# Parámetros temporales
num_steps = 25
beta = 0.9

# Definición de una arquitectura muy distinta
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Inicialización de capas
        self.fc1 = nn.Linear(num_inputs, num_hidden1)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden1, num_hidden2)
        self.lif2 = snn.Leaky(beta=beta)
        self.fc3 = nn.Linear(num_hidden2, num_hidden3)
        self.lif3 = snn.Leaky(beta=beta)
        self.fc4 = nn.Linear(num_hidden3, num_outputs)
        self.lif4 = snn.Leaky(beta=beta)

    def forward(self, x):
        # Inicialización de estados ocultos en t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()

        # Registro de actividad de disparo
        spk4_rec = []
        mem4_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)
            cur4 = self.fc4(spk3)
            spk4, mem4 = self.lif4(cur4, mem4)
            spk4_rec.append(spk4)
            mem4_rec.append(mem4)

        return torch.stack(spk4_rec, dim=0), torch.stack(mem4_rec, dim=0)

        
# Load the network onto CUDA if available
net = Net().to(device)

In [10]:
# pass data into the network, sum the spikes over time
# and compare the neuron with the highest number of spikes
# with the target

def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

def train_printer(
    data, targets, epoch,
    counter, iter_counter,
        loss_hist, test_loss_hist, test_data, test_targets):
    #print(f"Epoch {epoch}, Iteration {iter_counter}")
    #print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    #print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    #print_batch_accuracy(data, targets, train=True)
    #print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

In [11]:
loss = nn.CrossEntropyLoss()

In [12]:
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

In [13]:
data, targets = next(iter(train_loader))
data = data.to(device)
targets = targets.to(device)

In [14]:
spk_rec, mem_rec = net(data.view(batch_size, -1))
#print(mem_rec.size())

torch.Size([25, 1024, 10])


In [15]:
# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
for step in range(num_steps):
  loss_val += loss(mem_rec[step], targets)

#print(f"Training loss: {loss_val.item():.3f}")

Training loss: 58.938


In [16]:
#print_batch_accuracy(data, targets, train=True)

Train set accuracy for a single minibatch: 12.40%


In [17]:
# clear previously stored gradients
optimizer.zero_grad()

# calculate the gradients
loss_val.backward()

# weight update
optimizer.step()

In [18]:
# calculate new network outputs using the same data
spk_rec, mem_rec = net(data.view(batch_size, -1))

# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
for step in range(num_steps):
  loss_val += loss(mem_rec[step], targets)

#print(f"Training loss: {loss_val.item():.3f}")
#print_batch_accuracy(data, targets, train=True)

Training loss: 54.836
Train set accuracy for a single minibatch: 18.85%


In [19]:
num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

# Empezamos a medir el tiempo de entrenamiento
start_train_time = time.time()

# Outer training loop
for epoch in range(num_epochs):
    
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        
        data = data.to(device)
        targets = targets.to(device)
        
        # forward pass
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1))

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer(
                    data, targets, epoch,
                    counter, iter_counter,
                    loss_hist, test_loss_hist,
                    test_data, test_targets)
            counter += 1
            iter_counter +=1
            
end_train_time = time.time()
total_train_time = end_train_time - start_train_time


Epoch 0, Iteration 0
Train Set Loss: 56.03
Test Set Loss: 53.13
Train set accuracy for a single minibatch: 31.84%
Test set accuracy for a single minibatch: 28.61%


Epoch 0, Iteration 50
Train Set Loss: 14.94
Test Set Loss: 13.88
Train set accuracy for a single minibatch: 88.18%
Test set accuracy for a single minibatch: 88.28%




In [20]:
total = 0
correct = 0

# Empezamos a medir el tiempo
start_time = time.time()
print(f"{dispositivo} 1024 25")

# drop_last switched to False to keep all samples
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)
# Guardar el modelo entrenado
torch.save(net.state_dict(), '2modelo_entrenado_gpu_1024_25.pth')

with torch.no_grad():
  net.eval()
  for data, targets in test_loader:

    data = data.to(device)
    targets = targets.to(device)
    
    # forward pass
    test_spk, _ = net(data.view(data.size(0), -1))

    # calculate total accuracy
    _, predicted = test_spk.sum(dim=0).max(1)
    total += targets.size(0)
    correct += (predicted == targets).sum().item()

end_time = time.time()
total_time = end_time - start_time
     
print(f"Total correctly classified test set images: {correct}/{total}")
print(f"Test Set Accuracy: {100 * correct / total:.2f}%")

print(f"Tiempo total: {total_time} segundos")
print(f"Tiempo total de entrenamiento: {total_train_time} segundos")

cpu
Total correctly classified test set images: 8908/10000
Test Set Accuracy: 89.08%
Tiempo total: 1.8388621807098389 segundos
Tiempo total de entrenamiento: 23.53195571899414 segundos
