In [None]:
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import torch.nn.functional as F
from maxout import CustomMaxout

import matplotlib.pyplot as plt
import numpy as np


## 1. Chargement de la base de données

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # Normalize the images to [-1, 1]
])

training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=transform
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=transform
)
batch_size = 64

training_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

## 2. Modèle de prédiction

## 3. Défintion de la fonction de coût
$$\tilde{J}(\theta, x, y) = \alpha J(\theta, x, y) + (1 - \alpha) J(\theta, x + \epsilon \cdot \text{sign}(\nabla_x J(\theta, x, y)))$$

On définit dans un premier temps une loss de base $$J(\theta, x, y)$$. Comme rien n'est précisé dans l'article, on choisit la cross-entropy.

In [None]:
# Fonction de perte standard
def loss_fn(model, x, y):
    output = model(x)
    return F.cross_entropy(output, y)

# Fonction de perte adversariale
def adversarial_loss_fn(model, x, y, epsilon, alpha):
    # Calcul de la perte standard
    standard_loss = loss_fn(model, x, y)
    
    # Génération de l'exemple adverse
    x_adv = x + epsilon * torch.sign(torch.autograd.grad(standard_loss, x, create_graph=True)[0])
    
    # Calcul de la perte sur l'exemple adverse
    adversarial_loss = loss_fn(model, x_adv, y)
    
    # Combinaison des deux pertes
    return alpha * standard_loss + (1 - alpha) * adversarial_loss

## 4. Création des modèles utiles

On a besoin d'un modèle  à 240 unit per layer et d'un autre  à 1600

Définition des paramètres utiles à l'entrainement.

In [None]:
#3 chan in MNIST
n_channels = 1
dropout = 0.5

#on créé le premier model qui à 240 unit per layer model
Maxout_240U_Model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28*n_channels, 240),
    CustomMaxout(240, 200, n_channels, True),
    nn.Dropout(dropout),
    nn.Linear(200, 160),
    CustomMaxout(160, 120, n_channels, True),
    nn.Dropout(dropout),
    nn.Linear(120, 80),
    CustomMaxout(80, 40, n_channels, True),
    nn.Dropout(dropout),
    nn.Linear(40, 10),
    nn.LogSoftmax(dim=1)
)

#on créé le second modèle qui a 1600 unit per layer
Maxout_1600U_Model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28*n_channels, 1600),
    CustomMaxout(1600, 1500, n_channels, True),
    nn.Dropout(dropout),
    nn.Linear(1500, 1400),
    CustomMaxout(1400, 1300, n_channels, True),
    nn.Linear(1300, 1200),
    CustomMaxout(1200, 1100, n_channels, True),
    nn.Dropout(dropout),
    nn.Linear(1100, 1000),
    CustomMaxout(1000, 800, n_channels, True),
    nn.Dropout(dropout),
    nn.Linear(800, 600),
    CustomMaxout(600, 400, n_channels, True),
    nn.Dropout(dropout),
    nn.Linear(400, 200),  
    CustomMaxout(200, 100, n_channels, True),
    nn.Dropout(dropout),
    nn.Linear(100, 50),
    CustomMaxout(50, 25, n_channels, True),  
    nn.Dropout(dropout),
    nn.Linear(25, 10),  
    nn.LogSoftmax(dim=1)
)

Entrainement du modèle avec une crossentropy et un optimizer Adam

In [None]:
# Initialize the early stopping variables
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_valid_loss = float('inf')
epochs_no_improve = 0
early_stop_epochs = 20  

# Define the model
model_dict = {"Maxout_240U_Model": Maxout_240U_Model, "Maxout_1600U_Model": Maxout_1600U_Model}

for model_name, model in model_dict.items():
    # Define the optimizer
    model = model.to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Define the loss function
    loss_func = nn.CrossEntropyLoss()

    # Define the number of epochs
    n_epochs = 100

    # Define the training and validation data loaders
    train_dataloader = training_dataloader
    valid_dataloader = test_dataloader

    # Train the model
    train_losses = []
    valid_losses = []
    train_accuracies = []
    valid_accuracies = []
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        correct_train = 0
        total_train = 0
        for batch in train_dataloader:
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

            loss = loss_func(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
        train_losses.append(train_loss / len(train_dataloader))
        train_accuracies.append(100 * correct_train / total_train)

        model.eval()
        valid_loss = 0
        correct_valid = 0
        total_valid = 0
        with torch.no_grad():
            for batch in valid_dataloader:
                inputs, labels = batch
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total_valid += labels.size(0)
                correct_valid += (predicted == labels).sum().item()

                loss = loss_func(outputs, labels)

                valid_loss += loss.item()
        valid_losses.append(valid_loss / len(valid_dataloader))
        valid_accuracies.append(100 * correct_valid / total_valid)

        #early stopping
        # Check if the validation loss has improved
        if valid_losses[-1] < best_valid_loss:
            best_valid_loss = valid_losses[-1]
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # If the validation loss hasn't improved for early_stop_epochs, stop training
        if epochs_no_improve == early_stop_epochs:
            print("Early stopping!")
            break

        print(f'Epoch {epoch+1}/{n_epochs}.. '
              f'Train loss: {train_losses[-1]:.3f}.. '
              f'Validation loss: {valid_losses[-1]:.3f}.. '
              f'Train accuracy: {train_accuracies[-1]:.3f}.. '
              f'Validation accuracy: {valid_accuracies[-1]:.3f}')

    # Plot the training and validation losses
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training loss')
    plt.plot(valid_losses, label='Validation loss')
    plt.legend(frameon=False)
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Training accuracy')
    plt.plot(valid_accuracies, label='Validation accuracy')
    plt.legend(frameon=False)
    plt.show()

Entraînement avec adversarial loss

In [None]:
epsilon = 0.1
alpha = 0.5

for model_name, model in model_dict.items(): 
   #Train the model
    model = model.to(device)
    
    train_losses = []
    valid_losses = []
    train_accuracies = []
    valid_accuracies = []
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        correct_train_preds = 0
        total_train_preds = 0
        for batch in train_dataloader:
            inputs, labels = batch
            # Move the inputs and labels to the device
            inputs = inputs.to(device).requires_grad_()
            labels = labels.to(device).requires_grad_()

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = adversarial_loss_fn(model, inputs, labels, epsilon, alpha)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_train_preds += labels.size(0)
            correct_train_preds += (predicted == labels).sum().item()

        train_losses.append(train_loss / len(train_dataloader))
        train_accuracies.append(100 * correct_train_preds / total_train_preds)

        model.eval()
        valid_loss = 0
        correct_valid_preds = 0
        total_valid_preds = 0
        for batch in valid_dataloader:
            inputs, labels = batch
            # Move the inputs and labels to the device
            inputs = inputs.to(device).requires_grad_()
            labels = labels.to(device).requires_grad_()

            outputs = model(inputs)
            loss = adversarial_loss_fn(model, inputs, labels, epsilon, alpha)

            valid_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_valid_preds += labels.size(0)
            correct_valid_preds += (predicted == labels).sum().item()

        valid_losses.append(valid_loss / len(valid_dataloader))
        valid_accuracies.append(100 * correct_valid_preds / total_valid_preds)

        #early stopping
        # Check if the validation loss has improved
        if valid_losses[-1] < best_valid_loss:
            best_valid_loss = valid_losses[-1]
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # If the validation loss hasn't improved for early_stop_epochs, stop training
        if epochs_no_improve == early_stop_epochs:
            print("Early stopping!")
            break

        
        print(f'Epoch {epoch+1}/{n_epochs}.. '
            f'Train loss: {train_losses[-1]:.3f}.. '
            f'Train accuracy: {train_accuracies[-1]:.3f}.. '
            f'Validation loss: {valid_losses[-1]:.3f}.. '
            f'Validation accuracy: {valid_accuracies[-1]:.3f}')

    # Plot the training and validation losses
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training loss')
    plt.plot(valid_losses, label='Validation loss')
    plt.legend(frameon=False)
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Training accuracy')
    plt.plot(valid_accuracies, label='Validation accuracy')
    plt.legend(frameon=False)
    plt.show()