# Classification de Malaria par Deep Learning

Entraînement et évaluation de 3 modèles CNN pour détecter le paludisme.

## 1. Imports et Configuration

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
from sklearn.metrics import confusion_matrix, classification_report

# Configuration graphiques
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 12

# Imports du projet
from src.config import APPAREIL, CLASSES, TAILLE_IMAGE, TAILLE_BATCH, NOMBRE_EPOCHS, TAUX_APPRENTISSAGE, PATIENCE, CHEMIN_RESULTATS
from src.data_manager import DataManager
from src.models.simple_cnn import SimpleCNN
from src.models.vgg16_model import VGG16Model
from src.models.resnet50_model import ResNet50Model

print(f"Appareil: {APPAREIL}")
print(f"Classes: {CLASSES}")
print(f"Taille image: {TAILLE_IMAGE}x{TAILLE_IMAGE}")
print(f"Batch size: {TAILLE_BATCH}")
print(f"Epochs max: {NOMBRE_EPOCHS}")

## 2. Chargement des Données

In [None]:
gestionnaire = DataManager()
gestionnaire.load_data()

chargeur_train, chargeur_val, chargeur_test = gestionnaire.get_dataloaders(avec_augmentation=True)

print(f"\nDataLoaders créés:")
print(f"  Train: {len(chargeur_train.dataset)} images")
print(f"  Val: {len(chargeur_val.dataset)} images")
print(f"  Test: {len(chargeur_test.dataset)} images")

## 3. Visualisation des Données

In [None]:
# Distribution des classes
etiquettes = gestionnaire.etiquettes_entrainement
compte_parasitized = etiquettes.count(0)
compte_uninfected = etiquettes.count(1)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

axes[0].bar(CLASSES, [compte_parasitized, compte_uninfected], color=['#e74c3c', '#2ecc71'])
axes[0].set_title('Distribution des Classes', fontweight='bold')
axes[0].set_ylabel('Nombre images')
for i, v in enumerate([compte_parasitized, compte_uninfected]):
    axes[0].text(i, v + 200, str(v), ha='center', fontweight='bold')

axes[1].pie([compte_parasitized, compte_uninfected], labels=CLASSES, autopct='%1.1f%%', 
            colors=['#e74c3c', '#2ecc71'], explode=(0.05, 0))
axes[1].set_title('Repartition des Classes', fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(CHEMIN_RESULTATS, 'distribution_classes.png'), dpi=150)
plt.show()

In [None]:
# Exemples d'images
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
fig.suptitle('Exemples Images', fontsize=14, fontweight='bold')

for i in range(2):
    indices = [j for j, e in enumerate(gestionnaire.etiquettes_entrainement) if e == i]
    for j in range(5):
        chemin = gestionnaire.chemins_entrainement[indices[j]]
        img = Image.open(chemin)
        axes[i, j].imshow(img)
        axes[i, j].axis('off')
        if j == 0:
            axes[i, j].set_title(CLASSES[i], fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(CHEMIN_RESULTATS, 'exemples_images.png'), dpi=150)
plt.show()

## 4. Fonction Entrainement

In [None]:
def entrainer(modele, nom, chargeur_train, chargeur_val):
    modele.to(APPAREIL)
    critere = nn.CrossEntropyLoss()
    optimiseur = optim.Adam(modele.parameters(), lr=TAUX_APPRENTISSAGE)
    
    historique = {'perte_train': [], 'acc_train': [], 'perte_val': [], 'acc_val': []}
    meilleure_acc = 0
    patience_compteur = 0
    
    print(f"\n{'='*50}")
    print(f"Entrainement: {nom}")
    print(f"{'='*50}")
    
    for epoch in range(NOMBRE_EPOCHS):
        # Train
        modele.train()
        perte_sum, correct, total = 0, 0, 0
        
        for images, labels in tqdm(chargeur_train, desc=f'Epoch {epoch+1}/{NOMBRE_EPOCHS}'):
            images, labels = images.to(APPAREIL), labels.to(APPAREIL)
            
            optimiseur.zero_grad()
            outputs = modele(images)
            perte = critere(outputs, labels)
            perte.backward()
            optimiseur.step()
            
            perte_sum += perte.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        
        perte_train = perte_sum / total
        acc_train = 100 * correct / total
        
        # Validation
        modele.eval()
        perte_sum, correct, total = 0, 0, 0
        
        with torch.no_grad():
            for images, labels in chargeur_val:
                images, labels = images.to(APPAREIL), labels.to(APPAREIL)
                outputs = modele(images)
                perte = critere(outputs, labels)
                
                perte_sum += perte.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        
        perte_val = perte_sum / total
        acc_val = 100 * correct / total
        
        historique['perte_train'].append(perte_train)
        historique['acc_train'].append(acc_train)
        historique['perte_val'].append(perte_val)
        historique['acc_val'].append(acc_val)
        
        print(f"Epoch {epoch+1:2d} | Train: {acc_train:.1f}% | Val: {acc_val:.1f}%")
        
        if acc_val > meilleure_acc:
            meilleure_acc = acc_val
            patience_compteur = 0
            torch.save(modele.state_dict(), os.path.join(CHEMIN_RESULTATS, f'{nom}_best.pth'))
            print(f"  -> Meilleur modele sauvegarde!")
        else:
            patience_compteur += 1
            if patience_compteur >= PATIENCE:
                print(f"Early stopping apres {epoch+1} epochs")
                break
    
    modele.load_state_dict(torch.load(os.path.join(CHEMIN_RESULTATS, f'{nom}_best.pth'), map_location=APPAREIL))
    print(f"Entrainement termine! Meilleure precision: {meilleure_acc:.2f}%")
    
    return historique

In [None]:
def afficher_courbes(historique, nom):
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    epochs = range(1, len(historique['perte_train']) + 1)
    
    axes[0].plot(epochs, historique['perte_train'], 'b-o', label='Train')
    axes[0].plot(epochs, historique['perte_val'], 'r-o', label='Val')
    axes[0].set_title(f'Perte - {nom}', fontweight='bold')
    axes[0].set_xlabel('Epochs')
    axes[0].set_ylabel('Perte')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(epochs, historique['acc_train'], 'b-o', label='Train')
    axes[1].plot(epochs, historique['acc_val'], 'r-o', label='Val')
    axes[1].set_title(f'Precision - {nom}', fontweight='bold')
    axes[1].set_xlabel('Epochs')
    axes[1].set_ylabel('Precision (%)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(CHEMIN_RESULTATS, f'{nom}_courbes.png'), dpi=150)
    plt.show()

## 5. Entrainement CNN Simple

In [None]:
modele_cnn = SimpleCNN()
print(f"Modele cree: {sum(p.numel() for p in modele_cnn.parameters())} parametres")

historique_cnn = entrainer(modele_cnn, 'CNN_Simple', chargeur_train, chargeur_val)

In [None]:
afficher_courbes(historique_cnn, 'CNN_Simple')

## 6. Entrainement VGG16

In [None]:
modele_vgg = VGG16Model()
print(f"Modele cree: {sum(p.numel() for p in modele_vgg.parameters())} parametres")

historique_vgg = entrainer(modele_vgg, 'VGG16', chargeur_train, chargeur_val)

In [None]:
afficher_courbes(historique_vgg, 'VGG16')

## 7. Entrainement ResNet50

In [None]:
modele_resnet = ResNet50Model()
print(f"Modele cree: {sum(p.numel() for p in modele_resnet.parameters())} parametres")

historique_resnet = entrainer(modele_resnet, 'ResNet50', chargeur_train, chargeur_val)

In [None]:
afficher_courbes(historique_resnet, 'ResNet50')

## 8. Comparaison des Modeles

In [None]:
tous_historiques = {
    'CNN_Simple': historique_cnn,
    'VGG16': historique_vgg,
    'ResNet50': historique_resnet
}

couleurs = {'CNN_Simple': '#3498db', 'VGG16': '#e74c3c', 'ResNet50': '#2ecc71'}

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for nom, hist in tous_historiques.items():
    epochs = range(1, len(hist['acc_val']) + 1)
    axes[0].plot(epochs, hist['acc_val'], '-o', label=nom, color=couleurs[nom], linewidth=2)
    axes[1].plot(epochs, hist['perte_val'], '-o', label=nom, color=couleurs[nom], linewidth=2)

axes[0].set_title('Comparaison Precision Validation', fontweight='bold')
axes[0].set_xlabel('Epochs')
axes[0].set_ylabel('Precision (%)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].set_title('Comparaison Perte Validation', fontweight='bold')
axes[1].set_xlabel('Epochs')
axes[1].set_ylabel('Perte')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(CHEMIN_RESULTATS, 'comparaison.png'), dpi=150)
plt.show()

## 9. Evaluation sur Test

In [None]:
def evaluer(modele, chargeur):
    modele.eval()
    preds_all, labels_all = [], []
    
    with torch.no_grad():
        for images, labels in chargeur:
            outputs = modele(images.to(APPAREIL))
            _, preds = torch.max(outputs, 1)
            preds_all.extend(preds.cpu().numpy())
            labels_all.extend(labels.numpy())
    
    return np.array(preds_all), np.array(labels_all)

tous_modeles = {
    'CNN_Simple': modele_cnn,
    'VGG16': modele_vgg,
    'ResNet50': modele_resnet
}

resultats = {}
for nom, modele in tous_modeles.items():
    preds, labels = evaluer(modele, chargeur_test)
    acc = np.mean(preds == labels) * 100
    cm = confusion_matrix(labels, preds)
    resultats[nom] = {'acc': acc, 'cm': cm, 'preds': preds, 'labels': labels}
    print(f"{nom}: {acc:.2f}%")

In [None]:
# Matrices de confusion
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

for idx, (nom, data) in enumerate(resultats.items()):
    sns.heatmap(data['cm'], annot=True, fmt='d', cmap='Blues',
                xticklabels=CLASSES, yticklabels=CLASSES, ax=axes[idx],
                annot_kws={'size': 14})
    axes[idx].set_title(f"{nom}\n{data['acc']:.2f}%", fontweight='bold')
    axes[idx].set_ylabel('Vraie classe')
    axes[idx].set_xlabel('Prediction')

plt.tight_layout()
plt.savefig(os.path.join(CHEMIN_RESULTATS, 'matrices_confusion.png'), dpi=150)
plt.show()

In [None]:
# Rapports de classification
for nom, data in resultats.items():
    print(f"\n{'='*50}")
    print(f"Rapport: {nom}")
    print(f"{'='*50}")
    print(classification_report(data['labels'], data['preds'], target_names=CLASSES))

## 10. Resume Final

In [None]:
# Graphique final
fig, ax = plt.subplots(figsize=(10, 6))

noms = list(resultats.keys())
accs = [resultats[n]['acc'] for n in noms]
colors = [couleurs[n] for n in noms]

bars = ax.bar(noms, accs, color=colors, edgecolor='black', linewidth=2)

for bar, acc in zip(bars, accs):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
            f'{acc:.2f}%', ha='center', fontsize=12, fontweight='bold')

ax.set_ylim(0, 105)
ax.set_ylabel('Precision (%)')
ax.set_title('Comparaison Modeles - Test', fontweight='bold', fontsize=14)
ax.axhline(y=90, color='gray', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.savefig(os.path.join(CHEMIN_RESULTATS, 'resultat_final.png'), dpi=150)
plt.show()

# Resume
meilleur = max(resultats.keys(), key=lambda x: resultats[x]['acc'])
print(f"\n{'='*50}")
print(f"MEILLEUR MODELE: {meilleur}")
print(f"Precision: {resultats[meilleur]['acc']:.2f}%")
print(f"{'='*50}")
print(f"\nResultats sauvegardes dans: {CHEMIN_RESULTATS}")