In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import torch
from tqdm import tqdm
from art.attacks.evasion import FastGradientMethod,  ProjectedGradientDescent, CarliniL2Method
from art.estimators.classification import PyTorchClassifier
from art.utils import load_mnist
from art.defences.preprocessor import LabelSmoothing, SpatialSmoothing, GaussianAugmentation
from art.defences.trainer import AdversarialTrainerMadryPGD

# Step 0: Define the neural network model, return logits instead of activation in forward method


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5, stride=1)
        self.conv_2 = nn.Conv2d(in_channels=4, out_channels=10, kernel_size=5, stride=1)
        self.fc_1 = nn.Linear(in_features=4 * 4 * 10, out_features=100)
        self.fc_2 = nn.Linear(in_features=100, out_features=10)

    def forward(self, x):
        x = F.relu(self.conv_1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 10)
        x = F.relu(self.fc_1(x))
        x = self.fc_2(x)
        return x


# Step 1: Load the MNIST dataset

(train_images, train_labels), (test_images, test_labels), min_pixel_value, max_pixel_value = load_mnist()

# Step 1a: Swap axes to PyTorch's NCHW format

train_images = np.transpose(train_images, (0, 3, 1, 2)).astype(np.float32)
test_images = np.transpose(test_images, (0, 3, 1, 2)).astype(np.float32)


# Apply preprocessing defenses
label_smoothing = LabelSmoothing(max_value=0.7, apply_fit=True, apply_predict=False)
spatial_smoothing = SpatialSmoothing()
gaussian_augmentation = GaussianAugmentation(sigma=0.1, augmentation=False, ratio=1.0)
# Apply defenses sequentially
train_images, train_labels = label_smoothing(train_images, train_labels)
train_images = spatial_smoothing(train_images)[0]
train_images = gaussian_augmentation(train_images)[0]


# Step 2: Create the model

model = Net()

# Step 2a: Define the loss function and the optimizer

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Step 3: Create the ART classifier

classifier = PyTorchClassifier(
    model=model,
    clip_values=(min_pixel_value, max_pixel_value),
    loss=criterion,
    optimizer=optimizer,
    input_shape=(1, 28, 28),
    nb_classes=10,
)

# Step 4: Train the ART classifier

# Create attacks for adversarial training
fgsm_attack = FastGradientMethod(estimator=classifier, eps=0.1)
pgd_attack = ProjectedGradientDescent(classifier, eps=0.1, max_iter=10)

# Adversarial training loop
for epoch in tqdm(range(5)):
    # Generate adversarial examples
    train_images_fgsm = fgsm_attack.generate(train_images)
    train_images_pgd = pgd_attack.generate(train_images)

    # Combine clean and adversarial examples
    #combined_images = np.concatenate((train_images, train_images_fgsm))
    combined_images = np.concatenate((train_images, train_images_fgsm, train_images_pgd))
    #combined_labels = np.concatenate((train_labels, train_labels, ))
    combined_labels = np.concatenate((train_labels, train_labels, train_labels))
    # Train on mixed dataset
    classifier.fit(combined_images, combined_labels, batch_size=64, nb_epochs=1)

# Save model
torch.save(model.state_dict(), "mnist_model_robust.pth")

# Evaluate against multiple attacks
attack_methods = [
    FastGradientMethod(classifier, eps=0.1),
    ProjectedGradientDescent(classifier, eps=0.1),
    CarliniL2Method(classifier)
]

# Test clean accuracy
test_predictions = classifier.predict(test_images)
clean_accuracy = np.sum(np.argmax(test_predictions, axis=1) ==
                       np.argmax(test_labels, axis=1)) / len(test_labels)
print(f"\nClean test accuracy: {clean_accuracy * 100:.2f}%")

# Test adversarial accuracy
for attack_method in attack_methods:
    adversarial_images = attack_method.generate(test_images)
    adversarial_predictions = classifier.predict(adversarial_images)
    adversarial_accuracy = np.sum(np.argmax(adversarial_predictions, axis=1) ==
                                np.argmax(test_labels, axis=1)) / len(test_labels)
    print(f"Accuracy against {attack_method.__class__.__name__}: {adversarial_accuracy * 100:.2f}%")