In [1]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision import utils as utls

import snntorch as snn
from snntorch import utils
from snntorch import surrogate

import numpy as np

from tqdm.auto import tqdm
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
#https://github.com/jeshraghian/snntorch/blob/master/examples/tutorial_sae.ipynb
class SAE(torch.nn.Module):
    def __init__(self, latent_dim, beta=0.9, spike_grad=None, threshold=1.0):
        super().__init__()
        self.latent_dim = latent_dim
        if spike_grad == None: spike_grad = surrogate.fast_sigmoid(slope=25)
        
        self.encoder = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1, stride=2),
                                     nn.BatchNorm2d(32),
                                     snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, threshold=threshold),
                                     nn.Conv2d(32, 64, 3, padding=1, stride=2),
                                     nn.BatchNorm2d(64),
                                     snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, threshold=threshold),
                                     nn.Conv2d(64, 128, 3, padding=1, stride=2),
                                     nn.BatchNorm2d(128),
                                     snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, threshold=threshold),
                                     nn.Flatten(start_dim=1, end_dim=3),
                                     nn.Linear(128*4*4, latent_dim),
                                     snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True, threshold=threshold))
        
        # From latent back to tensor for convolution
        self.linear_net = nn.Sequential(nn.Linear(latent_dim,128*4*4),
                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=threshold))
        # Decoder
        self.decoder = nn.Sequential(nn.Unflatten(1,(128,4,4)), #Unflatten data from 1 dim to tensor of 128 x 4 x 4
                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=threshold),
                            nn.ConvTranspose2d(128, 64, 3,padding = 1,stride=(2,2),output_padding=1),
                            nn.BatchNorm2d(64),
                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=threshold),
                            nn.ConvTranspose2d(64, 32, 3,padding = 1,stride=(2,2),output_padding=1),
                            nn.BatchNorm2d(32),
                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=threshold),
                            nn.ConvTranspose2d(32, 1, 3,padding = 1,stride=(2,2),output_padding=1),
                            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,output=True,threshold=20000) #make large so membrane can be trained
                            )
    
    def forward(self, x):
        #Reset hidden states of LIF nodes.
        #Dim x: [Batch, Channels, Width, Length]
        batch_size, channels, width, height = x.shape
        for net in [self.encoder, self.linear_net, self.decoder]:
            utils.reset(net)
            
        spk_mem, spk_rec, encoded_x = [], [], []
        for step in range(num_steps:=5):
            spk_x, mem_x = self.encoder(x) # Output spike trains and neuron membrane states
            spk_rec.append(spk_x)
            spk_mem.append(mem_x)
        spk_rec = torch.stack(spk_rec, dim=2)
        spk_mem = torch.stack(spk_mem, dim=2) #Dimensions: [Batch, Latent Dim, Time]

        spk_mem2, spk_rec2, decoded_x = [], [], []
        for step in range(num_steps):
            x_recon, x_mem_recon = self.decode(spk_rec[:, :, step])
            spk_rec2.append(x_recon)
            spk_mem2.append(x_mem_recon)
        spk_rec2 = torch.stack(spk_rec2, dim=4)
        spk_mem2 = torch.stack(spk_mem2, dim=4) #Dimensions: [Batch, Channels, Width, Length, Time]
        out = spk_mem2[:,:,:,:,-1]
        return out                
        
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, x):
        spk_x, mem_x = self.linear_net(x)
        spk_x2, mem_x2 = self.decoder(spk_x)
        return spk_x2, mem_x2

In [3]:
# dataloader arguments
batch_size = 250
data_path='/data/mnist'
dtype = torch.float

# Define a transform
input_size = 32 #for the sake of this tutorial, we will be resizing the original MNIST from 28 to 32

transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

# Load MNIST

# Training data
train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Testing data
test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [4]:
def train(cortex, train_loader, optimizer, epoch, max_epoch=10):
    cortex = cortex.train()
    train_loss_history = []
    
    for batch_idx, (real_img, labels) in enumerate(pbar:=tqdm(train_loader)):
        optimizer.zero_grad()
        real_img = real_img.to(device)
        labels = labels.to(device)
        
        x_recon = cortex(real_img)
        
        loss = F.mse_loss(x_recon, real_img)
        
        pbar.desc = f'Train[{epoch:02}/{max_epoch}][{batch_idx+1:03}/{len(train_loader)}] Loss: {loss.item():.5f}'
        
        loss.backward()
        optimizer.step()
        

In [5]:
def test(cortex, test_loader, epoch, max_epoch=10):
    cortex = cortex.eval()
    test_loss_history = []
    
    with torch.no_grad():
        for batch_idx, (real_img, labels) in enumerate(pbar:=tqdm(test_loader)):
            real_img, labels = real_img.to(device), labels.to(device)
            x_recon = cortex(real_img)
            
            loss = F.mse_loss(x_recon, real_img)
            
            
            pbar.desc = f'Train[{epoch:02}/{max_epoch}][{batch_idx+1:03}/{len(test_loader)}] Loss: {loss.item():.5f}'
            
            if batch_idx == len(test_loader)-1:
                utls.save_image((real_img+1)/2, f'figures/testing/epoch{epoch}_finalbatch_inputs.png')
                utls.save_image((x_recon+1)/2, f'figures/testing/epoch{epoch}_finalbatch_recons.png')
    return loss

In [6]:
# create training/ and testing/ folders in your chosen path
if not os.path.isdir('figures/training'):
    os.makedirs('figures/training')
    
if not os.path.isdir('figures/testing'):
    os.makedirs('figures/testing')

In [7]:
batch_size = 250
input_size = 32

dtype = torch.float
spike_grad = surrogate.atan(alpha=2)
beta = 0.5
num_steps = 5
latent_dim = 32
threshold = 1
epochs = 10
max_epoch = epochs

cortex = SAE(latent_dim=latent_dim, beta=beta, spike_grad=spike_grad, threshold=1).to(device)

optimizer = torch.optim.AdamW(cortex.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=1e-3)

for idx_epoch in range(epochs):
    train_loss = train(cortex, train_loader=train_loader, optimizer=optimizer, epoch=idx_epoch, max_epoch=max_epoch)
    test_loss = test(cortex, test_loader=test_loader, epoch=idx_epoch, max_epoch=max_epoch)

  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]