In [2]:
# === Imports ===
import os, random, logging
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, WeightedRandomSampler
from sklearn.metrics import classification_report, confusion_matrix
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from datetime import datetime
from torch.amp import autocast, GradScaler
import timm

# === Reproductibilité ===
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

# === Configuration ===
model_name = "efficientformer_l1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Permet d'utiiser le GPU du PC si il y en a
use_amp = torch.cuda.is_available()
batch_size = 32
epochs = 10
image_size = 224## taille (1) améliorer
num_classes = 2
lr = 1e-4
log_dir = "runs/efficientformer_exp"
model_save_path = "best_model_efficientformer.pth"
info_save_path = "model_info_efficientformer.txt"
early_stopping_patience = 5
num_workers = min(os.cpu_count(), 4)

writer = SummaryWriter(log_dir=log_dir)
logging.basicConfig(filename="training.log", level=logging.INFO)

# === Transforms ===
train_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
test_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# === Datasets ===
train_dataset = datasets.ImageFolder("DATASET/train", transform=train_transform)
val_dataset = datasets.ImageFolder("DATASET/validation", transform=test_transform)
test_dataset = datasets.ImageFolder("DATASET/test", transform=test_transform)

# === Gestion des classes déséquilibrées ===
labels = [s[1] for s in train_dataset.samples]
class_counts = np.bincount(labels)
weights = 1.0 / class_counts
sample_weights = [weights[l] for l in labels]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

# === DataLoaders ===
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler,
                          num_workers=num_workers, pin_memory=True, persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                        num_workers=num_workers, pin_memory=True, persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                         num_workers=num_workers, pin_memory=True, persistent_workers=True)

# === Modèle EfficientFormer (hybride CNN + Transformer) ===
model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)
model = model.to(device, memory_format=torch.channels_last)

# === Optimisation ===
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
scaler = GradScaler(enabled=use_amp)

# === Fonctions d'entraînement et d'évaluation ===
def train_one_epoch(model, loader, optimizer, criterion, scaler):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for x, y in tqdm(loader, desc="Train", leave=False):
        x, y = x.to(device, memory_format=torch.channels_last, non_blocking=True), y.to(device)
        optimizer.zero_grad()
        with autocast(device_type="cuda", enabled=use_amp):
            output = model(x)
            loss = criterion(output, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        correct += (output.argmax(1) == y).sum().item()
        total += y.size(0)
    return total_loss / len(loader), correct / total

def evaluate(model, loader, criterion):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for x, y in tqdm(loader, desc="Eval", leave=False):
            x, y = x.to(device, memory_format=torch.channels_last, non_blocking=True), y.to(device)
            output = model(x)
            loss = criterion(output, y)
            total_loss += loss.item()
            correct += (output.argmax(1) == y).sum().item()
            total += y.size(0)
    return total_loss / len(loader), correct / total

# === Boucle d'entraînement ===
best_val_acc, patience = 0, 0
for epoch in range(epochs):
    print(f"\n📅 Epoch {epoch+1}/{epochs}")
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, scaler)
    val_loss, val_acc = evaluate(model, val_loader, criterion)
    scheduler.step()

    print(f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")
    writer.add_scalar("Loss/train", train_loss, epoch)
    writer.add_scalar("Loss/val", val_loss, epoch)
    writer.add_scalar("Acc/train", train_acc, epoch)
    writer.add_scalar("Acc/val", val_acc, epoch)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience = 0
        torch.save(model.state_dict(), model_save_path)
        print("✅ Meilleur modèle sauvegardé.")
    else:
        patience += 1
        if patience >= early_stopping_patience:
            print("⏹️ Early stopping.")
            break

# === Test final ===
model.load_state_dict(torch.load(model_save_path))
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for x, y in tqdm(test_loader, desc="Test"):
        x = x.to(device, memory_format=torch.channels_last, non_blocking=True)
        preds = model(x).argmax(1).cpu().numpy()
        y_true.extend(y.numpy())
        y_pred.extend(preds)

print("\n=== Rapport de classification ===")
print(classification_report(y_true, y_pred, target_names=test_dataset.classes))
print(confusion_matrix(y_true, y_pred))

# === Sauvegarde des infos ===
with open(info_save_path, "w") as f:
    f.write(f"=== Model Info ===\n")
    f.write(f"Model: {model_name}\n")
    f.write(f"Date: {datetime.now()}\n")
    f.write(f"Device: {device}\n")
    f.write(f"Classes: {num_classes}\n")
    f.write(f"Epochs: {epoch+1}\n")
    f.write(f"Best Val Acc: {best_val_acc:.4f}\n")
    f.write("\n=== Recharger le modèle ===\n")
    f.write(f"import timm\nmodel = timm.create_model('{model_name}', pretrained=False, num_classes={num_classes})\n")
    f.write(f"model.load_state_dict(torch.load('{model_save_path}'))\nmodel.eval()\n")

print(f"\n📄 Infos sauvegardées dans {info_save_path}")
writer.close()


📅 Epoch 1/10


                                              

KeyboardInterrupt: 