Valerio Pepe, CS 226r Spring 2024



---



Code References:

Alfredo Canziani's MNIST Tutorial: https://www.youtube.com/watch?v=OMDn66kM9Qc (base for the MNIST Model and training loop)

FGSM Attack PyTorch implementation adapted from here: https://github.com/pytorch/tutorials/blob/main/beginner_source/fgsm_tutorial.py

In [3]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
from torch.utils.data import random_split, DataLoader

In [None]:
train_data = datasets.MNIST("data", train=True, download=True, transform=transforms.ToTensor())
#train_data = datasets.FashionMNIST("data", train=True, download=True, transform=transforms.ToTensor())
train, val = random_split(train_data, [55000,5000])
train_loader = DataLoader(train,batch_size=32)
val_loader = DataLoader(val, batch_size=32)

In [4]:
class mnist_classifier(nn.Module):
    def __init__(self):
      super().__init__()
      self.linear_in = nn.Linear(28*28, 64)
      self.linear1 = nn.Linear(64, 64)
      self.linear2 = nn.Linear(64, 64)
      self.linear3 = nn.Linear(64, 64)
      self.linear4 = nn.Linear(64, 64)
      self.linear5 = nn.Linear(64, 64)
      self.linear_out = nn.Linear(64, 10)

    def forward(self, x):
      x = F.relu(self.linear_in(x))
      x = F.relu(self.linear1(x))
      x = F.relu(self.linear2(x))
      x = F.relu(self.linear3(x))
      x = F.relu(self.linear4(x))
      x = F.relu(self.linear5(x))
      x = F.log_softmax(self.linear_out(x), dim=1)
      return x

In [None]:
model = mnist_classifier()
model = create_feature_extractor(model, {'linear_in': 'feat_in',
                                         'linear1': 'feat1',
                                         'linear2': 'feat2',
                                         'linear3': 'feat3',
                                         'linear4': 'feat4',
                                         'linear5': 'feat5',
                                         'log_softmax': 'feat_out'})
params = model.parameters()
optimiser = optim.SGD(params, lr=1e-2)
#optimiser = optim.AdamW(params)
loss = nn.CrossEntropyLoss()

nb_epochs = 20
for epoch in range(nb_epochs):
  losses = list()
  accuracies = list()
  model.train()
  for batch in train_loader:
    x,y = batch

    b = x.size(0)
    x = x.view(b,-1)

    l = model(x)
    J = loss(l['feat_out'],y)

    model.zero_grad()

    J.backward()

    optimiser.step()

    losses.append(J.item())
    accuracies.append(y.eq(l['feat_out'].detach().argmax(dim=1)).float().mean())

  print(f"Epoch {epoch + 1}", end=", ")
  print(f"Training Loss: {torch.tensor(losses).mean():.2f}", end=", ")
  print(f"Training Accuracy: {torch.tensor(accuracies).mean():.2f}")

  losses = list()
  accuracies = list()
  model.eval()
  for batch in val_loader:
    x,y = batch

    b = x.size(0)
    x = x.view(b, -1)

    with torch.no_grad():
      l = model(x)

    J = loss(l['feat_out'],y)

    losses.append(J.item())
    accuracies.append(y.eq(l['feat_out'].detach().argmax(dim=1)).float().mean())

  print(f"Epoch {epoch + 1}", end=", ")
  print(f"Validation Loss: {torch.tensor(losses).mean():.2f}", end=", ")
  print(f"Validation Accuracy: {torch.tensor(accuracies).mean():.2f}")

torch.save(model.state_dict(), "FC_MNIST-FASHION.pt")

In [24]:
model = mnist_classifier()
model = create_feature_extractor(model, {'linear_in': 'feat_in',
                                         'linear1': 'feat1',
                                         'linear2': 'feat2',
                                         'linear3': 'feat3',
                                         'linear4': 'feat4',
                                         'linear5': 'feat5',
                                         'log_softmax': 'feat_out'})
model.load_state_dict(torch.load("FC_MNIST-FASHION.pt"))
params = model.parameters()
optimiser = optim.SGD(params, lr=1e-2)
loss = nn.CrossEntropyLoss()

In [6]:
def fgsm_attack(image, epsilon, data_grad):
    # Collect the element-wise sign of the data gradient
    sign_data_grad = data_grad.sign()

    # Create the perturbed image by adjusting each pixel of the input image
    perturbed_image = image + epsilon * sign_data_grad

    # Clip the perturbed image values to ensure they stay within the valid range
    perturbed_image = torch.clamp(perturbed_image, 0, 1)

    return perturbed_image

In [None]:
#makes centroids
from PIL import Image
from numpy import asarray, divide
import numpy as np
from collections import Counter
from IPython.display import Image as im
classes = range(0,10)
centroids = {}
#centroid_sets = ["","a"]
#centroid_sets = ["","a","b","c","d"]
centroid_sets = [""]

for class_num in classes:
  centroids[class_num] = []

for centroid_set in centroid_sets:
  for class_num in classes:
    filestring = f"{class_num}{centroid_set}.png" #jpg for mnist, png for mnist-fashion
    data = torch.from_numpy(divide(asarray(Image.open(filestring)),255)).float()
    data.requires_grad = True
    b = data.size(0)**2

    data = data.view(1, b)

    features = model(data)

    J = loss(features["feat_out"],torch.Tensor([class_num]).type(torch.LongTensor))
    data.retain_grad()
    J.backward()

    centroids[class_num].append(features)

In [None]:
from tqdm import tqdm
import random
import math

total = 5000
test_num = 150
max_eps = 128

perturb, _ = random_split(val, [test_num,total-test_num])
perturb_loader = DataLoader(perturb, batch_size=1)

graph_accuracies = dict()

for eps in range(1,max_eps,4):
  losses = list()
  accuracies = list()
  defended_accuracies = list()
  defended_looser = list()
  for batch in perturb_loader:
    x,y = batch
    x.requires_grad = True
    x.retain_grad()

    b = x.size(0)
    x = x.view(b, -1)

    l = model(x)
    J = loss(l["feat_out"],y)
    x.retain_grad()
    J.backward()

    data_grad = x.grad.data

    perturbed_input = fgsm_attack(x,eps/255,data_grad)
    b = perturbed_input.size(0)
    perturbed_input = perturbed_input.view(b,-1)

    perturbed_features = model(perturbed_input)
    J_p = loss(perturbed_features["feat_out"],y)

    losses.append(J_p.item())
    accuracies.append(y.eq(perturbed_features["feat_out"].detach().argmax(dim=1)).float())

    #DEFENSE

    layer_labels = list(perturbed_features.keys())
    similarities = {}
    #thresholds = {0: 0.3, #MNIST Thresholds
    #              1: 0.4,
    #              2: 0.35,
    #              3: 0.3,
     #             4: 0.35,
     #             5: 0.3,
     #             6: 0.3,
     #             7: 0.4,
     #             8: 0.3,
     #             9: 0.3}

    #thresholds = {0: 0.2, #MNIST-Fashion Thresholds
    #              1: 0.2,
    #              2: 0.3,
    #              3: 0.2,
    #              4: 0.35,
    #              5: 0.3,
    #              6: 0.3,
    #              7: 0.25,
    #              8: 0.2,
    #              9: 0.2}
    bias = 0

    for label in layer_labels:
      similarities[label] = []
      for index in range(len(centroid_sets)):
        for centroid in classes:
          similarities[label].append(float(F.cosine_similarity(perturbed_features[label],centroids[centroid][index][label]).detach()))
        similarities[label] = [(1-i)/2 for i in similarities[label]]
        similarities[label] = [i for i,score in enumerate(similarities[label]) if score <= (thresholds[int(str(i)[-1])]+bias)]

    image_similarities = list(similarities.values())
    scores = {i: 0 for i in classes}
    scores[1000] = 0
    for (layer,similarities) in enumerate(image_similarities):
      score = 1
      if len(similarities) == 0:
        scores[1000] += round(score,1)
      for image_class in similarities:
        scores[int(str(image_class)[-1])] += round(score,1)

    max_scores = [list(scores.keys())[index] for index,value in enumerate(list(scores.values())) if value == max(list(scores.values()))]

    defended_accuracies.append(y.eq(max_scores[0]).float())
    defended_looser.append(float(y.item() in max_scores and len(max_scores) < 4))

  graph_accuracies[eps] = [torch.tensor(accuracies).mean().item(), torch.tensor(defended_accuracies).mean().item(), torch.tensor(defended_looser).mean().item()]


In [None]:
import matplotlib.pyplot as plt

eps = list(graph_accuracies.keys())
packed_acc = list(graph_accuracies.values())
defenseless, defended, defended_looser = list(), list(), list()

for pack in packed_acc:
  defenseless.append(pack[0])
  defended.append(pack[1])
  defended_looser.append(pack[2])

eps = [i/255 for i in eps]

plt.plot(eps, defenseless, label="Defenseless")
plt.plot(eps, defended, label="Defense")
plt.plot(eps, defended_looser, label="Defense (Loose)")
plt.plot(eps, [1/11 for i in range(len(eps))], label="Chance")
plt.title("Neural Network accuracy against FGSM strength")
plt.xlabel("Epsilon")
plt.ylabel("Accuracy")
plt.legend()
#plt.show()

#for idx,e in enumerate(eps):
#  print(e*255, round(e,2), defenseless[idx], defended[idx], defended_looser[idx])