In [1]:
import os
from dotenv import load_dotenv
from huggingface_hub import HfApi

env_path = os.path.join(os.getcwd(), '.env')

In [2]:
hf_token = os.getenv('HF_TOKEN)')


In [33]:
import requests
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from huggingface_hub import InferenceClient
import io
from dotenv import load_dotenv
load_dotenv()

class ClothesSegmentationAPI:
    def __init__(self, hf_token):
        os.environ["HF_TOKEN"] = hf_token
        self.client = InferenceClient(
            provider="hf-inference",
            api_key=hf_token
        )
        self.model = "sayeed99/segformer_b3_clothes"
        
        self.label_map = {
            0: "Background", 1: "Hat", 2: "Hair", 3: "Sunglasses",
            4: "Upper-clothes", 5: "Skirt", 6: "Pants", 7: "Dress",
            8: "Belt", 9: "Left-shoe", 10: "Right-shoe", 11: "Face",
            12: "Left-leg", 13: "Right-leg", 14: "Left-arm", 15: "Right-arm",
            16: "Bag", 17: "Scarf"
        }
        
        self.colors = plt.cm.tab20(np.linspace(0, 1, len(self.label_map)))
    
    def segment_image(self, image_path):
        output = self.client.image_segmentation(image_path, model=self.model)
        
        # Convertir le r√©sultat en masque
        mask_image = output[0]['mask']  # R√©cup√©rer le masque
        return mask_image
    
    def visualize_comparison(self, image_path, mask_pred, mask_true, save_path=None):
        original = Image.open(image_path).convert('RGB')
        mask_pred_array = np.array(mask_pred)
        mask_true_array = np.array(mask_true)

        fig, axes = plt.subplots(2, 3, figsize=(18, 12))

        # Ligne 1: Pr√©diction
        axes[0, 0].imshow(original)
        axes[0, 0].set_title("Image Originale", fontsize=14, fontweight='bold')
        axes[0, 0].axis('off')

        # Masque pr√©dit
        colored_pred = np.zeros((*mask_pred_array.shape, 3))
        for label_id in np.unique(mask_pred_array):
            if label_id < len(self.colors):
                colored_pred[mask_pred_array == label_id] = self.colors[label_id][:3]

        axes[0, 1].imshow(colored_pred)
        axes[0, 1].set_title("Masque PR√âDIT", fontsize=14, fontweight='bold', color='blue')
        axes[0, 1].axis('off')

        axes[0, 2].imshow(original)
        axes[0, 2].imshow(colored_pred, alpha=0.6)
        axes[0, 2].set_title("Overlay Pr√©dit", fontsize=14, fontweight='bold')
        axes[0, 2].axis('off')

        # Ligne 2: Ground Truth
        axes[1, 0].axis('off')  # Vide

        # Masque r√©el
        colored_true = np.zeros((*mask_true_array.shape, 3))
        for label_id in np.unique(mask_true_array):
            if label_id < len(self.colors):
                colored_true[mask_true_array == label_id] = self.colors[label_id][:3]

        axes[1, 1].imshow(colored_true)
        axes[1, 1].set_title("Masque R√âEL (Ground Truth)", fontsize=14, fontweight='bold', color='green')
        axes[1, 1].axis('off')

        # Diff√©rence (erreurs)
        difference = (mask_pred_array != mask_true_array).astype(float)
        axes[1, 2].imshow(difference, cmap='Reds')
        axes[1, 2].set_title("Diff√©rences (erreurs en rouge)", fontsize=14, fontweight='bold', color='red')
        axes[1, 2].axis('off')

        # Calculer m√©triques
        metrics = SegmentationMetrics()
        iou = metrics.calculate_iou(mask_pred_array, mask_true_array, 18).mean()
        dice = metrics.calculate_dice(mask_pred_array, mask_true_array, 18).mean()
        pixel_acc = metrics.calculate_pixel_accuracy(mask_pred_array, mask_true_array)

        plt.suptitle(f"{Path(image_path).name}\nIoU: {iou:.3f} | Dice: {dice:.3f} | Accuracy: {pixel_acc:.3f}", 
                     fontsize=16, fontweight='bold')

        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')

        plt.show()

    
    def show_color_guide(self, save_path="segmentation_results/guide_couleurs.png"):
        fig, ax = plt.subplots(figsize=(10, 8))
        ax.axis('off')

        from matplotlib.patches import Rectangle

        y_pos = 0.95
        for label_id, label_name in self.label_map.items():
            if label_name != "Background":
                color = self.colors[label_id][:3]

                # Rectangle color√©
                rect = Rectangle((0.1, y_pos-0.04), 0.1, 0.03, 
                               facecolor=color, edgecolor='black', linewidth=1)
                ax.add_patch(rect)

                # Texte
                ax.text(0.25, y_pos-0.025, label_name, 
                       fontsize=14, va='center')

                y_pos -= 0.055

        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.set_title("Code Couleur des V√™tements", 
                    fontsize=18, fontweight='bold', pad=20)

        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.show()
        print(f"‚úÖ Guide sauvegard√©: {save_path}")
    
    def _print_statistics(self, mask_array):
        unique_labels = np.unique(mask_array)
        total_pixels = mask_array.size
        
        print("\nüìä Statistiques de segmentation:")
        print("-" * 60)
        
        for label_id in unique_labels:
            if label_id in self.label_map:
                count = np.sum(mask_array == label_id)
                percentage = (count / total_pixels) * 100
                label_name = self.label_map[label_id]
                
                if label_name != "Background" and percentage > 1:
                    print(f"  {label_name:15s}: {percentage:5.2f}% ({count:,} pixels)")

def find_images(folder_path):
        folder = Path(folder_path)

        if not folder.exists():
            print(f"‚ùå Le dossier '{folder_path}' n'existe pas!")
            return []

        # Trouver tous les fichiers .png (insensible √† la casse)
        image_files = []

        for file in folder.iterdir():
            if file.is_file():
                # V√©rifier l'extension en minuscules
                if file.suffix.lower() == '.png':
                    image_files.append(file)

        return sorted(image_files)
    
    

def process_with_comparison(folder_path, mask_folder, hf_token, save_results=False, max_images=None):
        
        api = ClothesSegmentationAPI(hf_token)
        image_files = find_images(folder_path)

        if max_images:
            image_files = image_files[:max_images]

        print(f"üîç {len(image_files)} image(s) √† traiter avec comparaison\n")

        results_folder = None
        if save_results:
            results_folder = Path("segmentation_results/comparisons")
            results_folder.mkdir(parents=True, exist_ok=True)

        success_count = 0
        for idx, image_path in enumerate(image_files, 1):
            print(f"\n{'='*70}")
            print(f"üñºÔ∏è  Image {idx}/{len(image_files)}: {image_path.name}")
            print(f"{'='*70}")

            # Chercher le masque correspondant
            mask_path = Path(mask_folder) / image_path.name

            if not mask_path.exists():
                print(f"‚ö†Ô∏è Masque non trouv√©: {mask_path.name}")
                continue

            try:
                # Pr√©diction
                print("‚è≥ Segmentation...")
                mask_pred = api.segment_image(str(image_path))

                # Charger masque r√©el
                mask_true = Image.open(mask_path)

                print("‚úÖ Comparaison...")

                save_path = None
                if save_results:
                    save_path = results_folder / f"comparison_{image_path.stem}.png"

                api.visualize_comparison(str(image_path), mask_pred, mask_true, save_path)
                success_count += 1

            except Exception as e:
                print(f"‚ùå Erreur: {e}")
                continue

        print(f"\n‚úÖ {success_count}/{len(image_files)} comparaisons effectu√©es")


class SegmentationMetrics:
    @staticmethod
    def calculate_iou(pred_mask, true_mask, num_classes):
        ious = []
        for cls in range(num_classes):
            pred_cls = (pred_mask == cls)
            true_cls = (true_mask == cls)
            intersection = np.logical_and(pred_cls, true_cls).sum()
            union = np.logical_or(pred_cls, true_cls).sum()
            iou = intersection / union if union > 0 else 0
            ious.append(iou)
        return np.array(ious)

    @staticmethod
    def calculate_dice(pred_mask, true_mask, num_classes):
        dice_scores = []
        for cls in range(num_classes):
            pred_cls = (pred_mask == cls)
            true_cls = (true_mask == cls)
            intersection = np.logical_and(pred_cls, true_cls).sum()
            total = pred_cls.sum() + true_cls.sum()
            dice = (2 * intersection) / total if total > 0 else 0
            dice_scores.append(dice)
        return np.array(dice_scores)

    @staticmethod
    def calculate_pixel_accuracy(pred_mask, true_mask):
        correct = (pred_mask == true_mask).sum()
        total = pred_mask.size
        return correct / total

# ============================================================================
# EX√âCUTION PRINCIPALE
# ============================================================================

# token API Hugging Face
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise ValueError("Le token HF_TOKEN n'a pas √©t√© trouv√© dans le fichier .env")
    
# Chemin vers vos images
IMAGES_FOLDER = "data/images/IMG"

# Lancer l'analyse
print("üöÄ D√©marrage de l'analyse de segmentation de v√™tements\n")
process_with_comparison(
    folder_path="data/images/IMG",
    mask_folder="data/images/Mask",
    hf_token=HF_TOKEN,
    save_results=True  # Mettre False pour ne pas sauvegarder
)

üöÄ D√©marrage de l'analyse de segmentation de v√™tements

üîç 50 image(s) √† traiter avec comparaison


üñºÔ∏è  Image 1/50: image_0.png
‚ö†Ô∏è Masque non trouv√©: image_0.png

üñºÔ∏è  Image 2/50: image_1.png
‚ö†Ô∏è Masque non trouv√©: image_1.png

üñºÔ∏è  Image 3/50: image_10.png
‚ö†Ô∏è Masque non trouv√©: image_10.png

üñºÔ∏è  Image 4/50: image_11.png
‚ö†Ô∏è Masque non trouv√©: image_11.png

üñºÔ∏è  Image 5/50: image_12.png
‚ö†Ô∏è Masque non trouv√©: image_12.png

üñºÔ∏è  Image 6/50: image_13.png
‚ö†Ô∏è Masque non trouv√©: image_13.png

üñºÔ∏è  Image 7/50: image_14.png
‚ö†Ô∏è Masque non trouv√©: image_14.png

üñºÔ∏è  Image 8/50: image_15.png
‚ö†Ô∏è Masque non trouv√©: image_15.png

üñºÔ∏è  Image 9/50: image_16.png
‚ö†Ô∏è Masque non trouv√©: image_16.png

üñºÔ∏è  Image 10/50: image_17.png
‚ö†Ô∏è Masque non trouv√©: image_17.png

üñºÔ∏è  Image 11/50: image_18.png
‚ö†Ô∏è Masque non trouv√©: image_18.png

üñºÔ∏è  Image 12/50: image_19.png
‚ö†Ô∏è Masque non trouv√©: image_19

In [34]:
# V√©rifier les noms dans le dossier Mask
mask_folder = Path("data/images/Mask")
mask_files = sorted([f.name for f in mask_folder.glob("*.png")])

print("üìÑ Premiers masques trouv√©s:")
for m in mask_files[:10]:
    print(f"  {m}")

üìÑ Premiers masques trouv√©s:
  mask_0.png
  mask_1.png
  mask_10.png
  mask_11.png
  mask_12.png
  mask_13.png
  mask_14.png
  mask_15.png
  mask_16.png
  mask_17.png
