# Report for DataChallenge SLB EchoCEM

**Student :** CHOQUET Laura and GRAVIER Thomas

### Import

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from tqdm import tqdm
import gc


### I - Load Data and preprocessing

In [None]:
y_train = pd.read_csv("y_train.csv", index_col=0)

In [None]:
y_train.head()

In [None]:
CURRENT_DIR = os.getcwd()
DIR_train = "/datasets/train"


In [None]:

class CustomImageDataset(Dataset):
    def __init__(self, root_dir, y_train, train=True):
        """
        Args:
            root_dir (str): Dossier contenant les fichiers .npy.
            y_train (pd.DataFrame): DataFrame contenant les masques sous forme de lignes flatten.
            train (bool): Si True, charge l'ensemble d'entraînement, sinon charge la validation.
        """
        self.root_dir = root_dir
        self.y_train = y_train  
        self.train = train  

        self.train_files = [f for f in os.listdir(root_dir) if f.endswith(".npy")]
        self.train_files = [f for f in os.listdir(root_dir) if f.startswith(("well_1", "well_4","well_6","well_2")) and f.endswith(".npy")]
        self.val_files = [f for f in os.listdir(root_dir) if f.startswith(("well_5")) and f.endswith(".npy")]
        self.files = self.train_files if train else self.val_files

        self.resize_transform = transforms.Resize((160, 160), interpolation=transforms.InterpolationMode.BILINEAR)
        self.resize_label_transform = transforms.Resize((160, 160), interpolation=transforms.InterpolationMode.NEAREST)

        self.augmentations = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=10, p=0.4),

            A.GridDistortion(p=0.3),

            A.GaussianBlur(blur_limit=(3, 5), p=0.2),
            A.GaussNoise(var_limit=(5.0, 30.0), p=0.2),
            A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.3),
            A.CLAHE(clip_limit=2.0, tile_grid_size=(8,8), p=0.2),
            
            A.RandomShadow(p=0.1),  # Ajout d'ombres aléatoires
            A.Sharpen(alpha=(0.2, 0.5), p=0.2),  # Rehausse les détails
            A.RandomGamma(gamma_limit=(80, 120), p=0.2),  # Ajustement gamma
        ])


    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_name = self.files[idx]
        file_path = os.path.join(self.root_dir, file_name)

        image = np.load(file_path).astype(np.float32)
        original_size = image.shape
        resized = False  

        image_min, image_max = image.min(), image.max()
        if image_max > image_min:  
            image = (image - image_min) / (image_max - image_min)
        else:
            image = np.zeros_like(image) 

        image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)  

        if original_size != (160, 160):
            image = self.resize_transform(image)
            resized = True  

        file_id = os.path.splitext(file_name)[0]
        label_flatten = self.y_train.loc[file_id].values.astype(np.int64)

        if np.any(label_flatten == -1):
            label_flatten = label_flatten[:np.argmax(label_flatten == -1)]

        if label_flatten.size == 160 * 272:
            label_shape = (160, 272)
        else:
            label_shape = (160, 160)

        label = label_flatten.reshape(label_shape)
        label = torch.tensor(label, dtype=torch.float32).unsqueeze(0)

        if label_shape != (160, 160):
            label = self.resize_label_transform(label)

        if self.train:
            augmented = self.augmentations(image=image.numpy().transpose(1, 2, 0), mask=label.numpy().transpose(1, 2, 0))
            image = torch.tensor(augmented["image"]).permute(2, 0, 1)  
            label = torch.tensor(augmented["mask"]).permute(2, 0, 1)  

        return {
            "image": image,
            "resized": resized,
            "label": label
        }


In [None]:

train_dataset = CustomImageDataset("images train", y_train, train=True)  
val_dataset = CustomImageDataset("images train", y_train, train=False)  

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)


print(f"Taille du dataset d'entraînement : {len(train_dataset)}")
print(f"Taille du dataset de validation : {len(val_dataset)}")


### II - Processing with UNET

Model choisie : RESNET + pretraine image net

In [None]:


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Utilisation de : {device}")

import segmentation_models_pytorch as smp
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
""" model = smp.UnetPlusPlus(
    encoder_name="efficientnet-b4",
    encoder_weights="imagenet",
    in_channels=1,
    classes=3
).to(device)
model = smp.DeepLabV3Plus(
    encoder_name="resnet101",  # ResNet50 si vous voulez un modèle plus rapide
    encoder_weights="imagenet",
    in_channels=1,
    classes=3  # Nombre de classes de segmentation
).to(device)
 """
model = smp.Unet(
    encoder_name="resnet34",       
    encoder_weights="imagenet",    
    in_channels=1,                 
    classes=3                      
).to(device)

In [None]:
for param in model.encoder.parameters():
    param.requires_grad = False

ce_loss = nn.CrossEntropyLoss()

import torch.nn.functional as F

def focal_dice_loss(preds, targets, alpha=0.5, gamma=2.0):
    targets_cls = targets.squeeze(1).long()
    ce_loss = F.cross_entropy(preds, targets_cls, reduction='none')
    pt = torch.exp(-ce_loss)
    focal_loss = ((1 - pt) ** gamma) * ce_loss
    focal_component = focal_loss.mean()
    
    preds_soft = F.softmax(preds, dim=1)
    targets_one_hot = F.one_hot(targets_cls, num_classes=3).permute(0, 3, 1, 2).float()
    
    intersection = (preds_soft * targets_one_hot).sum(dim=(2, 3))
    cardinality = preds_soft.sum(dim=(2, 3)) + targets_one_hot.sum(dim=(2, 3))
    
    dice = 2 * intersection / (cardinality + 1e-6)
    dice_component = (1 - dice.mean())
    
    return alpha * focal_component + (1 - alpha) * dice_component

def compute_multiclass_iou(preds, targets, num_classes=3, ignore_bg=True):
    preds = torch.argmax(preds, dim=1)  
    targets = targets.squeeze(1).long()  
    
    class_ious = []
    
    for cls in range(num_classes):
        pred_mask = (preds == cls).float()
        target_mask = (targets == cls).float()
        
        intersection = (pred_mask * target_mask).sum((1, 2))
        union = ((pred_mask + target_mask) > 0).float().sum((1, 2))
        
        iou = (intersection + 1e-6) / (union + 1e-6)
        
        class_ious.append(iou.mean().item())
    
    if ignore_bg:
        mean_iou = sum(class_ious[1:]) / (num_classes - 1) 
    else:
        mean_iou = sum(class_ious) / num_classes  
    
    return mean_iou, class_ious


optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)

def train(model, train_loader, val_loader, optimizer, scheduler, num_epochs=20, checkpoint_path="best_model.pth"):
    best_iou = 0.0  
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False)

        for batch in loop:
            images, masks, resized = batch["image"], batch["label"], batch["resized"]

            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = focal_dice_loss(outputs, masks)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            loop.set_postfix(loss=loss.item(), resized=resized.sum().item())

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}: Training Loss = {avg_loss:.4f}")

        model.eval()
        iou_score = 0.0
        iou_1, iou_2, iou_3 = 0.0, 0.0, 0.0
        with torch.no_grad():
            for batch in val_loader:
                images, masks, resized = batch["image"], batch["label"], batch["resized"]
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                iou_mean,iou_class= compute_multiclass_iou(outputs, masks)
                iou_score+=iou_mean
                iou_1 += iou_class[0]
                iou_2 += iou_class[1]
                iou_3 += iou_class[2]
        avg_iou = iou_score / len(val_loader)
        avg_iou_1 = iou_1 / len(val_loader)
        avg_iou_2 = iou_2 / len(val_loader)
        avg_iou_3 = iou_3 / len(val_loader)
        print(f"Validation Mean IoU = {avg_iou:.4f}, Mean IoU 1 = {avg_iou_1:.4f}, Mean IoU 2 = {avg_iou_2:.4f} Mean IoU 3 = {avg_iou_3:.4f}")

        scheduler.step(avg_iou)

        if avg_iou_1 > best_iou:
            best_iou = avg_iou_1
            torch.save(model.state_dict(), checkpoint_path)
            print(f" Meilleur modèle sauvegardé à epoch {epoch+1} avec IoU 2 {best_iou:.4f}")

        if epoch == 5:
            print(" Déblocage partiel du backbone pour fine-tuning progressif.")
            for param in list(model.encoder.parameters())[-10:]:  # Dernières couches seulement
                param.requires_grad = True

        if epoch == 10:
            print(" Déblocage total du backbone.")
            for param in model.encoder.parameters():
                param.requires_grad = True



In [None]:


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Utilisation de : {device}")

model = smp.Unet(
    encoder_name="resnet34",       
    encoder_weights=None,          
    in_channels=1,                 
    classes=3                      
).to(device)

checkpoint_path = "best_model_0_682/best_model.pth"
if os.path.isfile(checkpoint_path):
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    print(f" Poids pré-entraînés chargés depuis {checkpoint_path}.")
else:
    print(f" Aucun fichier de poids pré-entraînés trouvé à l'emplacement : {checkpoint_path}. Veuillez vérifier le chemin.")

def compute_multiclass_iou(preds, targets, num_classes=3, ignore_bg=True):
    preds = torch.argmax(preds, dim=1)  
    targets = targets.squeeze(1).long()  
    
    class_ious = []
    
    for cls in range(num_classes):
        pred_mask = (preds == cls).float()
        target_mask = (targets == cls).float()
        
        intersection = (pred_mask * target_mask).sum((1, 2))
        union = ((pred_mask + target_mask) > 0).float().sum((1, 2))
        
        iou = (intersection + 1e-6) / (union + 1e-6)
        
        class_ious.append(iou.mean().item())
    
    if ignore_bg:
        mean_iou = sum(class_ious[1:]) / (num_classes - 1)
    else:
        mean_iou = sum(class_ious) / num_classes  
    
    return mean_iou, class_ious

def evaluate(model, val_loader):
    model.eval()
    iou_list = []
    with torch.no_grad():
        for batch in val_loader:
            images, masks, _ = batch["image"], batch["label"], batch["resized"]
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            iou,_ = compute_multiclass_iou(outputs, masks)
            iou_list.append(iou)
    
    # Calcul de l'IoU moyen
    mean_iou = sum(iou_list) / len(iou_list)
    print(f" Moyenne de l'IoU sur l'ensemble du val_loader : {mean_iou:.4f}")

# Exécution de l'évaluation
evaluate(model, val_loader)


### III -  Post-processing

In [None]:
def compute_edges(x):
    """
    Applique un filtre Sobel sur chaque canal indépendamment.
    """
    sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32).repeat(x.shape[1], 1, 1, 1).to(x.device)
    sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32).repeat(x.shape[1], 1, 1, 1).to(x.device)

    edge_x = F.conv2d(x, sobel_x, padding=1, groups=x.shape[1])  # Appliquer indépendamment sur chaque canal
    edge_y = F.conv2d(x, sobel_y, padding=1, groups=x.shape[1])
    
    return torch.sqrt(edge_x**2 + edge_y**2 + 1e-6)  # Norme L2 pour éviter torch.abs

def boundary_loss(preds, targets):
    """
    Perte qui compare les bords des prédictions et des cibles.
    """
    preds_soft = torch.softmax(preds, dim=1)  # Probabilités par classe
    
    targets_cls = targets.squeeze(1).long()
    targets_one_hot = F.one_hot(targets_cls, num_classes=preds.shape[1]).permute(0, 3, 1, 2).float()

    # On utilise seulement la classe la plus probable pour éviter d’inclure trop de bruit
    preds_max = preds_soft.argmax(dim=1, keepdim=True)
    preds_one_hot = F.one_hot(preds_max.squeeze(1), num_classes=preds.shape[1]).permute(0, 3, 1, 2).float()

    # Contours
    edge_preds = compute_edges(preds_one_hot)
    edge_targets = compute_edges(targets_one_hot)

    # Différence entre contours en norme L2
    edge_diff = (edge_preds - edge_targets) ** 2  

    return edge_diff.mean()

def lovasz_softmax(preds, targets):
    """
    Implémentation du Lovász-Softmax pour optimiser directement l'IoU.
    Cette perte est particulièrement efficace pour la segmentation.
    
    Basé sur l'article: "The Lovász-Softmax loss: A tractable surrogate for the optimization 
    of the intersection-over-union measure in neural networks"
    
    Implémentation simplifiée pour illustration - voir PyTorch-Lovasz pour une 
    implémentation complète et optimisée.
    """
    preds_soft = torch.softmax(preds, dim=1)
    targets_cls = targets.squeeze(1).long()
    
    lovasz_loss = 0.0
    
    for c in range(preds.shape[1]):
        # Vérifier s'il y a au moins un pixel de cette classe
        if (targets_cls == c).sum() == 0:
            continue
            
        # Convertir en one-hot pour la classe actuelle
        targets_c = (targets_cls == c).float()
        preds_c = preds_soft[:, c]
        
        # Calcul de l'erreur pour chaque pixel
        errors = (targets_c - preds_c).abs()
        
        # Triez les erreurs par ordre décroissant
        sorted_errors, sorted_indices = torch.sort(errors.reshape(-1), descending=True)
        
        # Conversion des cibles triées
        sorted_targets = targets_c.reshape(-1)[sorted_indices]
        
        # Calculer les incréments d'IoU
        intersection = sorted_targets.cumsum(0)
        union = sorted_targets.sum() + (1.0 - sorted_targets).cumsum(0)
        iou = 1.0 - intersection / (union + 1e-6)
        
        # Lovász extension
        lovasz_grad = torch.zeros_like(iou)
        lovasz_grad[:-1] = iou[1:] - iou[:-1]
        lovasz_grad[-1] = iou[-1]
        
        # Calculer la perte pour cette classe
        lovasz_loss += (lovasz_grad * sorted_errors).sum()
    
    return lovasz_loss / preds.shape[1]

In [None]:



class PostProcessingNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_filters=16):
        super(PostProcessingNet, self).__init__()

        self.enc1 = self._conv_block(in_channels, base_filters)
        self.enc2 = self._conv_block(base_filters, base_filters * 2)
        self.enc3 = self._conv_block(base_filters * 2, base_filters * 4)

        self.bridge = self._conv_block(base_filters * 4, base_filters * 8)

        self.up1 = nn.ConvTranspose2d(base_filters * 8, base_filters * 4, kernel_size=2, stride=2)
        self.dec1 = self._conv_block(base_filters * 8, base_filters * 4)
        
        self.up2 = nn.ConvTranspose2d(base_filters * 4, base_filters * 2, kernel_size=2, stride=2)
        self.dec2 = self._conv_block(base_filters * 4, base_filters * 2)
        
        self.up3 = nn.ConvTranspose2d(base_filters * 2, base_filters, kernel_size=2, stride=2)
        self.dec3 = self._conv_block(base_filters * 2, base_filters)

        self.final = nn.Conv2d(base_filters, out_channels, kernel_size=1)

    def _conv_block(self, in_c, out_c):
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc1 = self.enc1(x)
        x = F.max_pool2d(enc1, kernel_size=2, stride=2)
        enc2 = self.enc2(x)
        x = F.max_pool2d(enc2, kernel_size=2, stride=2)
        enc3 = self.enc3(x)
        x = F.max_pool2d(enc3, kernel_size=2, stride=2)

        x = self.bridge(x)

        x = self.up1(x)
        x = F.pad(x, [0, enc3.size(3) - x.size(3), 0, enc3.size(2) - x.size(2)])
        x = torch.cat([x, enc3], dim=1)
        x = self.dec1(x)

        x = self.up2(x)
        x = F.pad(x, [0, enc2.size(3) - x.size(3), 0, enc2.size(2) - x.size(2)])
        x = torch.cat([x, enc2], dim=1)
        x = self.dec2(x)

        x = self.up3(x)
        x = F.pad(x, [0, enc1.size(3) - x.size(3), 0, enc1.size(2) - x.size(2)])
        x = torch.cat([x, enc1], dim=1)
        x = self.dec3(x)

        x = self.final(x)
        return x


In [None]:
def train_memory_efficient(model, model_post_pro, train_loader, val_loader, optimizer, scheduler, 
                        criterion, num_epochs=50, checkpoint_path="best_post_pro_model_tho.pth", 
                        device="cuda", batch_size=None, val_frequency=5):
    best_iou = 0.0
    best_loss = 2

    history = {"train_loss": [], "val_loss": [], "val_iou": []}
    
    if batch_size is not None and batch_size < train_loader.batch_size:
        print(f"Attention: Réduction de la taille de batch de {train_loader.batch_size} à {batch_size}")
    for epoch in range(num_epochs):
        model_post_pro.train()
        total_loss = 0.0
        batch_count = 0
        
        loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False)
        
        for batch in loop:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            images, masks, resized = batch["image"], batch["label"], batch["resized"]
            images, masks = images.to(device), masks.to(device)
            
            with torch.no_grad():
                initial_outputs = model(images)
                probabilities = torch.softmax(initial_outputs, dim=1) 
                del initial_outputs
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
            
            optimizer.zero_grad()
            post_pro_outputs = model_post_pro(probabilities)
            
            del probabilities
            if torch.cuda.is_available():
                torch.cuda.empty_cache() 
     
            
            loss = focal_dice_loss(post_pro_outputs, masks) + 0.5 * boundary_loss(post_pro_outputs, masks)            
            loss.backward()
            optimizer.step()
            
            batch_loss = loss.item()
            del post_pro_outputs, loss, masks
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            total_loss += batch_loss
            batch_count += 1
            loop.set_postfix(loss=batch_loss)
        
        avg_train_loss = total_loss / batch_count
        history["train_loss"].append(avg_train_loss)
        
        if epoch % val_frequency == 0 or epoch == num_epochs - 1 or avg_train_loss<best_loss:
            if avg_train_loss<best_loss:
                best_loss = avg_train_loss
            val_loss, val_iou,_ = validate_memory_efficient(
                model, model_post_pro, val_loader, criterion, device
            )
            history["val_loss"].append(val_loss)
            history["val_iou"].append(val_iou)
            
            if scheduler:
                scheduler.step(val_loss)
            
            print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Val IoU: {val_iou:.4f}")
            if val_iou > best_iou:
                best_iou = val_iou
                torch.save(model_post_pro.state_dict(), checkpoint_path)
                print(f"Meilleur modèle de post-processing sauvegardé avec IoU: {best_iou:.4f}")
        else:
            print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f} (pas de validation)")
            history["val_loss"].append(None)
            history["val_iou"].append(None)
            
            if scheduler and isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                pass
            elif scheduler:
                scheduler.step()
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()
    
    return history

def validate_memory_efficient(model, model_post_pro, val_loader, criterion, device="cuda"):
    model.eval()
    model_post_pro.eval()
    total_loss = 0.0
    
    all_post_pro_outputs = []
    all_initial_outputs = []
    all_targets = []
    
    with torch.no_grad():
        for batch in val_loader:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            images, masks, _ = batch["image"], batch["label"], batch["resized"]
            images, masks = images.to(device), masks.to(device)
            
            initial_outputs = model(images)
            
            all_initial_outputs.append(initial_outputs.cpu())
            
            probabilities = torch.softmax(initial_outputs, dim=1)  
            
            
            del initial_outputs
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            post_pro_outputs = model_post_pro(probabilities)
            
            all_post_pro_outputs.append(post_pro_outputs.cpu())
            
            del probabilities
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            if masks.shape[1] == 1:
                target_masks = masks.squeeze(1)
            else:
                target_masks = masks
                
            loss = focal_dice_loss(post_pro_outputs, masks) + 0.5 * boundary_loss(post_pro_outputs, masks)            
            total_loss += loss.item()
            
            all_targets.append(masks.cpu())
            
            del post_pro_outputs, loss, target_masks
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    all_post_pro_outputs = torch.cat(all_post_pro_outputs, dim=0)
    all_initial_outputs = torch.cat(all_initial_outputs, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    
    initial_iou, initial_class_ious = compute_multiclass_iou(all_initial_outputs, all_targets, num_classes=3)
    post_pro_iou, post_pro_class_ious = compute_multiclass_iou(all_post_pro_outputs, all_targets, num_classes=3)
    
    print(f"IoU du modèle initial: {initial_iou:.4f}")
    for cls, iou in enumerate(initial_class_ious):
        print(f"  - Classe {cls}: {iou:.4f}")
    
    print(f"IoU après post-processing: {post_pro_iou:.4f}")
    for cls, iou in enumerate(post_pro_class_ious):
        print(f"  - Classe {cls}: {iou:.4f}")
    
    print(f"Différence d'IoU globale: {post_pro_iou - initial_iou:.4f} ({(post_pro_iou - initial_iou) / initial_iou * 100:.2f}%)")
    
    for cls in range(len(initial_class_ious)):
        diff = post_pro_class_ious[cls] - initial_class_ious[cls]
        percent = (diff / initial_class_ious[cls] * 100) if initial_class_ious[cls] > 0 else float('inf')
        print(f"  - Classe {cls}: {diff:.4f} ({percent:.2f}%)")
    
    del all_post_pro_outputs, all_initial_outputs, all_targets
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()
    
    val_loss = total_loss / len(val_loader)
    return val_loss, post_pro_iou, initial_iou 

def visualize_memory_efficient(model, model_post_pro, val_loader, num_samples=2, device="cuda"):
    model.eval()
    model_post_pro.eval()
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(20, 5*num_samples))
    samples_seen = 0
    all_post_pro_outputs = []
    all_initial_outputs = []
    all_targets = []
    with torch.no_grad():
        for batch in val_loader:
            if samples_seen >= num_samples:
                break
                
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            images, masks, _ = batch["image"], batch["label"], batch["resized"]
            images, masks = images.to(device), masks.to(device)
            
            initial_outputs = model(images)
            
            probabilities = torch.softmax(initial_outputs, dim=1) 
            initial_masks = torch.argmax(initial_outputs, dim=1).cpu()
            all_initial_outputs.append(initial_masks)
            initial_masks = initial_masks.numpy()

            del initial_outputs
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            post_pro_outputs = model_post_pro(probabilities)
            post_pro_masks = torch.argmax(post_pro_outputs, dim=1).cpu()
            all_post_pro_outputs.append(post_pro_masks)
            post_pro_masks =post_pro_masks.numpy()
            
            del probabilities, post_pro_outputs
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            if masks.shape[1] == 1:
                target_masks = masks.squeeze(1).cpu()
            else:
                target_masks = torch.argmax(masks, dim=1).cpu()
            all_targets.append(target_masks)
            target_masks = target_masks.numpy()
            
            if samples_seen < num_samples:
                axes[samples_seen, 0].imshow(images[i, 0].cpu().numpy(), cmap='gray')
                axes[samples_seen, 0].set_title('Image originale')
                axes[samples_seen, 0].axis('off')
                
                axes[samples_seen, 1].imshow(initial_masks[i], cmap='jet')
                axes[samples_seen, 1].set_title('Prédiction initiale')
                axes[samples_seen, 1].axis('off')
                
                axes[samples_seen, 2].imshow(post_pro_masks[i], cmap='jet')
                axes[samples_seen, 2].set_title('Post-traité')
                axes[samples_seen, 2].axis('off')
                
                axes[samples_seen, 3].imshow(target_masks[i], cmap='jet')
                axes[samples_seen, 3].set_title('Ground Truth')
                axes[samples_seen, 3].axis('off')
                
                samples_seen += 1
            
            del images, masks, initial_masks, post_pro_masks, target_masks
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    all_post_pro_outputs = torch.cat(all_post_pro_outputs, dim=0)
    all_initial_outputs = torch.cat(all_initial_outputs, dim=0)
    all_targets = torch.cat(all_targets, dim=0)

    initial_ious = []
    post_pro_ious = []

    for i in range(all_targets.shape[0]):
        initial_iou, _ = compute_multiclass_iou(all_initial_outputs[i].unsqueeze(0), all_targets[i].unsqueeze(0), num_classes=3)
        post_pro_iou, _ = compute_multiclass_iou(all_post_pro_outputs[i].unsqueeze(0), all_targets[i].unsqueeze(0), num_classes=3)
        
        initial_ious.append(initial_iou)
        post_pro_ious.append(post_pro_iou)

def main():
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    torch.backends.cudnn.benchmark = False  
    torch.backends.cudnn.deterministic = True 
    model = smp.Unet(
        encoder_name="resnet34",       
        encoder_weights=None,          
        in_channels=1,                 
        classes=3                      
    ).to(device)

    checkpoint_path = "best_model_0_682/best_model.pth"
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.eval()
    print(" Modèle chargé avec succès !")

    model_post_pro = PostProcessingNet(in_channels=3, out_channels=3, base_filters=32).to(device)
    
    optimizer = torch.optim.Adam(model_post_pro.parameters(), lr=0.0005)  # LR réduit
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)
    criterion = torch.nn.CrossEntropyLoss()
    
    history = train_memory_efficient(
        model, model_post_pro, train_loader, val_loader, 
        optimizer, scheduler, criterion, num_epochs=30,
        val_frequency=5, checkpoint_path="best_post_pro_model_tho_li_2.pth"
    )
    
    model_post_pro.load_state_dict(torch.load("best_post_pro_model_tho_li_2.pth"))
    
    fig = visualize_memory_efficient(model, model_post_pro, val_loader, num_samples=2)
    
    return model_post_pro, history

In [None]:
#model_post_pro, history = main()

In [None]:
model_post_pro = PostProcessingNet(in_channels=3, out_channels=3, base_filters=32).to(device)
model_post_pro.load_state_dict(torch.load("best_model_0_682/best_post_pro_model_tho_li_2.pth"))
model = smp.Unet(
    encoder_name="resnet34",       
    encoder_weights=None,          
    in_channels=1,                 
    classes=3                      
).to(device)

checkpoint_path = "best_model_0_682/best_model.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
fig = visualize_memory_efficient(model, model_post_pro, val_loader, num_samples=9)
plt.show()


### IV - Post-processing manuel

In [None]:
import numpy as np
import torch
from scipy.ndimage import label, generate_binary_structure

def pipeline_modifie(pred_mask, batch_idx=0, i=0):
    is_tensor = isinstance(pred_mask, torch.Tensor)
    original_device = None
    
    if is_tensor:
        original_device = pred_mask.device
        pred_mask_np = pred_mask.cpu().numpy()
    else:
        pred_mask_np = pred_mask
    
    input_had_channel_dim = False
    if pred_mask_np.ndim == 3 and pred_mask_np.shape[0] == 1:
        input_had_channel_dim = True
        pred_mask_np = pred_mask_np.squeeze(0)
    
    result_mask = pred_mask_np.copy()
    
    class2_mask = np.zeros_like(pred_mask_np, dtype=np.uint8)
    class2_mask[pred_mask_np == 2] = 1
    
    structure = generate_binary_structure(2, 2)
    labeled, num_features = label(class2_mask, structure=structure)
    
    if num_features == 0:
        if is_tensor:
            if input_had_channel_dim:
                result_mask = np.expand_dims(result_mask, 0)
            result_mask = torch.from_numpy(result_mask).to(original_device)
        return result_mask
    
    component_sizes = np.bincount(labeled.ravel())[1:]
    
    largest_comp_idx = np.argmax(component_sizes) + 1
    
    class2_processed = np.zeros_like(pred_mask_np, dtype=np.uint8)
    class2_processed[labeled == largest_comp_idx] = 2
    
    result_mask = np.where(pred_mask_np == 1, 1, 0)  
    result_mask = np.where(class2_processed == 2, 2, result_mask)  
    
    if is_tensor:
        if input_had_channel_dim:
            result_mask = np.expand_dims(result_mask, 0)
        result_mask = torch.from_numpy(result_mask).to(original_device)
    
    return result_mask

In [None]:
import numpy as np
import random
import torch
import os
import matplotlib
matplotlib.use('Agg')  
import matplotlib.pyplot as plt
from datetime import datetime

def compute_multiclass_iou_numpy(preds, targets, num_classes=3):
    if len(preds.shape) == 2:
        preds_cls = preds
    else:
        if len(preds.shape) == 3 and preds.shape[0] > 1:
            preds_cls = np.argmax(preds, axis=0)
        else:
            preds_cls = preds.squeeze()  
    
    targets = targets.squeeze()
    
    class_ious = []
    
    for cls in range(num_classes):
        pred_mask = (preds_cls == cls).astype(np.float32)
        target_mask = (targets == cls).astype(np.float32)
        
        intersection = (pred_mask * target_mask).sum()
        union = ((pred_mask + target_mask) > 0).astype(np.float32).sum()
        
        iou = intersection / (union + 1e-6)
        class_ious.append(iou)
    
    mean_iou = sum(class_ious) / num_classes
    return class_ious[1]

def test_post_processing_random(model,post_pro_model, val_loader, output_dir="post_processing_results", 
                               num_batches=10, samples_per_batch=20):
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    model.eval()
    
    all_batches = list(val_loader)
    
    if num_batches > len(all_batches):
        num_batches = len(all_batches)
        print(f"Warning: Requested {num_batches} batches but only {len(all_batches)} available.")
    
    selected_batches = random.sample(all_batches, num_batches)
    
    results = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(all_batches):
            images, masks = batch["image"].to(device), batch["label"].to(device)

            
            initial_outputs = model(images)
            initial_preds = torch.argmax(initial_outputs, dim=1).cpu().numpy()

            probabilities = torch.softmax(initial_outputs, dim=1) 

            preds = post_pro_model(probabilities)
            pred_indices = torch.argmax(preds, dim=1)
            
            batch_size = images.size(0)

            
            print(f"\nTraitement du batch {batch_idx+1}/{len(all_batches)}")
            
            for i in range(batch_size):
                pred_mask = (pred_indices[i] == 1).cpu().numpy().squeeze()
                true_mask = (masks[i] == 1).cpu().numpy().squeeze()
                post_processed_mask = pipeline_modifie(
                    pred_mask,batch_idx ,i               )
                
                iou_before = compute_multiclass_iou_numpy(pred_mask, true_mask, num_classes=3)
                iou_after = compute_multiclass_iou_numpy(post_processed_mask, true_mask, num_classes=3)
                
                img = images[i].cpu().numpy()
                
                if len(img.shape) == 3 and img.shape[0] in [1, 3]:  # Format CHW
                    img = img.transpose(1, 2, 0)
                
                if len(img.shape) == 3 and img.shape[2] == 3:  
                    
                    img = (img - img.min()) / (img.max() - img.min() + 1e-6)
                elif len(img.shape) == 3 and img.shape[2] == 1:  
                    img = img.squeeze()
                results.append({
                    'batch': batch_idx,
                    'sample': i,
                    'iou_before': iou_before,
                    'iou_after': iou_after,
                    'improvement': iou_after - iou_before,
                    'true_mask' : true_mask,
                    'initial_pred_mask' : pred_mask,
                    'post_processed_mask' :post_processed_mask,
                    'img' : img
                })
                
                
    print("\n----- RÉSUMÉ DES RÉSULTATS -----")
    avg_improvement = sum(r['improvement'] for r in results) / len(results) if results else 0
    print(f"Amélioration moyenne de l'IoU: {avg_improvement:.4f}")
    
    if results:
        results.sort(key=lambda x: x['improvement'], reverse=True)
        print(f"Meilleure amélioration: {results[0]['improvement']:.4f} (batch {results[0]['batch']}, échantillon {results[0]['sample']})")
        print(f"Pire amélioration: {results[-1]['improvement']:.4f} (batch {results[-1]['batch']}, échantillon {results[-1]['sample']})")
        plt.figure(figsize=(16, 4))
        
        plt.subplot(1, 4, 1)
        plt.imshow(results[0]['img'], cmap="gray" if len(img.shape) == 2 else None)
        plt.title("Image d'entrée")
        plt.axis("off")
        
        plt.subplot(1, 4, 2)
        plt.imshow(results[0]['true_mask'], cmap="jet")
        plt.title("Masque réel (Ground Truth)")
        plt.axis("off")
        
        plt.subplot(1, 4, 3)
        plt.imshow(results[0]['initial_pred_mask'], cmap="jet")
        plt.title(f"Masque prédit - IoU: {iou_before:.4f}")
        plt.axis("off")
        
        plt.subplot(1, 4, 4)
        plt.imshow(results[0]['post_processed_mask'], cmap="jet")
        plt.title(f"Post-traité - IoU: {iou_after:.4f}")
        plt.axis("off")
        
        plt.tight_layout()
        filename = f"{output_dir}/batch{results[0]['batch']}_sample{results[0]['sample']}_best.png"
        plt.savefig(filename)
        plt.close()
        print(f"Image sauvegardée: {filename}")
        plt.figure(figsize=(16, 4))
        
        plt.subplot(1, 4, 1)
        plt.imshow(results[-1]['img'], cmap="gray" if len(img.shape) == 2 else None)
        plt.title("Image d'entrée")
        plt.axis("off")
        
        plt.subplot(1, 4, 2)
        plt.imshow(results[-1]['true_mask'], cmap="jet")
        plt.title("Masque réel (Ground Truth)")
        plt.axis("off")
        
        plt.subplot(1, 4, 3)
        plt.imshow(results[-1]['initial_pred_mask'], cmap="jet")
        plt.title(f"Masque prédit - IoU: {iou_before:.4f}")
        plt.axis("off")
        
        plt.subplot(1, 4, 4)
        plt.imshow(results[-1]['post_processed_mask'], cmap="jet")
        plt.title(f"Post-traité - IoU: {iou_after:.4f}")
        plt.axis("off")
        
        plt.tight_layout()
        # Sauvegarder l'image
        filename = f"{output_dir}/batch{results[-1]['batch']}_sample{results[-1]['sample']}_worst.png"
        plt.savefig(filename)
        plt.close()
        print(f"Image sauvegardée: {filename}")
    
        
    return results

In [None]:
model_post_pro = PostProcessingNet(in_channels=3, out_channels=3, base_filters=32).to(device)
model_post_pro.load_state_dict(torch.load("best_post_pro_model_tho_li_2.pth"))
model = smp.Unet(
    encoder_name="resnet34",       
    encoder_weights=None,          
    in_channels=1,                 
    classes=3                      
).to(device)

checkpoint_path = "best_model.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))

In [None]:

test_post_processing_random(model,model_post_pro,val_loader)

### V - Test et mise en forme

In [None]:
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import pandas as pd
import cv2
import numpy as np
from scipy.ndimage import label

resize_transform = T.Resize((160, 272), interpolation=T.InterpolationMode.NEAREST)

def morphological_postprocessing(mask):
  
    kernel = np.ones((3, 3), np.uint8)
    
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    
    return mask

def postprocess_predictions(preds, resized, file_names):

    results = {}

    for i in range(preds.shape[0]):  
        pred = torch.argmax(preds[i], dim=0, keepdim=True) 
        
        if resized[i]:  
            pred = resize_transform(pred.unsqueeze(0)).squeeze(0) 
        pred = pipeline_modifie(pred)

        pred_np = pred.cpu().numpy().astype(np.uint8)

        pred_np = morphological_postprocessing(pred_np)
        processed_pred = pred_np.flatten()

        if not resized[i]:
            pad_size = (160 * 272) - processed_pred.shape[0] 
            processed_pred = F.pad(torch.tensor(processed_pred), (0, pad_size), value=-1).numpy()

        results[file_names[i]] = processed_pred

    df = pd.DataFrame.from_dict(results, orient="index")
    return df


In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class TestImageDataset(Dataset):
    def __init__(self, root_dir):

        self.root_dir = root_dir
        self.files = [f for f in os.listdir(root_dir) if f.endswith(".npy")]

        self.resize_transform = transforms.Resize((160, 160), interpolation=transforms.InterpolationMode.BILINEAR)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_name = self.files[idx]
        file_path = os.path.join(self.root_dir, file_name)

        image = np.load(file_path).astype(np.float32)
        original_size = image.shape
        resized = False  

        image_min, image_max = image.min(), image.max()
        if image_max > image_min:  
            image = (image - image_min) / (image_max - image_min)
        else:
            image = np.zeros_like(image)  

        image = torch.tensor(image, dtype=torch.float32).unsqueeze(0)  
        if original_size != (160, 160):
            image = self.resize_transform(image)
            resized = True  

        return {
            "image": image,
            "resized": resized,
            "file_name": os.path.splitext(file_name)[0] 
        }


In [None]:
import torch
import segmentation_models_pytorch as smp
import pandas as pd
from torch.utils.data import DataLoader
from tqdm import tqdm  


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Utilisation de : {device}")

#
model = smp.Unet(
    encoder_name="resnet34",       
    encoder_weights=None,          
    in_channels=1,                 
    classes=3                      
).to(device)

checkpoint_path = "best_model.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))

model_post_pro = PostProcessingNet(in_channels=3, out_channels=3, base_filters=32).to(device)
model_post_pro.load_state_dict(torch.load("best_post_pro_model_tho_li_2.pth"))


print(" Modèle chargé avec succès !")

test_dataset = TestImageDataset(root_dir="images test")  $

test_loader = DataLoader(
    test_dataset,
    batch_size=8,  
    shuffle=False,
    pin_memory=True,  
    num_workers=4  
)

model.eval()
all_results = pd.DataFrame()  

with torch.no_grad():
    loop = tqdm(test_loader, desc="Traitement des images de test", unit="image")
    
    for batch in loop:
        images = batch["image"].to(device, non_blocking=True)  
        resized = batch["resized"]  
        file_names = batch["file_name"]  
                
        initial_outputs = model(images)
        initial_preds = torch.argmax(initial_outputs, dim=1).cpu().numpy()

        probabilities = torch.softmax(initial_outputs, dim=1) 

        preds = model_post_pro(probabilities)

        df_preds = postprocess_predictions(preds, resized, file_names)  # Post-traitement

        all_results = pd.concat([all_results, df_preds])

        loop.set_postfix(Images_traitees=len(all_results))

all_results.to_csv("predictions_test.csv")
print("Prédictions sauvegardées dans 'predictions_test.csv'")


In [None]:
import matplotlib.pyplot as plt
import random
import numpy as np
import os

vis_dir = "visualizations"
os.makedirs(vis_dir, exist_ok=True)

predictions_df = pd.read_csv("predictions_test.csv", index_col=0)

num_vis = min(10, len(predictions_df))
random_samples = random.sample(list(predictions_df.index), num_vis)

print(f"\nGénération de {num_vis} visualisations aléatoires...")

for i, file_name in enumerate(random_samples):
    img_path = os.path.join("images test", f"{file_name}.npy")
    orig_image = np.load(img_path).astype(np.float32)
    
    image_min, image_max = orig_image.min(), orig_image.max()
    if image_max > image_min:
        orig_image = (orig_image - image_min) / (image_max - image_min)
    
    prediction = predictions_df.loc[file_name].values
    
    prediction = prediction.reshape(160, 272).astype(np.uint8)
    
    colors = [
        [0, 0, 0],      
        [255, 0, 0],    
        [0, 255, 0],    
    ]
    
    pred_colored = np.zeros((prediction.shape[0], prediction.shape[1], 3), dtype=np.uint8)
    for class_idx, color in enumerate(colors):
        pred_colored[prediction == class_idx] = color
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    axes[0].imshow(orig_image, cmap='gray')
    axes[0].set_title(f'Image originale: {file_name}')
    axes[0].axis('off')
    
    axes[1].imshow(pred_colored)
    axes[1].set_title('Segmentation prédite')
    axes[1].axis('off')
    
    legend_elements = [
        plt.Rectangle((0, 0), 1, 1, color=[c/255 for c in colors[0]], label='Classe 0'),
        plt.Rectangle((0, 0), 1, 1, color=[c/255 for c in colors[1]], label='Classe 1'),
        plt.Rectangle((0, 0), 1, 1, color=[c/255 for c in colors[2]], label='Classe 2')
    ]
    axes[1].legend(handles=legend_elements, loc='lower right')
    
    plt.tight_layout()
    plt.savefig(os.path.join(vis_dir, f'sample_{i}_{file_name}.png'), dpi=150)
    plt.close(fig)
    
    print(f"Visualisation {i+1}/{num_vis} générée")

print(f"{num_vis} visualisations générées avec succès dans '{vis_dir}'!")