<img src="./IMTA.png" alt="Logo IMT Atlantique" width="300"/>

##  **Introduction à PyTorch/MONAI - Structuration d’un projet de Deep Learning**
## TAF Health - UE B - 2025/2026 

Pierre-Henri.Conze@imt-atlantique.fr - Vincent.Jaouen@imt-atlantique.fr


# 02 — Classification

In [None]:
import sys, os
sys.path.append(os.path.abspath(".."))
from utils.data_utils import get_classif_dataloaders
from utils.training import train_classification
from utils.model_utils import model_factory

import torch
import yaml

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open("../configs/classification.yaml") as f:
    cfg = yaml.safe_load(f)
# --- Data ---
train_loader, val_loader = get_classif_dataloaders(
    data_dir=cfg["data"]["data_dir"],
    batch_size=cfg["data"]["batch_size"],
    num_workers=cfg["data"]["num_workers"],
    cache_rate=cfg["data"]["cache_rate"],
    target_size=tuple(cfg["data"]["target_size"]), 
)
# --- Model ---
model = model_factory(cfg["model"]).to(device)

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

train_losses, val_accs, best_acc, best_weights = train_classification(
    model, train_loader, val_loader,
    loss_fn, optimizer,
    device=device, max_epochs=cfg["training"]["max_epochs"],
    save_path=cfg["save"]["best_model_path"]
)


In [None]:
from utils.evaluation import evaluate_classification
# reload best weights
model.load_state_dict(best_weights)

# class names for readability
class_names = {0: "T1ce", 1: "T2", 2: "FLAIR"}

# evaluate with mosaic
final_acc = evaluate_classification(model, val_loader, device=device, max_examples=50, class_names=class_names)

