In [None]:
import statistics
import numpy as np
import tqdm

import art

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader


from art.estimators.classification import PyTorchClassifier

device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(device)

### Model Setup

In [None]:
batch_size = 64
num_classes = 10
epochs = 5

In [None]:
from art.utils import load_mnist

(x_train, y_train), (x_test, y_test), min_pixel_value, max_pixel_value = load_mnist()
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)


x_train = np.transpose(x_train, (0, 3, 1, 2)).astype(np.float32)
x_test = np.transpose(x_test, (0, 3, 1, 2)).astype(np.float32)
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
print(min_pixel_value, max_pixel_value)
#print(x_train[0][0])

In [None]:
mnist_train = FashionMNIST(root="data", download=True, transform=transforms.ToTensor(), train=True)

train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)

mnist_test = FashionMNIST(root="data", download=True, transform=transforms.ToTensor(), train=False)

test_loader = DataLoader(mnist_test)

x_train, y_train = mnist_train.data, mnist_train.targets
x_test, y_test = mnist_test.data, mnist_test.targets
min_pixel_value, max_pixel_value = 0.0, 1.0

x_train_lin = x_train.reshape((60000, 28*28)).type(torch.float32) / 255.0
x_test_lin = x_test.reshape((10000, 28*28)).type(torch.float32) / 255.0
x_train = x_train.reshape((60000, 1, 28, 28)).type(torch.float32) / 255.0
x_test = x_test.reshape((10000, 1, 28, 28)).type(torch.float32) / 255.0
y_train = F.one_hot(y_train, num_classes).type(torch.float32)
y_test = F.one_hot(y_test, num_classes).type(torch.float32)

print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)

In [None]:
# Feed Forward Neural Network

# model architecture from: https://www.kaggle.com/code/pankajj/fashion-mnist-with-pytorch-93-accuracy

class ConvClassifier(nn.Module):
    def __init__(self, num_classes: int, activation_function=nn.ReLU(), loss_fn=nn.CrossEntropyLoss()):
        super().__init__()
        self.conv_layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            activation_function,
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv_layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.BatchNorm2d(64),
            activation_function,
            nn.MaxPool2d(2)
        )
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(64 * 6 * 6, 600),
            nn.Dropout(0.25),
            activation_function,
            nn.Linear(600, 120),
            activation_function,
            nn.Linear(120, num_classes)  # no softmax necessary
        )

        self.num_classes = num_classes
        self.loss_fn = loss_fn

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_layer1(x)
        x = self.conv_layer2(x)
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

class MLPClassifier(nn.Module):
    def __init__(self, input_dim: int, num_classes: int, activation_function=nn.ReLU(), loss_fn=nn.CrossEntropyLoss()):
        super().__init__()
        self.fc_layers = nn.Sequential(
            nn.Linear(in_features=input_dim, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=num_classes),
        )

        self.num_classes = num_classes
        self.loss_fn = loss_fn

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc_layers(x)

def train_network(model, dataloader: DataLoader, optimizer: torch.optim.Optimizer):
    model.train()
    losses = []
    for batch, (X, y) in enumerate(tqdm.tqdm(dataloader)):
        X, y = X.to(device), y.to(device)
        # Compute prediction error
        y_pred = model.forward(X)
        y_true = F.one_hot(y, model.num_classes).type(y_pred.dtype)
        loss = model.loss_fn(y_pred, y_true)
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch Train Loss: {statistics.mean(losses)}")

def test(model, dataloader: DataLoader):
    num_batches = len(dataloader)
    model.eval()
    test_loss = 0
    num_correct = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            y_pred = model.forward(X)
            test_loss += model.loss_fn(y_pred, F.one_hot(y, model.num_classes).type(y_pred.dtype)).item()
            num_correct += (torch.argmax(y_pred, dim=1) == y).sum().item()
    test_loss /= num_batches
    accuracy = num_correct / len(dataloader.dataset)
    print(f"Test Error: Avg loss: {test_loss:>8f} Accuracy: {float(accuracy):>8f} \n")

def test_art(model, dataloader: DataLoader, criterion, num_classes):
    num_batches = len(dataloader)
    #model.eval()
    test_loss = 0
    num_correct = 0
    with torch.no_grad(): 
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            y_pred = torch.from_numpy(model.predict(X, training_mode=False))
            test_loss += criterion(y_pred, F.one_hot(y, num_classes).type(y_pred.dtype)).item()
            num_correct += (torch.argmax(y_pred, dim=1) == y).sum().item()
    test_loss /= num_batches
    accuracy = num_correct / len(dataloader.dataset)
    print(f"Test Error: Avg loss: {test_loss:>8f} Accuracy: {float(accuracy):>8f} \n")

In [None]:
model = ConvClassifier(num_classes=num_classes).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# for t in range(epochs):
#     print(f"Epoch {t+1} -------------------------------")
#     train_network(model, train_loader, optimizer)
#     test(model, test_loader)

criterion = nn.CrossEntropyLoss()

classifier = PyTorchClassifier(
    model=model,
    clip_values=(min_pixel_value, max_pixel_value),
    loss=criterion,
    optimizer=optimizer,
    input_shape=(1, 28, 28),
    nb_classes=num_classes,
)

In [None]:
classifier.fit(x_train, y_train, batch_size=batch_size, nb_epochs=3)

In [None]:
predictions = classifier.predict(x_test)
accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) #/ len(y_test)
print(accuracy)
print("Accuracy on benign test examples: {}%".format(accuracy * 100))

In [None]:
test_art(classifier, test_loader, criterion, num_classes)

In [None]:
linear_model = MLPClassifier(input_dim=28*28, num_classes=num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# for t in range(epochs):
#     print(f"Epoch {t+1} -------------------------------")
#     train_network(model, train_loader, optimizer)
#     test(model, test_loader)

criterion = nn.CrossEntropyLoss()

linear_classifier = PyTorchClassifier(
    model=linear_model,
    clip_values=(min_pixel_value, max_pixel_value),
    loss=criterion,
    optimizer=optimizer,
    input_shape=(28*28),
    nb_classes=num_classes,
)

In [None]:
linear_classifier.fit(x_train_lin, y_train, batch_size=batch_size, nb_epochs=3)

In [None]:
test_art(linear_classifier, test_loader, criterion, num_classes)

### Attacks

In [None]:
def showImage(image):
    plt.imshow(image[0], cmap='gray', interpolation='none')
    plt.show()
showImage(x_test[0])

In [None]:
def plot_first_samples(images, labels, label_list=mnist_train.classes):
    rows = 8
    columns = 16
    fig = plt.figure(figsize=(10, 10))
    for i, (image, label) in enumerate(zip(images, labels)):

        fig.add_subplot(rows, columns, 2*i+1)
        plt.axis('off')
        plt.title(label_list[torch.argmax(label).item()])
        plt.imshow(image[0].detach().numpy(), cmap='gray')

#plot_first_samples(x_test[0:64], y_test[0:64])

In [None]:
def plot_classified_first_samples(X, y, model, label_list=mnist_train.classes):
    rows = 8
    columns = 8
    fig = plt.figure(figsize=(10, 10), layout="constrained")

    y_pred = torch.argmax(model.forward(X), dim=1).to('cpu')
    for i, (image, true_label, predicted_label) in enumerate(zip(X, y, y_pred)):
        fig.add_subplot(rows, columns, i+1)
        plt.axis('off')
        plt.title(label_list[predicted_label.item()] + ("\nCorrect: " + label_list[true_label] if predicted_label != true_label else ""),
                    color="green" if predicted_label == true_label else "red",
                    size="small")
        plt.imshow(image[0].detach().numpy(), cmap='gray')

def plot_both(X, X_p, y, label_list=mnist_train.classes):
    rows = 4
    columns = 8
    fig = plt.figure(figsize=(10, 10), layout="constrained")
    for i, (image, image_p, label) in enumerate(zip(X, X_p, y)):
        fig.add_subplot(rows, columns, 2*i+1)
        plt.axis('off')
        plt.title(label_list[label.item()])
        plt.imshow(image[0].detach().numpy(), cmap='gray')

        fig.add_subplot(rows, columns, 2*i+2)
        plt.axis('off')
        #plt.title("R")
        plt.imshow(image_p[0].detach().numpy(), cmap='gray')

In [162]:
#from art.attacks.evasion import UniversalPerturbation
attack = art.attacks.evasion.UniversalPerturbation(classifier, delta = 0.5)
pertubations = attack.generate(x_test[0:64].numpy(), y_test[0:64].numpy())
show_image(attack.noise)

In [None]:
plot_classified_first_samples(x_test[0:64], torch.argmax(y_test[0:64], dim=1), model)

In [None]:
plot_both(x_test[0:4], torch.from_numpy(pertubations[0:4]), torch.argmax(y_test[0:4], dim=1))

In [None]:
plot_classified_first_samples(torch.from_numpy(pertubations[0:64]), torch.argmax(y_test[0:64], dim=1), model)