Valerio Pepe, CS 226r Spring 2024



---



Code References:

Alfredo Canziani's MNIST Tutorial: https://www.youtube.com/watch?v=OMDn66kM9Qc (base for the MNIST Model and training loop)

LeNet-5 PyTorch Architecture adapted from here: https://github.com/lychengrex/LeNet-5-Implementation-Using-Pytorch/blob/master/LeNet-5%20Implementation%20Using%20Pytorch.ipynb

FGSM Attack PyTorch implementation adapted from here: https://github.com/pytorch/tutorials/blob/main/beginner_source/fgsm_tutorial.py

In [None]:
import torch
import numpy as np
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
from torch.utils.data import random_split, DataLoader

In [None]:
train_data = datasets.MNIST("data", train=True, download=True, transform=transforms.ToTensor())
#train_data = datasets.FashionMNIST("data", train=True, download=True, transform=transforms.ToTensor())
train, val = random_split(train_data, [55000,5000])
train_loader = DataLoader(train,batch_size=32)
val_loader = DataLoader(val, batch_size=32)

In [None]:
class LeNet(nn.Module):
    # network structure
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        '''
        One forward pass through the network.

        Args:
            x: input
        '''
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.log_softmax(self.fc3(x))
        return x

    def num_flat_features(self, x):
        '''
        Get the number of features in a batch of tensors `x`.
        '''
        size = x.size(1)*x.size(2)*x.size(3)
       # print(size)
        return size #np.prod(list(size))

In [None]:
model = LeNet()
model = create_feature_extractor(model, {'conv1': 'feat1',
                                         'conv2': 'feat2',
                                         'fc1': 'feat3',
                                         'fc2': 'feat4',
                                         'fc3': 'feat5',
                                         'log_softmax': 'feat_out'})
params = model.parameters()
optimiser = optim.SGD(params, lr=1e-2)
#optimiser = optim.AdamW(params)
loss = nn.CrossEntropyLoss()

nb_epochs = 2
for epoch in range(nb_epochs):
  losses = list()
  accuracies = list()
  model.train()
  for batch in train_loader:
    x,y = batch

    b = x.size(0)
    try:
      x = x.view(32, 1, 28, 28)
    except RuntimeError:
      break

    l = model(x)
    J = loss(l['feat_out'],y)

    model.zero_grad()

    J.backward()

    optimiser.step()

    losses.append(J.item())
    accuracies.append(y.eq(l['feat_out'].detach().argmax(dim=1)).float().mean())

  print(f"Epoch {epoch + 1}", end=", ")
  print(f"Training Loss: {torch.tensor(losses).mean():.2f}", end=", ")
  print(f"Training Accuracy: {torch.tensor(accuracies).mean():.2f}")

  losses = list()
  accuracies = list()
  model.eval()
  for batch in val_loader:
    x,y = batch

    b = x.size(0)
    try:
      x = x.view(32, 1, 28, 28)
    except RuntimeError:
      break

    with torch.no_grad():
      l = model(x)

    J = loss(l['feat_out'],y)

    losses.append(J.item())
    accuracies.append(y.eq(l['feat_out'].detach().argmax(dim=1)).float().mean())

  print(f"Epoch {epoch + 1}", end=", ")
  print(f"Validation Loss: {torch.tensor(losses).mean():.2f}", end=", ")
  print(f"Validation Accuracy: {torch.tensor(accuracies).mean():.2f}")

  torch.save(model.state_dict(), "CNN_MNIST-FASHION.pt")

In [None]:
model = LeNet()
model = create_feature_extractor(model, {'conv1': 'feat1',
                                         'conv2': 'feat2',
                                         'fc1': 'feat3',
                                         'fc2': 'feat4',
                                         'fc3': 'feat5',
                                         'log_softmax': 'feat_out'})
model.load_state_dict(torch.load("CNN_MNIST.pt"))
params = model.parameters()
optimiser = optim.SGD(params, lr=1e-2)
loss = nn.CrossEntropyLoss()

In [None]:
def fgsm_attack(image, epsilon, data_grad):
    # Collect the element-wise sign of the data gradient
    sign_data_grad = data_grad.sign()

    # Create the perturbed image by adjusting each pixel of the input image
    perturbed_image = image + epsilon * sign_data_grad

    # Clip the perturbed image values to ensure they stay within the valid range
    perturbed_image = torch.clamp(perturbed_image, 0, 1)

    return perturbed_image

In [None]:
#makes centroids
from PIL import Image
from numpy import asarray, divide
import numpy as np
from collections import Counter
from IPython.display import Image as im
classes = range(0,10)
centroids = {}
#centroid_sets = ["","a"]
centroid_sets = [""]

for class_num in classes:
  centroids[class_num] = []

for centroid_set in centroid_sets:
  for class_num in classes:
    filestring = f"{class_num}{centroid_set}.jpg" #jpg for mnist, png for mnist-fashion
    data = torch.from_numpy(divide(asarray(Image.open(filestring)),255)).float()
    data.requires_grad = True
    b = data.size(0)**2
    #try:
    data = data.view(1, 1, 28, 28)
    #except RuntimeError:
    #  break

    features = model(data)

    J = loss(features["feat_out"],torch.Tensor([class_num]).type(torch.LongTensor))
    data.retain_grad()
    J.backward()
    centroids[class_num].append(features)


In [None]:
from tqdm import tqdm
import random

total = 5000
test_num = 200
min_eps = 0
max_eps = 128

perturb, _ = random_split(val, [test_num,total-test_num])
perturb_loader = DataLoader(perturb, batch_size=1)

graph_accuracies = dict()

for eps in range(min_eps,max_eps,4):
  losses = list()
  accuracies = list()
  defended_accuracies = list()
  defended_looser = list()
  for batch in perturb_loader:
    x,y = batch
    x.requires_grad = True
    x.retain_grad()

    b = x.size(0)
    try:
      x = x.view(1, 1, 28, 28)
    except RuntimeError:
      break

    l = model(x)
    J = loss(l["feat_out"],y)
    x.retain_grad()
    J.backward()

    data_grad = x.grad.data

    perturbed_input = fgsm_attack(x,eps/255,data_grad)
    b = perturbed_input.size(0)
    try:
      perturbed_input = perturbed_input.view(1, 1, 28, 28)
    except RuntimeError:
      break

    perturbed_features = model(perturbed_input)
    J_p = loss(perturbed_features["feat_out"],y)

    losses.append(J_p.item())
    accuracies.append(y.eq(perturbed_features["feat_out"].detach().argmax(dim=1)).float())

    #DEFENSE

    layer_labels = list(perturbed_features.keys())
    similarities = {}
    #thresholds = {0: 0.3, #MNIST Thresholds
    #              1: 0.4,
    #              2: 0.35,
    #              3: 0.3,
     #             4: 0.35,
     #             5: 0.3,
     #             6: 0.3,
     #             7: 0.4,
     #             8: 0.3,
     #             9: 0.3}

    #thresholds = {0: 0.2, #MNIST-Fashion Thresholds
    #              1: 0.2,
    #              2: 0.3,
    #              3: 0.2,
    #              4: 0.35,
    #              5: 0.3,
    #              6: 0.3,
    #              7: 0.25,
    #              8: 0.2,
    #              9: 0.2}

    for label in layer_labels:
      similarities[label] = []
      for index in range(len(centroid_sets)):
        for centroid in classes:
          sh = np.prod(list(perturbed_features[label].shape))
          cos_sim_1 = perturbed_features[label]
          cos_sim_2 = centroids[centroid][index][label]
          sim = F.cosine_similarity(cos_sim_1,cos_sim_2).detach().tolist()[0]

          if type(sim) == list:
            simil_sum, simil_len = 0, 0
            for sim_list in sim:
              simil_sum += sum(sim_list)
              simil_len += len(sim_list)
            avg_sim = simil_sum/simil_len
          else:
            avg_sim = sim
          similarities[label].append(avg_sim)
        similarities[label] = [(1-i)/2 for i in similarities[label]]
        similarities[label] = [i for i,score in enumerate(similarities[label]) if score <= thresholds[i]]

    image_similarities = list(similarities.values())
    scores = {i: 0 for i in classes}
    scores[1000] = 0
    for (layer,similarities) in enumerate(image_similarities):
      score = 1
      if len(similarities) == 0:
        scores[1000] += round(score,1)
      for image_class in similarities:
        if image_class >= 10:
          image_class -= 10
        scores[image_class] += round(score,1)

    max_scores = [list(scores.keys())[index] for index,value in enumerate(list(scores.values())) if value == max(list(scores.values()))]

    defended_accuracies.append(y.eq(max_scores[0]).float())
    defended_looser.append(float(y.item() in max_scores and len(max_scores) < 4))

  graph_accuracies[eps] = [torch.tensor(accuracies).mean().item(), torch.tensor(defended_accuracies).mean().item(), torch.tensor(defended_looser).mean().item()]


In [None]:
import matplotlib.pyplot as plt

eps = list(graph_accuracies.keys())
packed_acc = list(graph_accuracies.values())
defenseless, defended, defended_looser = list(), list(), list()

for pack in packed_acc:
  defenseless.append(pack[0])
  defended.append(pack[1])
  defended_looser.append(pack[2])

eps = [i/255 for i in eps]

plt.plot(eps, defenseless, label="Defenseless")
plt.plot(eps, defended, label="Defense")
plt.plot(eps, defended_looser, label="Defense (Loose)")
plt.plot(eps, [1/11 for i in range(len(eps))], label="Chance")
plt.title("LeNet accuracy against FGSM strength")
plt.xlabel("Epsilon")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

#for idx,e in enumerate(eps):
#  print(e*255, round(e,2), defenseless[idx], defended[idx], defended_looser[idx])