In [1]:
import tkinter as tk
from tkinter import ttk
import threading
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
import torchattacks
from spikingjelly.activation_based import neuron, encoding, functional, layer
from torch.utils.tensorboard import SummaryWriter
import os
import time
import torch.utils.data as data
from tqdm import tqdm

In [2]:
from foolbox import TensorFlowModel, accuracy, samples, Model
import eagerpy as ep

In [3]:
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF

In [4]:
from robustbench.utils import clean_accuracy

In [5]:
class SNN(nn.Module):
    def __init__(self, tau=None, beta=None, spike_grad=None, model_type='QIF'):
        super(SNN, self).__init__()
        self.model_type = model_type
        
        # Define the common layers
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(28 * 28, 10, bias=False)

        # Add the appropriate spiking neuron to the layer based on the model type
        if model_type == 'QIF':
            self.layer = nn.Sequential(
                self.flatten,
                self.fc,
                neuron.QIFNode(tau=tau)
            )
        elif model_type == 'Izhikevich':
            self.layer = nn.Sequential(
                self.flatten,
                self.fc,
                neuron.IzhikevichNode(tau=tau)
            )
        elif model_type == 'LIF':
            self.snn1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        else:
            raise ValueError(f"Unsupported model type: {model_type}")

    def forward(self, x: torch.Tensor):
        if self.model_type == 'LIF':
            return self.forward_leaky(x)
        else:
            return self.layer(x)

    def forward_leaky(self, x):
        # Flatten and apply linear transformation
        x = self.flatten(x)
        x = self.fc(x)
        
        # Initialize and reset membrane potential
        mem1 = self.snn1.reset_mem()
        spk1, mem1 = self.snn1(x, mem1)
        return spk1

In [6]:
# Function to start training
def start_training():
    # Create the args dictionary
    args = {
        'model': model_var.get(),
        'attack': attack_var.get(),
        'eps': float(epsilon_var.get().split('/')[0]) / 225,  # Convert fraction to float
        'epochs': int(epochs_var.get()),
        'lr': 1e-2,  # Keep learning rate fixed
        'b': int(batch_size_var.get()),
        'tau': 2.0,  # Fixed tau value
        'beta': 0.6, #fixed
        'T': int(timesteps_var.get()),  # Get T from the user input
        'device': 'cpu',  # or 'cuda' if GPU is available
        'data_dir': './mnist_data',
        'out_dir': './results',
        'resume': None,
        'amp': False,  # Disable AMP since using CPU
        'opt': 'adam',  # Fixed optimizer
        'momentum': 0.9,
        'j': 2,  # Number of workers for data loading
        'dim': float (pixels_var.get()) # Size of patch, in %, for Pixle attack
    }

    # Run the training in a separate thread to keep the GUI responsive
    threading.Thread(target=train_and_evaluate, args=(args,)).start()

In [7]:
# Function to calculate accuracy
@torch.no_grad()
def get_accuracy(model, data_loader, atk=None, n_limit=1e10, device=None):
    model = model.eval()

    if device is None:
        device = next(model.parameters()).device

    correct = 0
    total = 0

    for images, labels in data_loader:
        X = images.to(device)
        Y = labels.to(device)

        if atk:
            X = atk(X, Y)

        outputs = model(X)
        _, predicted = torch.max(outputs.data, 1)
        total += predicted.size(0)
        correct += (predicted == Y).sum().item()

        if total > n_limit:
            break

    return 100 * float(correct) / total

In [8]:
def train_and_evaluate(args):
    # Check the model type and initialize accordingly
    if args['model'] == 'LIF':
        net = SNN(beta=args['beta'], spike_grad=surrogate.fast_sigmoid(slope=25), model_type='LIF')
    else:
        net = SNN(tau=args['tau'], model_type=args['model'])
    
    # Move the model to the specified device (CPU or GPU)
    device = torch.device(args['device'])
    net.to(device)

    # Load datasets
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Normalize((0,), (1,))
    ])
    
    train_dataset = datasets.MNIST(
        root=args['data_dir'],
        train=True,
        transform=transforms.ToTensor(),
        download=False
    )
    test_dataset = datasets.MNIST(
        root=args['data_dir'],
        train=False,
        transform=transforms.ToTensor(),
        download=False
    )
    
    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args['b'],
        shuffle=True,
        drop_last=True,
        num_workers=args['j'],
        pin_memory=True
    )
    test_data_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args['b'],
        shuffle=False,
        drop_last=False,
        num_workers=args['j'],
        pin_memory=True
    )

    # AMP setup and optimizer
    scaler = None
    start_epoch = 0
    max_test_acc = -1
    optimizer = optim.Adam(net.parameters(), lr=args['lr'])

    # Resume from checkpoint if provided
    if args['resume']:
        checkpoint = torch.load(args['resume'], map_location='cpu')
        net.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch'] + 1
        max_test_acc = checkpoint['max_test_acc']

    # Output directory setup
    out_dir = os.path.join(args['out_dir'], f"T{args['T']}_b{args['b']}_{args['opt']}_lr{args['lr']}")
    if args['amp']:
        out_dir += '_amp'
    os.makedirs(out_dir, exist_ok=True)

    with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
        args_txt.write(str(args))

    writer = SummaryWriter(out_dir, purge_step=start_epoch)
    encoder = encoding.PoissonEncoder()

    # Training and evaluation loop
    for epoch in range(start_epoch, args['epochs']):
        start_time = time.time()
        net.train()
        train_loss = 0
        train_acc = 0
        train_samples = 0

        for img, label in tqdm(train_data_loader, desc="Training", unit="batch"):
            optimizer.zero_grad()
            img = img.to(device)
            label = label.to(device)
            label_onehot = F.one_hot(label, 10).float()

            out_fr = 0
            for t in range(args['T']):
                encoded_img = encoder(img)
                out_fr += net(encoded_img)
                
            out_fr /= args['T']
            loss = F.mse_loss(out_fr, label_onehot)
            loss.backward()
            optimizer.step()

            train_samples += label.numel()
            train_loss += loss.item() * label.numel()
            train_acc += (out_fr.argmax(1) == label).float().sum().item()

            if args['model'] != 'LIF':
                functional.reset_net(net)

        train_loss /= train_samples
        train_acc /= train_samples
        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('train_acc', train_acc, epoch)

        # Testing phase with tqdm
        net.eval()
        test_loss = 0
        test_acc = 0
        test_samples = 0
        
        with torch.no_grad():
            for img, label in tqdm(test_data_loader, desc="Testing", unit="batch"):
                img = img.to(device)
                label = label.to(device)
                label_onehot = F.one_hot(label, 10).float()
                out_fr = 0.
                
                for t in range(args['T']):
                    encoded_img = encoder(img)
                    out_fr += net(encoded_img)
                out_fr /= args['T']
                loss = F.mse_loss(out_fr, label_onehot)

                test_samples += label.numel()
                test_loss += loss.item() * label.numel()
                test_acc += (out_fr.argmax(1) == label).float().sum().item()
                
                if args['model'] != 'LIF':
                    functional.reset_net(net)
                
        test_loss /= test_samples
        test_acc /= test_samples
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)

        # Save checkpoint
        save_max = False
        if test_acc > max_test_acc:
            max_test_acc = test_acc
            save_max = True

        checkpoint = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
            'max_test_acc': max_test_acc
        }

        if save_max:
            torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth'))

        torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth'))

        # Print results
        print(f'epoch = {epoch}, train_loss = {train_loss:.4f}, train_acc = {train_acc:.4f}, test_loss = {test_loss:.4f}, test_acc = {test_acc:.4f}, max_test_acc = {max_test_acc:.4f}')

    # Adversarial accuracy evaluation using the updated approach
    if args['attack'] == 'PGD':
        attack = attack = torchattacks.PGD(net, eps=args['eps'], alpha=2/255, steps=10, random_start=True)
    elif args['attack'] == 'FGSM':
        attack = torchattacks.FGSM(net, eps=args['eps'])
    else: 
        attack = torchattacks.Pixle(model=net,x_dimensions=args['dim'],y_dimensions=args['dim'])

    
    attack.set_mode_targeted_by_label()

    
    # Define a transform
    transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

    dataset = torchvision.datasets.MNIST(
    root=args['data_dir'],
    transform=transform)

    data_loader = data.DataLoader(dataset, batch_size=args['b'], shuffle=False, drop_last=False, num_workers=2)
    
    # Perform adversarial attack
    images, labels = next(iter(data_loader))
    images, labels = images.to(device), labels.to(device)

    new_labels = (labels + 1) % 10 # New labels

    adv_images = attack(images, new_labels)

    # Calculate clean accuracy
    #clean_acc = get_accuracy(net, data_loader, device=device)
    clean_acc = clean_accuracy(net, images, labels)
    
    # Calculate adversarial accuracy
    with torch.no_grad():
        outputs = net(adv_images)
        _, preds = torch.max(outputs, 1)
        correct = (preds == labels).sum().item()
        adv_acc1 = correct / labels.size(0)

    adv_acc2 = clean_accuracy(net, adv_images, labels)

    print(f'Model Choice: {args['model']}')
    print(f'Attack Choice: {args['attack']}')
    print(f'Clean Accuracy: {clean_acc * 100:.2f}%')
    print(f'Robust Accuracy1: {adv_acc1 * 100:.2f}%')
    print(f'Robust Accuracy2: {adv_acc2 * 100:.2f}%')
    print(f'Accuracy Degradation: {(clean_acc - adv_acc2) * 100:.2f}%')
    print(f'Time Steps: {args['T']}')
    print(f'Epsilon x 225: {args['eps'] * 225}')
    
    if args['attack'] == 'Pixle':
        print(f'(Pixle) Size of Patch: {args['dim'] * 100:.2f}%')


    with open(os.path.join(out_dir, 'accuracies.txt'), 'w') as f:
        f.write(f'Clean Accuracy: {clean_acc * 100:.2f}%\n')
        f.write(f'Adversarial Accuracy1: {adv_acc1 * 100:.2f}%\n')
        f.write(f'Adversarial Accuracy1: {adv_acc1 * 100:.2f}%\n')

In [9]:
# Tkinter UI setup
root = tk.Tk()
root.title("(T) SNN Training App")

# Model selection
ttk.Label(root, text="Select Model:").grid(row=0, column=0, padx=10, pady=10)
model_var = tk.StringVar(value="QIF")
model_dropdown = ttk.Combobox(root, textvariable=model_var, values=["QIF", "Izhikevich", "LIF"])
model_dropdown.grid(row=0, column=1, padx=10, pady=10)

# Attack selection
ttk.Label(root, text="Select Attack:").grid(row=1, column=0, padx=10, pady=10)
attack_var = tk.StringVar(value="PGD")
attack_dropdown = ttk.Combobox(root, textvariable=attack_var, values=["PGD", "FGSM", "Pixle"])
attack_dropdown.grid(row=1, column=1, padx=10, pady=10)

# Epsilon selection (fractions of 2k/225)
ttk.Label(root, text="Select Epsilon:").grid(row=2, column=0, padx=10, pady=10)
epsilon_var = tk.StringVar(value="8/225")
epsilon_dropdown = ttk.Combobox(root, textvariable=epsilon_var, values=[f'{2*k}/225' for k in [4, 8, 16, 32]])
epsilon_dropdown.grid(row=2, column=1, padx=10, pady=10)

epochs_var = tk.StringVar(value="1")
batch_size_var = tk.StringVar(value="16") #Batch Size stays at 16

# Timesteps (T) entry
ttk.Label(root, text="Timesteps (T):").grid(row=5, column=0, padx=10, pady=10)
timesteps_var = tk.StringVar(value="15")
timesteps_entry = ttk.Entry(root, textvariable=timesteps_var)
timesteps_entry.grid(row=5, column=1, padx=10, pady=10)

# Pixle Attack entry
ttk.Label(root, text="(For Pixle Attack): Size of patch:").grid(row=6, column=0, padx=10, pady=10)
value_mapping = {f"{i}%": f"{i / 100:.2f}" for i in range(10, 105, 10)}
pixels_var = tk.StringVar(value="0.1")
pixels_combobox = ttk.Combobox(
    root,
    textvariable=pixels_var,
    values=list(value_mapping.keys()),
    state="readonly"
)
pixels_combobox.grid(row=6, column=1, padx=10, pady=10)

# Bind the combobox selection event to the update function
pixels_combobox.bind("<<ComboboxSelected>>", lambda event: pixels_var.set(value_mapping[pixels_combobox.get()]))

# Start button
start_button = ttk.Button(root, text="Start Training", command=start_training)
start_button.grid(row=7, columnspan=2, padx=10, pady=20)

In [10]:
# Run the application
root.mainloop()

Training: 100%|█████████████████████████████████████████████████████████████████| 3750/3750 [00:50<00:00, 73.99batch/s]
Testing: 100%|████████████████████████████████████████████████████████████████████| 625/625 [00:08<00:00, 74.91batch/s]


epoch = 0, train_loss = 0.0394, train_acc = 0.7931, test_loss = 0.0491, test_acc = 0.7443, max_test_acc = 0.7443
Attack mode is changed to 'targeted(label)'.
Model Choice: Izhikevich
Attack Choice: Pixle
Clean Accuracy: 81.25%
Robust Accuracy1: 12.50%
Robust Accuracy2: 25.00%
Accuracy Degradation: 56.25%
Time Steps: 15
Epsilon x 225: 8.0
(Pixle) Size of Patch: 20.00%


Training: 100%|█████████████████████████████████████████████████████████████████| 3750/3750 [00:51<00:00, 72.14batch/s]
Testing: 100%|████████████████████████████████████████████████████████████████████| 625/625 [00:08<00:00, 75.57batch/s]


epoch = 0, train_loss = 0.0330, train_acc = 0.8318, test_loss = 0.0332, test_acc = 0.8384, max_test_acc = 0.8384
Attack mode is changed to 'targeted(label)'.
Model Choice: Izhikevich
Attack Choice: Pixle
Clean Accuracy: 81.25%
Robust Accuracy1: 6.25%
Robust Accuracy2: 12.50%
Accuracy Degradation: 68.75%
Time Steps: 15
Epsilon x 225: 8.0
(Pixle) Size of Patch: 40.00%


Training: 100%|█████████████████████████████████████████████████████████████████| 3750/3750 [00:52<00:00, 72.01batch/s]
Testing: 100%|████████████████████████████████████████████████████████████████████| 625/625 [00:08<00:00, 75.91batch/s]


epoch = 0, train_loss = 0.0321, train_acc = 0.8336, test_loss = 0.0262, test_acc = 0.8595, max_test_acc = 0.8595
Attack mode is changed to 'targeted(label)'.
Model Choice: Izhikevich
Attack Choice: Pixle
Clean Accuracy: 81.25%
Robust Accuracy1: 0.00%
Robust Accuracy2: 6.25%
Accuracy Degradation: 75.00%
Time Steps: 15
Epsilon x 225: 8.0
(Pixle) Size of Patch: 60.00%
