In [1]:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# 1. Charger les données avec transformations adaptées
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Redimensionner pour VGG16
    transforms.Grayscale(num_output_channels=3),  # Convertir MNIST (1 canal) en 3 canaux
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalisation
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
testloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 2. Charger le modèle pré-entraîné VGG16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg16 = models.vgg16(pretrained=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 57440168.95it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1667134.97it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 13925950.07it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2546862.13it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw




Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 209MB/s]  


In [2]:
# 3. Modifier la dernière couche pour s'adapter au nombre de classes de MNIST (10 classes)
# Remplace la "classifier" de la VGG16 pour la classification sur MNIST
num_classes = 10
vgg16.classifier[6] = nn.Linear(4096, num_classes)

# Déplacer le modèle sur le périphérique GPU/CPU
vgg16 = vgg16.to(device)

# 4. Définir les hyperparamètres
criterion = nn.CrossEntropyLoss()  # Fonction de perte pour la classification
optimizer = optim.Adam(vgg16.parameters(), lr=0.0001)  # Optimiseur



In [3]:
# 5. Entraîner le modèle
num_epochs = 5
for epoch in range(num_epochs):
    vgg16.train()  # Met le modèle en mode entraînement
    running_loss = 0.0

    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()  # Réinitialiser les gradients
        outputs = vgg16(images)  # Propagation avant
        loss = criterion(outputs, labels)  # Calcul de la perte
        loss.backward()  # Backpropagation
        optimizer.step()  # Mise à jour des poids

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(trainloader):.4f}")


Epoch [1/5], Loss: 0.0778
Epoch [2/5], Loss: 0.0324
Epoch [3/5], Loss: 0.0267
Epoch [4/5], Loss: 0.0219
Epoch [5/5], Loss: 0.0181


In [5]:
# 6. Évaluer le modèle
vgg16.eval()  # Mode évaluation
correct = 0
total = 0
with torch.no_grad():  # Pas besoin de calculer les gradients
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = vgg16(images)
        _, predicted = torch.max(outputs.data, 1)  # Prédictions finales
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f'Accuracy on test set: {100 * correct / total:.2f}%')

Accuracy on test set: 99.52%


In [6]:
from sklearn.metrics import classification_report

# Obtenir les prédictions et les vraies étiquettes
all_labels = []
all_predictions = []

with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = vgg16(images)
        _, predicted = torch.max(outputs.data, 1)
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

# Calculer le rapport de classification
print("Classification Report for VGG16:")
print(classification_report(all_labels, all_predictions, digits=4))


Classification Report for VGG16:
              precision    recall  f1-score   support

           0     0.9959    0.9959    0.9959       980
           1     0.9939    0.9974    0.9956      1135
           2     0.9932    0.9981    0.9957      1032
           3     0.9990    0.9960    0.9975      1010
           4     0.9909    0.9980    0.9944       982
           5     0.9955    0.9955    0.9955       892
           6     0.9958    0.9896    0.9927       958
           7     0.9932    0.9932    0.9932      1028
           8     0.9979    0.9979    0.9979       974
           9     0.9970    0.9901    0.9935      1009

    accuracy                         0.9952     10000
   macro avg     0.9952    0.9952    0.9952     10000
weighted avg     0.9952    0.9952    0.9952     10000

