<img src="./IMTA.png" alt="Logo IMT Atlantique" width="300"/>

##  **Introduction à PyTorch/MONAI - Structuration d’un projet de Deep Learning**
## TAF Health - UE B - 2025/2026 

Pierre-Henri.Conze@imt-atlantique.fr - Vincent.Jaouen@imt-atlantique.fr


# 03 — Synthèse image à image

In [None]:
import sys, os
sys.path.append(os.path.abspath(".."))
from utils.model_utils import model_factory
import torch
import yaml

In [None]:
from utils.data_utils import get_i2i_dataloaders
from utils.training import train_i2i
from utils.vis_utils import show_i2i_triplet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open("../configs/synthesis.yaml") as f:
    cfg = yaml.safe_load(f)

# --- Data ---
train_loader, val_loader = get_i2i_dataloaders(
    data_dir=cfg["data"]["data_dir"],
    batch_size=cfg["data"]["batch_size"],
    num_workers=cfg["data"]["num_workers"],
    cache_rate=cfg["data"]["cache_rate"],
    target_size=tuple(cfg["data"]["target_size"]),
    val_fraction=cfg["data"]["val_fraction"],
    seed=cfg["experiment"]["seed"]
)

# --- Model ---
model = model_factory(cfg["model"]).to(device)




In [None]:
# --- Loss & Optim ---
loss_name = cfg["training"]["loss"].lower()
loss_fn = torch.nn.MSELoss() if loss_name == "mse" else torch.nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["training"]["lr"], weight_decay=cfg["training"]["weight_decay"])

# --- Train ---
train_losses, val_metrics, best_val, best_w = train_i2i(
    model, train_loader, val_loader,
    loss_fn, optimizer,
    device=device,
    max_epochs=cfg["training"]["max_epochs"],
    val_metric=cfg["training"]["val_metric"],
    overlay_fn=show_i2i_triplet   # 
)

In [None]:
# --- Sauvegarde du meilleur modèle ---
out_dir = cfg["save"]["out_dir"]
best_model_path = cfg["save"]["best_model_path"]

os.makedirs(out_dir, exist_ok=True)
torch.save(best_w, best_model_path)
print(f"[INFO] Best model saved at: {best_model_path} (val {cfg['training']['val_metric']}={best_val:.4f})")