# 🎭 Face Concern Detector - Complete Interactive Demo

## 🚀 Project Overview

This notebook implements an **end-to-end face concern detection system** using deep learning to detect and visualize facial skin concerns including:

- 🔴 **Acne**: Inflammatory skin conditions and blemishes
- 👁️ **Dark Circles**: Discoloration under the eyes  
- 🟡 **Redness**: Skin irritation, rosacea, or inflammation
- 📏 **Wrinkles**: Fine lines and aging signs

## ⭐ Key Features

- **Multi-label Classification**: Detects multiple concerns simultaneously
- **Face Detection**: Automatic face detection and alignment using MTCNN
- **Explainable AI**: GradCAM visualizations showing prediction reasoning
- **Mac Optimized**: MPS acceleration for Apple Silicon (M1/M2)
- **Dual Dataset**: Combined training from multiple Kaggle sources
- **Visual Overlays**: Semi-transparent heatmaps and confidence scores (0-100%)

## 🎯 Technology Stack

- **Model**: ResNet18 with pretrained ImageNet weights
- **Face Detection**: MTCNN for robust face alignment
- **Visualization**: GradCAM for explainable predictions
- **Framework**: PyTorch with MPS support
- **Dataset**: Combined Kaggle datasets for comprehensive training

---

**Let's build an amazing face concern detection system! 🚀**

## 📦 Section 1: Import Required Libraries and Setup Environment

Let's start by importing all the necessary libraries and checking our system capabilities.

In [None]:
# Import essential libraries
import os
import sys
import warnings
warnings.filterwarnings('ignore')

# Add project root to path
sys.path.append('..')

# Deep learning and computer vision
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, models
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Computer vision and image processing
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Rectangle
import seaborn as sns

# Face detection
from mtcnn import MTCNN

# Data processing
import pandas as pd
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.model_selection import train_test_split

# Progress bars and utilities
from tqdm import tqdm
import time
import json
import shutil
from pathlib import Path

# Dataset download
import kagglehub

print("✅ All libraries imported successfully!")

# Check PyTorch version and device capabilities
print(f"\n🔧 System Information:")
print(f"PyTorch version: {torch.__version__}")
print(f"Torchvision version: {torchvision.__version__}")

# Device detection with detailed info
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print(f"🚀 Using Apple Silicon MPS acceleration!")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"🚀 Using CUDA GPU: {torch.cuda.get_device_name()}")
else:
    device = torch.device("cpu")
    print(f"💻 Using CPU")

print(f"Selected device: {device}")

# Create directory structure
directories = [
    'data/raw',
    'data/processed', 
    'data/sample_images',
    'models/saved_weights',
    'outputs/visualizations',
    'outputs/logs',
    'outputs/results'
]

for directory in directories:
    os.makedirs(directory, exist_ok=True)
    
print(f"\n📁 Created {len(directories)} directories for project organization")

## 📥 Section 2: Download and Combine Multiple Kaggle Datasets

We'll download and combine two complementary Kaggle datasets to create a comprehensive training set:

1. **Acne-Wrinkles-Spots Classification** - 600 images covering acne, wrinkles, and spots
2. **Skin Defects Dataset** - Additional images for acne, redness, and dark circles

This dual-dataset approach ensures we have sufficient training data for all four skin concerns.

In [None]:
def download_and_combine_datasets():
    """
    Download both Kaggle datasets and combine them into a unified format
    """
    print("🔄 Starting dataset download and preparation...")
    
    # Dataset 1: Acne-Wrinkles-Spots Classification
    print("\n📥 Downloading Acne-Wrinkles-Spots dataset...")
    try:
        acne_spots_path = kagglehub.dataset_download(
            "ranvijaybalbir/acne-wrinkles-spots-classification"
        )
        print(f"✅ Dataset 1 downloaded to: {acne_spots_path}")
    except Exception as e:
        print(f"❌ Error downloading dataset 1: {e}")
        return None, None
    
    # Dataset 2: Skin Defects Dataset
    print("\n📥 Downloading Skin Defects dataset...")
    try:
        skin_defects_path = kagglehub.dataset_download(
            "trainingdatapro/skin-defects-acne-redness-and-bags-under-the-eyes"
        )
        print(f"✅ Dataset 2 downloaded to: {skin_defects_path}")
    except Exception as e:
        print(f"❌ Error downloading dataset 2: {e}")
        return acne_spots_path, None
    
    return acne_spots_path, skin_defects_path

def process_combined_datasets(acne_spots_path, skin_defects_path, output_dir='data/processed'):
    """
    Process and combine both datasets into unified format
    """
    os.makedirs(output_dir, exist_ok=True)
    annotations = []
    
    print("\n🔄 Processing Dataset 1: Acne-Wrinkles-Spots...")
    
    # Process first dataset
    dataset1_mapping = {
        'acne': 'acne',
        'wrinkles': 'wrinkles',
        'spots': 'redness'  # Map spots to redness
    }
    
    for kaggle_cat, our_cat in dataset1_mapping.items():
        src_dir = os.path.join(acne_spots_path, kaggle_cat)
        if os.path.exists(src_dir):
            image_files = [f for f in os.listdir(src_dir) 
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            
            print(f"  Found {len(image_files)} images for {kaggle_cat}")
            
            for img_file in image_files:
                # Create unique filename
                unique_name = f"ds1_{kaggle_cat}_{img_file}"
                
                # Copy image
                src_path = os.path.join(src_dir, img_file)
                dst_path = os.path.join(output_dir, unique_name)
                shutil.copy2(src_path, dst_path)
                
                # Create annotation
                labels = {
                    'acne': 1 if our_cat == 'acne' else 0,
                    'dark_circles': 0,  # Not in this dataset
                    'redness': 1 if our_cat == 'redness' else 0,
                    'wrinkles': 1 if our_cat == 'wrinkles' else 0
                }
                
                annotations.append({
                    'image_name': unique_name,
                    'source': 'acne_spots',
                    'original_category': kaggle_cat,
                    **labels
                })
    
    print("\n🔄 Processing Dataset 2: Skin Defects...")
    
    # Process second dataset (try multiple directory structures)
    dataset2_mapping = {
        'acne': 'acne',
        'redness': 'redness', 
        'bags_under_eyes': 'dark_circles',
        'bags-under-eyes': 'dark_circles',
        'dark_circles': 'dark_circles'
    }
    
    for kaggle_cat, our_cat in dataset2_mapping.items():
        # Try different possible directory structures
        possible_paths = [
            os.path.join(skin_defects_path, kaggle_cat),
            os.path.join(skin_defects_path, kaggle_cat.replace('_', '-')),
            os.path.join(skin_defects_path, 'data', kaggle_cat),
            os.path.join(skin_defects_path, 'images', kaggle_cat),
            os.path.join(skin_defects_path, kaggle_cat.title())
        ]
        
        src_dir = None
        for path in possible_paths:
            if os.path.exists(path):
                src_dir = path
                break
        
        if src_dir:
            image_files = [f for f in os.listdir(src_dir) 
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            
            print(f"  Found {len(image_files)} images for {kaggle_cat}")
            
            for img_file in image_files:
                # Create unique filename
                unique_name = f"ds2_{kaggle_cat}_{img_file}"
                
                # Copy image
                src_path = os.path.join(src_dir, img_file)
                dst_path = os.path.join(output_dir, unique_name)
                shutil.copy2(src_path, dst_path)
                
                # Create annotation
                labels = {
                    'acne': 1 if our_cat == 'acne' else 0,
                    'dark_circles': 1 if our_cat == 'dark_circles' else 0,
                    'redness': 1 if our_cat == 'redness' else 0,
                    'wrinkles': 0  # Not in this dataset
                }
                
                annotations.append({
                    'image_name': unique_name,
                    'source': 'skin_defects',
                    'original_category': kaggle_cat,
                    **labels
                })
        else:
            print(f"  ⚠️ No directory found for {kaggle_cat}")
    
    # Save combined annotations
    df = pd.DataFrame(annotations)
    output_csv = os.path.join(output_dir, 'combined_annotations.csv')
    df.to_csv(output_csv, index=False)
    
    # Print statistics
    print(f"\n📊 Combined Dataset Statistics:")
    print(f"Total images: {len(df)}")
    print(f"From Dataset 1 (Acne-Wrinkles-Spots): {len(df[df['source'] == 'acne_spots'])}")
    print(f"From Dataset 2 (Skin Defects): {len(df[df['source'] == 'skin_defects'])}")
    print(f"\n🏷️ Concern Distribution:")
    
    concern_stats = {}
    for concern in ['acne', 'dark_circles', 'redness', 'wrinkles']:
        count = df[concern].sum()
        concern_stats[concern] = count
        print(f"  {concern.replace('_', ' ').title()}: {count} images")
    
    print(f"\n✅ Combined dataset saved to: {output_dir}")
    print(f"✅ Annotations saved to: {output_csv}")
    
    return df, concern_stats

# Download and process datasets
acne_spots_path, skin_defects_path = download_and_combine_datasets()

if acne_spots_path:
    combined_df, stats = process_combined_datasets(acne_spots_path, skin_defects_path)
    
    # Visualize dataset distribution
    plt.figure(figsize=(12, 5))
    
    # Subplot 1: Concern distribution
    plt.subplot(1, 2, 1)
    concerns = list(stats.keys())
    counts = list(stats.values())
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
    
    plt.bar(concerns, counts, color=colors, alpha=0.8)
    plt.title('Skin Concern Distribution', fontsize=14, fontweight='bold')
    plt.xlabel('Skin Concerns')
    plt.ylabel('Number of Images')
    plt.xticks(rotation=45)
    
    for i, v in enumerate(counts):
        plt.text(i, v + 5, str(v), ha='center', fontweight='bold')
    
    # Subplot 2: Dataset sources
    plt.subplot(1, 2, 2)
    source_counts = combined_df['source'].value_counts()
    plt.pie(source_counts.values, labels=['Acne-Wrinkles-Spots', 'Skin Defects'], 
            autopct='%1.1f%%', colors=['#FF9999', '#66B2FF'], startangle=90)
    plt.title('Dataset Sources', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('outputs/visualizations/dataset_distribution.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\n🎯 Dataset preparation complete! Ready for training with {len(combined_df)} total images.")
else:
    print("❌ Failed to download datasets. Please check your internet connection and kagglehub setup.")

## 👤 Section 3: Face Detection and Preprocessing Pipeline

Now let's implement our robust face detection and preprocessing pipeline using MTCNN (Multi-task Cascaded Convolutional Networks) for accurate face detection and alignment.

In [None]:
class FacePreprocessor:
    """
    Advanced face detection and preprocessing pipeline using MTCNN
    """
    
    def __init__(self, image_size=224, margin=20, min_face_size=40):
        """
        Initialize face preprocessor
        
        Args:
            image_size: Target size for face images (224x224 for ResNet)
            margin: Margin around detected face in pixels
            min_face_size: Minimum face size to consider valid
        """
        self.detector = MTCNN(device='cpu')  # MTCNN works best on CPU
        self.image_size = image_size
        self.margin = margin
        self.min_face_size = min_face_size
        
        print(f"🎯 FacePreprocessor initialized:")
        print(f"   Image size: {image_size}x{image_size}")
        print(f"   Margin: {margin}px")
        print(f"   Min face size: {min_face_size}px")
    
    def detect_faces(self, image):
        """
        Detect all faces in image with confidence scores
        
        Args:
            image: PIL Image or numpy array
            
        Returns:
            List of detected faces with bounding boxes and confidence
        """
        if isinstance(image, Image.Image):
            image_array = np.array(image)
        else:
            image_array = image
        
        try:
            results = self.detector.detect_faces(image_array)
            
            # Filter by minimum face size and confidence
            valid_faces = []
            for face in results:
                bbox = face['box']
                confidence = face['confidence']
                
                # Check minimum size
                if bbox[2] >= self.min_face_size and bbox[3] >= self.min_face_size:
                    # Check confidence (MTCNN usually gives good results above 0.9)
                    if confidence > 0.8:
                        valid_faces.append(face)
            
            return valid_faces
            
        except Exception as e:
            print(f"Face detection error: {e}")
            return []
    
    def get_best_face(self, faces):
        """
        Select the best face from detected faces (highest confidence + largest size)
        """
        if not faces:
            return None
        
        # Score faces based on confidence and size
        scored_faces = []
        for face in faces:
            bbox = face['box']
            confidence = face['confidence']
            face_area = bbox[2] * bbox[3]  # width * height
            
            # Combined score: confidence * normalized_area
            score = confidence * (face_area / 10000)  # Normalize area
            scored_faces.append((score, face))
        
        # Return face with highest score
        return max(scored_faces, key=lambda x: x[0])[1]
    
    def crop_and_align_face(self, image, face_info):
        """
        Crop face from image with proper alignment and margin
        """
        if isinstance(image, Image.Image):
            img_array = np.array(image)
        else:
            img_array = image
        
        bbox = face_info['box']
        x, y, w, h = bbox
        
        # Add margin while staying within image bounds
        img_height, img_width = img_array.shape[:2]
        
        x1 = max(0, x - self.margin)
        y1 = max(0, y - self.margin)
        x2 = min(img_width, x + w + self.margin)
        y2 = min(img_height, y + h + self.margin)
        
        # Crop face
        face_crop = img_array[y1:y2, x1:x2]
        
        # Convert back to PIL Image
        face_image = Image.fromarray(face_crop)
        
        # Resize to target size
        face_resized = face_image.resize((self.image_size, self.image_size), Image.LANCZOS)
        
        return face_resized, (x1, y1, x2, y2)
    
    def preprocess_image(self, image_path):
        """
        Complete preprocessing pipeline: detect -> crop -> resize
        
        Returns:
            Tuple: (processed_face_image, detection_info) or (None, error_msg)
        """
        try:
            # Load image
            if isinstance(image_path, str):
                image = Image.open(image_path).convert('RGB')
            else:
                image = image_path.convert('RGB')
            
            # Detect faces
            faces = self.detect_faces(image)
            
            if not faces:
                return None, "No face detected"
            
            # Get best face
            best_face = self.get_best_face(faces)
            
            if not best_face:
                return None, "No suitable face found"
            
            # Crop and align
            face_image, crop_coords = self.crop_and_align_face(image, best_face)
            
            detection_info = {
                'confidence': best_face['confidence'],
                'bbox': best_face['box'],
                'crop_coords': crop_coords,
                'num_faces_detected': len(faces)
            }
            
            return face_image, detection_info
            
        except Exception as e:
            return None, f"Preprocessing error: {str(e)}"

def create_data_transforms(train=True, image_size=224):
    """
    Create data augmentation transforms optimized for face images
    """
    if train:
        # Training augmentations (careful not to distort facial features too much)
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=10, fill=0),
            transforms.ColorJitter(
                brightness=0.2,
                contrast=0.2, 
                saturation=0.2,
                hue=0.1
            ),
            transforms.RandomAffine(
                degrees=0,
                translate=(0.05, 0.05),
                scale=(0.95, 1.05),
                fill=0
            ),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],  # ImageNet pretrained means
                std=[0.229, 0.224, 0.225]    # ImageNet pretrained stds
            )
        ])
    else:
        # Validation/test transforms (no augmentation)
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    
    return transform

# Initialize face preprocessor
face_preprocessor = FacePreprocessor(image_size=224, margin=20)

# Test face detection on a sample (if any processed images exist)
processed_dir = 'data/processed'
if os.path.exists(processed_dir):
    sample_images = [f for f in os.listdir(processed_dir) 
                    if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    
    if sample_images:
        # Test on first few images
        print(f"\n🧪 Testing face detection on sample images...")
        
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.flatten()
        
        test_count = min(6, len(sample_images))
        
        for i in range(test_count):
            img_path = os.path.join(processed_dir, sample_images[i])
            
            # Process image
            face_img, detection_info = face_preprocessor.preprocess_image(img_path)
            
            if face_img:
                axes[i].imshow(face_img)
                axes[i].set_title(
                    f"✅ Face Detected\nConf: {detection_info['confidence']:.2f}", 
                    color='green', fontsize=10
                )
            else:
                # Show original image if face detection failed
                original = Image.open(img_path)
                axes[i].imshow(original)
                axes[i].set_title(f"❌ {detection_info}", color='red', fontsize=10)
            
            axes[i].axis('off')
        
        # Hide unused subplots
        for i in range(test_count, len(axes)):
            axes[i].axis('off')
        
        plt.suptitle('Face Detection Test Results', fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.savefig('outputs/visualizations/face_detection_test.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        print(f"✅ Face detection tested on {test_count} images")
    else:
        print("⚠️ No sample images found for face detection testing")
else:
    print("⚠️ No processed directory found yet")

## 🏷️ Section 4: Dataset Preparation and Multi-Label Annotation

Let's create our custom dataset class for multi-label skin concern classification and properly split our data.

In [None]:
class SkinConcernDataset(Dataset):
    """
    Custom PyTorch Dataset for multi-label skin concern classification
    """
    
    def __init__(self, data_dir, annotations_df=None, transform=None, face_preprocessor=None):
        """
        Args:
            data_dir: Directory containing images
            annotations_df: DataFrame with image annotations
            transform: Data augmentation transforms
            face_preprocessor: FacePreprocessor instance for face detection
        """
        self.data_dir = data_dir
        self.transform = transform
        self.face_preprocessor = face_preprocessor
        
        if annotations_df is not None:
            self.annotations = annotations_df.reset_index(drop=True)
            self.concern_labels = ['acne', 'dark_circles', 'redness', 'wrinkles']
        else:
            # Load all images if no annotations provided
            self.annotations = None
            self.image_files = [f for f in os.listdir(data_dir) 
                               if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    
    def __len__(self):
        if self.annotations is not None:
            return len(self.annotations)
        else:
            return len(self.image_files)
    
    def __getitem__(self, idx):
        """
        Get a sample from the dataset
        
        Returns:
            If annotations available: (image_tensor, label_tensor, metadata)
            If no annotations: (image_tensor, image_name)
        """
        if self.annotations is not None:
            # Training/validation mode with labels
            row = self.annotations.iloc[idx]
            img_name = row['image_name']
            img_path = os.path.join(self.data_dir, img_name)
            
            # Load and preprocess image
            if self.face_preprocessor:
                face_img, detection_info = self.face_preprocessor.preprocess_image(img_path)
                if face_img is None:
                    # Fallback to original image if face detection fails
                    face_img = Image.open(img_path).convert('RGB')
                    detection_info = {'confidence': 0, 'fallback': True}
            else:
                face_img = Image.open(img_path).convert('RGB')
                detection_info = {'no_preprocessing': True}
            
            # Apply transforms
            if self.transform:
                image_tensor = self.transform(face_img)
            else:
                image_tensor = transforms.ToTensor()(face_img)
            
            # Get multi-label targets
            labels = torch.tensor([
                row['acne'],
                row['dark_circles'], 
                row['redness'],
                row['wrinkles']
            ], dtype=torch.float32)
            
            metadata = {
                'image_name': img_name,
                'source': row.get('source', 'unknown'),
                'detection_info': detection_info
            }
            
            return image_tensor, labels, metadata
        
        else:
            # Inference mode without labels
            img_name = self.image_files[idx]
            img_path = os.path.join(self.data_dir, img_name)
            
            # Load image
            image = Image.open(img_path).convert('RGB')
            
            # Apply transforms
            if self.transform:
                image_tensor = self.transform(image)
            else:
                image_tensor = transforms.ToTensor()(image)
            
            return image_tensor, img_name

def split_dataset_smart(df, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_state=42):
    """
    Smart dataset splitting that ensures balanced representation across concerns
    """
    print(f"📊 Splitting dataset: {train_ratio:.0%} train, {val_ratio:.0%} val, {test_ratio:.0%} test")
    
    # Shuffle the dataset
    df_shuffled = df.sample(frac=1, random_state=random_state).reset_index(drop=True)
    
    # Calculate split sizes
    n_total = len(df_shuffled)
    n_train = int(n_total * train_ratio)
    n_val = int(n_total * val_ratio)
    n_test = n_total - n_train - n_val
    
    # Split the data
    train_df = df_shuffled[:n_train].copy()
    val_df = df_shuffled[n_train:n_train + n_val].copy()
    test_df = df_shuffled[n_train + n_val:].copy()
    
    # Verify splits
    print(f"✅ Split completed:")
    print(f"   Train: {len(train_df)} images ({len(train_df)/n_total:.1%})")
    print(f"   Val:   {len(val_df)} images ({len(val_df)/n_total:.1%})")
    print(f"   Test:  {len(test_df)} images ({len(test_df)/n_total:.1%})")
    
    # Check concern distribution in each split
    concerns = ['acne', 'dark_circles', 'redness', 'wrinkles']
    
    print(f"\n📈 Concern distribution across splits:")
    for concern in concerns:
        train_count = train_df[concern].sum()
        val_count = val_df[concern].sum()
        test_count = test_df[concern].sum()
        total_count = train_count + val_count + test_count
        
        print(f"   {concern.replace('_', ' ').title()}:")
        print(f"     Train: {train_count} ({train_count/total_count:.1%})")
        print(f"     Val:   {val_count} ({val_count/total_count:.1%})")
        print(f"     Test:  {test_count} ({test_count/total_count:.1%})")
    
    return train_df, val_df, test_df

def save_dataset_splits(train_df, val_df, test_df, base_path='data/processed'):
    """
    Save dataset splits to CSV files
    """
    train_path = os.path.join(base_path, 'train_annotations.csv')
    val_path = os.path.join(base_path, 'val_annotations.csv') 
    test_path = os.path.join(base_path, 'test_annotations.csv')
    
    train_df.to_csv(train_path, index=False)
    val_df.to_csv(val_path, index=False)
    test_df.to_csv(test_path, index=False)
    
    print(f"💾 Dataset splits saved:")
    print(f"   Train: {train_path}")
    print(f"   Val:   {val_path}")
    print(f"   Test:  {test_path}")
    
    return train_path, val_path, test_path

# Prepare dataset splits (if combined_df exists from previous section)
if 'combined_df' in locals() and combined_df is not None:
    print("🔄 Preparing dataset splits...")
    
    # Split the dataset
    train_df, val_df, test_df = split_dataset_smart(
        combined_df, 
        train_ratio=0.7, 
        val_ratio=0.15, 
        test_ratio=0.15
    )
    
    # Save splits
    train_path, val_path, test_path = save_dataset_splits(train_df, val_df, test_df)
    
    # Create dataset instances
    print(f"\n🏗️ Creating dataset instances...")
    
    # Create transforms
    train_transform = create_data_transforms(train=True)
    val_transform = create_data_transforms(train=False)
    
    # Create datasets (without face preprocessing for now to test basic functionality)
    train_dataset = SkinConcernDataset(
        data_dir='data/processed',
        annotations_df=train_df,
        transform=train_transform,
        face_preprocessor=None  # We'll add this later for training
    )
    
    val_dataset = SkinConcernDataset(
        data_dir='data/processed',
        annotations_df=val_df,
        transform=val_transform,
        face_preprocessor=None
    )
    
    test_dataset = SkinConcernDataset(
        data_dir='data/processed',
        annotations_df=test_df,
        transform=val_transform,
        face_preprocessor=None
    )
    
    print(f"✅ Dataset instances created:")
    print(f"   Train dataset: {len(train_dataset)} samples")
    print(f"   Val dataset:   {len(val_dataset)} samples") 
    print(f"   Test dataset:  {len(test_dataset)} samples")
    
    # Test dataset loading
    print(f"\n🧪 Testing dataset loading...")
    
    # Test one sample from training set
    sample_img, sample_labels, sample_metadata = train_dataset[0]
    
    print(f"Sample image shape: {sample_img.shape}")
    print(f"Sample labels: {sample_labels}")
    print(f"Sample metadata: {sample_metadata}")
    
    # Visualize a few samples
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    concern_names = ['Acne', 'Dark Circles', 'Redness', 'Wrinkles'] 
    colors = ['red', 'blue', 'orange', 'purple']
    
    for i in range(8):
        row = i // 4
        col = i % 4
        
        # Get sample
        img_tensor, labels, metadata = train_dataset[i]
        
        # Convert tensor to displayable image
        img_display = img_tensor.permute(1, 2, 0)
        img_display = img_display * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])
        img_display = torch.clamp(img_display, 0, 1)
        
        axes[row, col].imshow(img_display)
        
        # Create title with detected concerns
        detected = [concern_names[j] for j, label in enumerate(labels) if label == 1]
        title = f"Sample {i+1}\\n{', '.join(detected) if detected else 'None'}"
        axes[row, col].set_title(title, fontsize=10)
        axes[row, col].axis('off')
    
    plt.suptitle('Training Dataset Samples', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig('outputs/visualizations/dataset_samples.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("✅ Dataset preparation complete!")
    
else:
    print("⚠️ Combined dataset not found. Please run the dataset download section first.")

## 🧠 Section 5: Model Architecture - ResNet18 Multi-Label Classifier

Let's implement our ResNet18-based multi-label classifier optimized for skin concern detection with Mac MPS acceleration.

In [None]:
class SkinConcernDetector(nn.Module):
    """
    ResNet18-based multi-label classifier for skin concern detection
    Optimized for Mac with MPS support
    """
    
    def __init__(self, num_classes=4, pretrained=True, dropout_rate=0.5):
        """
        Args:
            num_classes: Number of skin concerns (4: acne, dark_circles, redness, wrinkles)
            pretrained: Use ImageNet pretrained weights
            dropout_rate: Dropout rate for regularization
        """
        super(SkinConcernDetector, self).__init__()
        
        # Load pretrained ResNet18
        self.backbone = models.resnet18(pretrained=pretrained)
        
        # Get number of features from the last layer
        num_features = self.backbone.fc.in_features  # 512 for ResNet18
        
        # Replace the final fully connected layer with multi-label classifier
        self.backbone.fc = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(num_features, 256),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(256),
            nn.Dropout(dropout_rate),
            nn.Linear(256, num_classes)  # No activation here
        )
        
        # Sigmoid for multi-label classification (applied in forward)
        self.sigmoid = nn.Sigmoid()
        
        # Store config
        self.num_classes = num_classes
        self.concern_labels = ['acne', 'dark_circles', 'redness', 'wrinkles']
        
        print(f"🏗️ SkinConcernDetector initialized:")
        print(f"   Backbone: ResNet18 (pretrained: {pretrained})")
        print(f"   Input features: {num_features}")
        print(f"   Output classes: {num_classes}")
        print(f"   Dropout rate: {dropout_rate}")
    
    def forward(self, x):
        """
        Forward pass through the network
        
        Args:
            x: Input tensor [batch_size, 3, 224, 224]
            
        Returns:
            Probabilities for each concern [batch_size, num_classes]
        """
        # Get logits from backbone
        logits = self.backbone(x)
        
        # Apply sigmoid for multi-label probabilities
        probabilities = self.sigmoid(logits)
        
        return probabilities
    
    def get_features(self, x):
        """
        Extract intermediate features for GradCAM visualization
        
        Returns:
            Features from layer4 (before global average pooling)
        """
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        
        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        features = self.backbone.layer4(x)  # [batch_size, 512, 7, 7]
        
        return features
    
    def get_predictions(self, x, threshold=0.5):
        """
        Get binary predictions and confidence scores
        
        Args:
            x: Input tensor
            threshold: Classification threshold
            
        Returns:
            Dictionary with predictions and scores
        """
        with torch.no_grad():
            probabilities = self.forward(x)
            
            # Convert to binary predictions
            binary_preds = (probabilities > threshold).float()
            
            # Convert to lists for easier handling
            probs_list = probabilities.cpu().numpy()
            preds_list = binary_preds.cpu().numpy()
            
            results = {
                'probabilities': probs_list,
                'predictions': preds_list,
                'threshold': threshold
            }
            
            return results

class MultiLabelLoss(nn.Module):
    """
    Binary Cross Entropy Loss for multi-label classification with class weighting
    """
    
    def __init__(self, pos_weights=None, reduction='mean'):
        """
        Args:
            pos_weights: Tensor of positive class weights for handling imbalance
            reduction: Loss reduction method ('mean', 'sum', 'none')
        """
        super(MultiLabelLoss, self).__init__()
        
        if pos_weights is not None:
            self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights, reduction=reduction)
        else:
            self.criterion = nn.BCELoss(reduction=reduction)
        
        self.use_logits = pos_weights is not None
        
    def forward(self, predictions, targets):
        """
        Calculate multi-label loss
        
        Args:
            predictions: Model predictions [batch_size, num_classes]
            targets: Ground truth labels [batch_size, num_classes]
            
        Returns:
            Loss value
        """
        if self.use_logits:
            # predictions should be logits (before sigmoid)
            return self.criterion(predictions, targets)
        else:
            # predictions should be probabilities (after sigmoid) 
            return self.criterion(predictions, targets)

def calculate_class_weights(train_df, concerns=['acne', 'dark_circles', 'redness', 'wrinkles']):
    """
    Calculate class weights to handle imbalanced data
    """
    pos_counts = []
    neg_counts = []
    
    for concern in concerns:
        pos_count = train_df[concern].sum()
        neg_count = len(train_df) - pos_count
        pos_counts.append(pos_count)
        neg_counts.append(neg_count)
    
    # Calculate positive weights (higher weight for less frequent classes)
    pos_weights = []
    for pos_count, neg_count in zip(pos_counts, neg_counts):
        if pos_count > 0:
            weight = neg_count / pos_count
        else:
            weight = 1.0
        pos_weights.append(weight)
    
    pos_weights_tensor = torch.tensor(pos_weights, dtype=torch.float32)
    
    print(f"📊 Class weights calculated:")
    for i, (concern, weight) in enumerate(zip(concerns, pos_weights)):
        print(f"   {concern}: {weight:.2f} (pos: {pos_counts[i]}, neg: {neg_counts[i]})")
    
    return pos_weights_tensor

# Create model instance
print("🏗️ Creating SkinConcernDetector model...")

model = SkinConcernDetector(
    num_classes=4,
    pretrained=True,
    dropout_rate=0.5
).to(device)

# Calculate class weights if training data is available
if 'train_df' in locals():
    class_weights = calculate_class_weights(train_df)
    class_weights = class_weights.to(device)
    
    # Create loss function with class weights
    criterion = MultiLabelLoss(pos_weights=class_weights)
else:
    # Default loss without weights
    criterion = MultiLabelLoss()
    class_weights = None

# Move loss to device
criterion = criterion.to(device)

# Print model summary
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(model)
print(f"\n📈 Model Summary:")
print(f"   Total trainable parameters: {total_params:,}")
print(f"   Model size: ~{total_params * 4 / (1024**2):.1f} MB (FP32)")
print(f"   Device: {device}")

# Test forward pass
print(f"\n🧪 Testing forward pass...")
dummy_input = torch.randn(2, 3, 224, 224).to(device)

with torch.no_grad():
    dummy_output = model(dummy_input)
    print(f"Input shape: {dummy_input.shape}")
    print(f"Output shape: {dummy_output.shape}")
    print(f"Output range: [{dummy_output.min():.3f}, {dummy_output.max():.3f}]")

# Visualize model architecture
def visualize_model_architecture():
    """Create a simple visualization of the model architecture"""
    
    fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    
    # Define architecture components
    components = [
        ("Input Image", "3x224x224", "#FF9999"),
        ("ResNet18 Backbone", "Feature Extraction", "#66B2FF"),  
        ("Conv Layers", "64→128→256→512", "#99FF99"),
        ("Global Avg Pool", "512x7x7 → 512", "#FFCC99"),
        ("FC + Dropout", "512 → 256", "#FF99CC"),
        ("BatchNorm + ReLU", "Normalization", "#99FFCC"),
        ("Final FC", "256 → 4", "#CCCCFF"),
        ("Sigmoid", "Multi-label Output", "#FFCCCC"),
        ("Output", "4 Probabilities", "#CCFFCC")
    ]
    
    # Draw architecture
    y_positions = np.linspace(0.9, 0.1, len(components))
    
    for i, (name, desc, color) in enumerate(components):
        # Draw component box
        rect = patches.FancyBboxPatch(
            (0.1, y_positions[i] - 0.03), 0.8, 0.06,
            boxstyle="round,pad=0.01",
            facecolor=color,
            edgecolor='black',
            linewidth=1
        )
        ax.add_patch(rect)
        
        # Add text
        ax.text(0.15, y_positions[i], name, fontsize=12, fontweight='bold', va='center')
        ax.text(0.85, y_positions[i], desc, fontsize=10, va='center', ha='right')
        
        # Draw arrows (except for last component)
        if i < len(components) - 1:
            ax.arrow(0.5, y_positions[i] - 0.04, 0, -0.03, 
                    head_width=0.02, head_length=0.01, fc='black', ec='black')
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_title('SkinConcernDetector Architecture', fontsize=16, fontweight='bold', pad=20)
    ax.axis('off')
    
    # Add concern labels
    concern_text = "Output Classes: " + " | ".join(model.concern_labels)
    ax.text(0.5, 0.02, concern_text, ha='center', fontsize=11, 
           bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue"))
    
    plt.tight_layout()
    plt.savefig('outputs/visualizations/model_architecture.png', dpi=300, bbox_inches='tight')
    plt.show()

# Visualize the architecture
visualize_model_architecture()

print("✅ Model architecture created and tested successfully!")

## 🚀 Section 6: Training Pipeline with Mac Optimization

Now let's implement the complete training loop with Mac MPS optimization, learning rate scheduling, and comprehensive monitoring.

In [None]:
class SkinConcernTrainer:
    """
    Comprehensive training manager for SkinConcernDetector with Mac optimization
    """
    
    def __init__(self, model, train_loader, val_loader, criterion, device, config=None):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.device = device
        
        # Default training configuration
        self.config = config or {
            'learning_rate': 1e-4,
            'weight_decay': 1e-4,
            'patience': 5,
            'min_lr': 1e-7,
            'num_epochs': 25,
            'save_frequency': 5
        }
        
        # Initialize optimizer
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=self.config['learning_rate'],
            weight_decay=self.config['weight_decay']
        )
        
        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            patience=self.config['patience'],
            factor=0.5,
            min_lr=self.config['min_lr'],
            verbose=True
        )
        
        # Training history
        self.history = {
            'train_loss': [],
            'val_loss': [],
            'train_acc': [],
            'val_acc': [],
            'learning_rates': []
        }
        
        self.best_val_loss = float('inf')
        self.best_val_acc = 0.0
        self.epochs_without_improvement = 0
        
        print(f"🏋️ Trainer initialized:")
        print(f"   Optimizer: AdamW (lr={self.config['learning_rate']}, wd={self.config['weight_decay']})")
        print(f"   Scheduler: ReduceLROnPlateau (patience={self.config['patience']})")
        print(f"   Device: {device}")
    
    def train_epoch(self):
        """Train for one epoch"""
        self.model.train()
        running_loss = 0.0
        running_corrects = 0
        total_samples = 0
        
        # Progress bar
        pbar = tqdm(self.train_loader, desc='Training', leave=False)
        
        for batch_idx, (images, targets, metadata) in enumerate(pbar):
            images = images.to(self.device)
            targets = targets.to(self.device)
            
            # Zero gradients
            self.optimizer.zero_grad()
            
            # Forward pass
            outputs = self.model(images)
            loss = self.criterion(outputs, targets)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            # Optimizer step
            self.optimizer.step()
            
            # Statistics
            running_loss += loss.item() * images.size(0)
            
            # Calculate accuracy (percentage of correctly predicted labels)
            preds = (outputs > 0.5).float()
            batch_accuracy = (preds == targets).float().mean()
            running_corrects += batch_accuracy * images.size(0)
            total_samples += images.size(0)
            
            # Update progress bar
            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{batch_accuracy:.3f}'
            })\n        
        epoch_loss = running_loss / total_samples
        epoch_acc = running_corrects / total_samples
        
        return epoch_loss, epoch_acc.item()
    
    def validate_epoch(self):
        """Validate for one epoch"""
        self.model.eval()
        running_loss = 0.0
        running_corrects = 0
        total_samples = 0
        
        all_outputs = []
        all_targets = []
        
        with torch.no_grad():
            pbar = tqdm(self.val_loader, desc='Validation', leave=False)
            
            for images, targets, metadata in pbar:
                images = images.to(self.device)
                targets = targets.to(self.device)
                
                # Forward pass
                outputs = self.model(images)
                loss = self.criterion(outputs, targets)
                
                # Statistics
                running_loss += loss.item() * images.size(0)
                
                # Accuracy
                preds = (outputs > 0.5).float()
                batch_accuracy = (preds == targets).float().mean()
                running_corrects += batch_accuracy * images.size(0)
                total_samples += images.size(0)
                
                # Store for detailed metrics
                all_outputs.extend(outputs.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
                
                # Update progress bar
                pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'Acc': f'{batch_accuracy:.3f}'
                })
        
        epoch_loss = running_loss / total_samples
        epoch_acc = running_corrects / total_samples
        
        return epoch_loss, epoch_acc.item(), np.array(all_outputs), np.array(all_targets)
    
    def save_checkpoint(self, epoch, is_best=False):
        """Save model checkpoint"""
        os.makedirs('models/saved_weights', exist_ok=True)
        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_val_loss': self.best_val_loss,
            'best_val_acc': self.best_val_acc,
            'history': self.history,
            'config': self.config
        }
        
        # Save latest checkpoint
        latest_path = 'models/saved_weights/latest_checkpoint.pth'
        torch.save(checkpoint, latest_path)
        
        # Save best checkpoint
        if is_best:
            best_path = 'models/saved_weights/best_model.pth'
            torch.save(checkpoint, best_path)
            print(f"💾 Best model saved! Val Loss: {self.best_val_loss:.4f}, Val Acc: {self.best_val_acc:.3f}")
    
    def plot_training_progress(self):
        """Plot training progress"""
        if not self.history['train_loss']:
            return
        
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))
        
        epochs = range(1, len(self.history['train_loss']) + 1)
        
        # Loss plot
        axes[0].plot(epochs, self.history['train_loss'], 'b-', label='Training Loss', linewidth=2)
        axes[0].plot(epochs, self.history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
        axes[0].set_title('Training and Validation Loss', fontweight='bold')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # Accuracy plot
        axes[1].plot(epochs, self.history['train_acc'], 'b-', label='Training Accuracy', linewidth=2)
        axes[1].plot(epochs, self.history['val_acc'], 'r-', label='Validation Accuracy', linewidth=2)
        axes[1].set_title('Training and Validation Accuracy', fontweight='bold')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Accuracy')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        # Learning rate plot
        if self.history['learning_rates']:
            axes[2].semilogy(epochs, self.history['learning_rates'], 'g-', linewidth=2)
            axes[2].set_title('Learning Rate Schedule', fontweight='bold')
            axes[2].set_xlabel('Epoch')
            axes[2].set_ylabel('Learning Rate (log scale)')
            axes[2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('outputs/visualizations/training_progress.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def train(self, num_epochs=None):
        """Main training loop"""
        if num_epochs is None:
            num_epochs = self.config['num_epochs']
        
        print(f"🚀 Starting training for {num_epochs} epochs...")
        print(f"📊 Dataset sizes: Train={len(self.train_loader.dataset)}, Val={len(self.val_loader.dataset)}")
        
        start_time = time.time()
        
        for epoch in range(num_epochs):
            print(f"\\nEpoch {epoch+1}/{num_epochs}")
            print("-" * 50)
            
            # Train
            train_loss, train_acc = self.train_epoch()
            
            # Validate
            val_loss, val_acc, val_outputs, val_targets = self.validate_epoch()
            
            # Update learning rate
            self.scheduler.step(val_loss)
            current_lr = self.optimizer.param_groups[0]['lr']
            
            # Store history
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_acc'].append(val_acc)
            self.history['learning_rates'].append(current_lr)
            
            # Check for improvement
            improved = False
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.best_val_acc = val_acc
                self.epochs_without_improvement = 0
                improved = True
            else:
                self.epochs_without_improvement += 1
            
            # Print epoch results
            print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.3f}")
            print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.3f}")
            print(f"Learning Rate: {current_lr:.2e}")
            print(f"Best Val Loss: {self.best_val_loss:.4f} | Best Val Acc: {self.best_val_acc:.3f}")
            
            if improved:
                print("🎉 New best model!")
            
            # Save checkpoint
            if (epoch + 1) % self.config['save_frequency'] == 0 or improved:
                self.save_checkpoint(epoch, is_best=improved)
            
            # Early stopping check
            if self.epochs_without_improvement >= 8:
                print(f"\\n🛑 Early stopping after {epoch+1} epochs (no improvement for 8 epochs)")
                break
            
            # Plot progress every few epochs
            if (epoch + 1) % 5 == 0:
                self.plot_training_progress()
        
        training_time = time.time() - start_time
        print(f"\\n✅ Training completed in {training_time/60:.1f} minutes!")
        print(f"📈 Final Results:")
        print(f"   Best Validation Loss: {self.best_val_loss:.4f}")
        print(f"   Best Validation Accuracy: {self.best_val_acc:.3f}")
        
        # Final plot
        self.plot_training_progress()
        
        return self.history

# Training configuration optimized for Mac
training_config = {
    'learning_rate': 1e-4,      # Conservative for pretrained model
    'weight_decay': 1e-4,       # Regularization
    'patience': 3,              # LR scheduler patience  
    'min_lr': 1e-7,            # Minimum learning rate
    'num_epochs': 20,           # Reasonable for small dataset
    'save_frequency': 3         # Save every 3 epochs
}

# Create data loaders (if datasets exist)
if 'train_dataset' in locals() and 'val_dataset' in locals():
    print("🔄 Creating data loaders...")
    
    # Optimized batch sizes for Mac
    batch_size = 16 if device.type == 'mps' else 32
    num_workers = 2 if device.type == 'mps' else 4
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True if device.type != 'cpu' else False,
        persistent_workers=True if num_workers > 0 else False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True if device.type != 'cpu' else False,
        persistent_workers=True if num_workers > 0 else False
    )
    
    print(f"✅ Data loaders created:")
    print(f"   Batch size: {batch_size}")
    print(f"   Num workers: {num_workers}")
    print(f"   Train batches: {len(train_loader)}")
    print(f"   Val batches: {len(val_loader)}")
    
    # Initialize trainer
    trainer = SkinConcernTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        device=device,
        config=training_config
    )
    
    print(f"\\n🎯 Ready to start training! Run the next cell to begin.")
    
else:
    print("⚠️ Dataset not available. Please run previous sections to prepare data.")

# Training execution cell (separate for user control)
def start_training():
    """Start the training process"""
    if 'trainer' in locals() or 'trainer' in globals():
        print("🚀 Starting training process...")
        history = trainer.train()
        return history
    else:
        print("❌ Trainer not initialized. Please run the setup cells first.")
        return None

print("\\n" + "="*60)
print("🎯 TRAINING SETUP COMPLETE!")
print("="*60)
print("To start training, run: history = start_training()")
print("Expected training time on Mac M1/M2: ~15-30 minutes")
print("="*60)

## 📊 Section 7: Model Evaluation and Metrics Calculation

Let's evaluate our trained model and calculate comprehensive performance metrics for each skin concern.

In [None]:
class ModelEvaluator:
    """
    Comprehensive evaluation metrics for multi-label skin concern detection
    """
    
    def __init__(self, model, device, concern_labels=None):
        self.model = model
        self.device = device
        self.concern_labels = concern_labels or ['acne', 'dark_circles', 'redness', 'wrinkles']
    
    def evaluate_model(self, data_loader, threshold=0.5):
        """
        Comprehensive evaluation on a dataset
        """
        self.model.eval()
        
        all_outputs = []
        all_targets = []
        all_metadata = []
        
        print(f"🔍 Evaluating model on {len(data_loader.dataset)} samples...")
        
        with torch.no_grad():
            for images, targets, metadata in tqdm(data_loader, desc="Evaluating"):
                images = images.to(self.device)
                targets = targets.to(self.device)
                
                outputs = self.model(images)
                
                all_outputs.extend(outputs.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
                all_metadata.extend(metadata)
        
        # Convert to numpy arrays
        outputs_np = np.array(all_outputs)
        targets_np = np.array(all_targets)
        
        # Calculate metrics
        metrics = self.calculate_comprehensive_metrics(outputs_np, targets_np, threshold)
        
        return metrics, outputs_np, targets_np, all_metadata
    
    def calculate_comprehensive_metrics(self, outputs, targets, threshold=0.5):
        """
        Calculate comprehensive metrics for multi-label classification
        """
        # Binary predictions
        predictions = (outputs > threshold).astype(int)
        
        metrics = {}
        
        # Overall metrics
        overall_accuracy = (predictions == targets).mean()
        metrics['overall_accuracy'] = overall_accuracy
        
        # Per-class metrics
        per_class_metrics = {}
        
        for i, concern in enumerate(self.concern_labels):
            y_true = targets[:, i]
            y_pred = predictions[:, i]
            y_scores = outputs[:, i]
            
            # Basic metrics
            tp = np.sum((y_true == 1) & (y_pred == 1))
            tn = np.sum((y_true == 0) & (y_pred == 0))
            fp = np.sum((y_true == 0) & (y_pred == 1))
            fn = np.sum((y_true == 1) & (y_pred == 0))
            
            # Calculate metrics
            accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            
            # AUC-ROC
            try:
                from sklearn.metrics import roc_auc_score
                auc = roc_auc_score(y_true, y_scores)
            except:
                auc = 0.5
            
            per_class_metrics[concern] = {
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'specificity': specificity,
                'f1_score': f1,
                'auc_roc': auc,
                'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn,
                'support': np.sum(y_true)
            }
        
        metrics['per_class'] = per_class_metrics
        
        # Macro averages
        metrics['macro_precision'] = np.mean([m['precision'] for m in per_class_metrics.values()])
        metrics['macro_recall'] = np.mean([m['recall'] for m in per_class_metrics.values()])
        metrics['macro_f1'] = np.mean([m['f1_score'] for m in per_class_metrics.values()])
        metrics['macro_auc'] = np.mean([m['auc_roc'] for m in per_class_metrics.values()])
        
        return metrics
    
    def plot_confusion_matrices(self, outputs, targets, threshold=0.5):
        """
        Plot confusion matrices for each concern
        """
        predictions = (outputs > threshold).astype(int)
        
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        axes = axes.flatten()
        
        for i, concern in enumerate(self.concern_labels):
            y_true = targets[:, i]
            y_pred = predictions[:, i]
            
            # Create confusion matrix
            cm = confusion_matrix(y_true, y_pred)
            
            # Plot
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                       xticklabels=['No ' + concern, concern.title()],
                       yticklabels=['No ' + concern, concern.title()],
                       ax=axes[i])
            
            axes[i].set_title(f'{concern.replace("_", " ").title()} Confusion Matrix')
            axes[i].set_xlabel('Predicted')
            axes[i].set_ylabel('Actual')
        
        plt.tight_layout()
        plt.savefig('outputs/visualizations/confusion_matrices.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_metrics_summary(self, metrics):
        """
        Plot comprehensive metrics summary
        """
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        concerns = list(metrics['per_class'].keys())
        
        # Accuracy by concern
        accuracies = [metrics['per_class'][c]['accuracy'] for c in concerns]
        axes[0,0].bar(concerns, accuracies, color='skyblue', alpha=0.8)
        axes[0,0].set_title('Accuracy by Concern', fontweight='bold')
        axes[0,0].set_ylabel('Accuracy')
        axes[0,0].set_ylim(0, 1)
        for i, v in enumerate(accuracies):
            axes[0,0].text(i, v + 0.01, f'{v:.3f}', ha='center', fontweight='bold')
        
        # Precision, Recall, F1
        precisions = [metrics['per_class'][c]['precision'] for c in concerns]
        recalls = [metrics['per_class'][c]['recall'] for c in concerns]
        f1_scores = [metrics['per_class'][c]['f1_score'] for c in concerns]
        
        x = np.arange(len(concerns))
        width = 0.25
        
        axes[0,1].bar(x - width, precisions, width, label='Precision', alpha=0.8)
        axes[0,1].bar(x, recalls, width, label='Recall', alpha=0.8)
        axes[0,1].bar(x + width, f1_scores, width, label='F1-Score', alpha=0.8)
        
        axes[0,1].set_title('Precision, Recall, F1-Score by Concern', fontweight='bold')
        axes[0,1].set_ylabel('Score')
        axes[0,1].set_xticks(x)
        axes[0,1].set_xticklabels([c.replace('_', ' ').title() for c in concerns])
        axes[0,1].legend()
        axes[0,1].set_ylim(0, 1)
        
        # AUC-ROC
        aucs = [metrics['per_class'][c]['auc_roc'] for c in concerns]
        axes[1,0].bar(concerns, aucs, color='lightcoral', alpha=0.8)
        axes[1,0].set_title('AUC-ROC by Concern', fontweight='bold')
        axes[1,0].set_ylabel('AUC-ROC')
        axes[1,0].set_ylim(0, 1)
        axes[1,0].axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='Random')
        for i, v in enumerate(aucs):
            axes[1,0].text(i, v + 0.01, f'{v:.3f}', ha='center', fontweight='bold')
        
        # Support (number of positive samples)
        supports = [metrics['per_class'][c]['support'] for c in concerns]
        axes[1,1].bar(concerns, supports, color='lightgreen', alpha=0.8)
        axes[1,1].set_title('Positive Samples by Concern', fontweight='bold')
        axes[1,1].set_ylabel('Count')
        for i, v in enumerate(supports):
            axes[1,1].text(i, v + 1, str(v), ha='center', fontweight='bold')
        
        plt.tight_layout()
        plt.savefig('outputs/visualizations/metrics_summary.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def print_evaluation_report(self, metrics):
        """
        Print comprehensive evaluation report
        """
        print("\\n" + "="*80)
        print("🎯 COMPREHENSIVE EVALUATION REPORT")
        print("="*80)
        
        print(f"\\n📊 Overall Performance:")
        print(f"   Overall Accuracy: {metrics['overall_accuracy']:.3f}")
        print(f"   Macro Precision:  {metrics['macro_precision']:.3f}")
        print(f"   Macro Recall:     {metrics['macro_recall']:.3f}")
        print(f"   Macro F1-Score:   {metrics['macro_f1']:.3f}")
        print(f"   Macro AUC-ROC:    {metrics['macro_auc']:.3f}")
        
        print(f"\\n📋 Per-Class Performance:")
        print(f"{'Concern':<15} {'Acc':<6} {'Prec':<6} {'Rec':<6} {'F1':<6} {'AUC':<6} {'Supp':<6}")
        print("-" * 65)
        
        for concern, metrics_dict in metrics['per_class'].items():
            print(f"{concern.replace('_', ' ').title():<15} "
                  f"{metrics_dict['accuracy']:<6.3f} "
                  f"{metrics_dict['precision']:<6.3f} "
                  f"{metrics_dict['recall']:<6.3f} "
                  f"{metrics_dict['f1_score']:<6.3f} "
                  f"{metrics_dict['auc_roc']:<6.3f} "
                  f"{metrics_dict['support']:<6}")
        
        print("\\n" + "="*80)

# Load best model if available
def load_best_model():
    """Load the best trained model"""
    model_path = 'models/saved_weights/best_model.pth'
    
    if os.path.exists(model_path):
        print(f"📥 Loading best model from {model_path}...")
        checkpoint = torch.load(model_path, map_location=device)
        
        # Load model state
        model.load_state_dict(checkpoint['model_state_dict'])
        
        print(f"✅ Model loaded successfully!")
        print(f"   Trained for {checkpoint['epoch']+1} epochs")
        print(f"   Best validation loss: {checkpoint['best_val_loss']:.4f}")
        print(f"   Best validation accuracy: {checkpoint['best_val_acc']:.3f}")
        
        return True
    else:
        print(f"⚠️ No trained model found at {model_path}")
        print(f"   Please train the model first or check the path.")
        return False

# Evaluation execution
def run_comprehensive_evaluation():
    """Run comprehensive model evaluation"""
    
    # Load best model
    if not load_best_model():
        return None
    
    # Create evaluator
    evaluator = ModelEvaluator(model, device, model.concern_labels)
    
    # Evaluate on validation set
    if 'val_loader' in locals() or 'val_loader' in globals():
        print(f"\\n🔍 Evaluating on validation set...")
        val_metrics, val_outputs, val_targets, val_metadata = evaluator.evaluate_model(val_loader)
        
        # Print report
        evaluator.print_evaluation_report(val_metrics)
        
        # Plot visualizations
        print(f"\\n📊 Creating evaluation visualizations...")
        evaluator.plot_confusion_matrices(val_outputs, val_targets)
        evaluator.plot_metrics_summary(val_metrics)
        
        return val_metrics, evaluator
    else:
        print("⚠️ Validation data loader not available")
        return None, evaluator

print("\\n" + "="*60)
print("📊 EVALUATION SETUP COMPLETE!")
print("="*60)
print("To run comprehensive evaluation, execute:")
print("val_metrics, evaluator = run_comprehensive_evaluation()")
print("="*60)

## 🔥 Section 8: GradCAM Implementation for Explainable AI

Now let's implement GradCAM (Gradient-weighted Class Activation Mapping) to visualize which facial regions influence each prediction, providing explainable AI capabilities.

In [None]:
class MultiLabelGradCAM:
    """
    GradCAM implementation for multi-label skin concern classification
    """
    
    def __init__(self, model, target_layer_name='layer4'):
        """
        Initialize GradCAM
        
        Args:
            model: Trained SkinConcernDetector model
            target_layer_name: Name of the target layer for gradient extraction
        """
        self.model = model
        self.model.eval()
        
        # Get the target layer (ResNet18's layer4)
        self.target_layer = getattr(self.model.backbone, target_layer_name)
        
        # Hooks for storing gradients and activations
        self.gradients = {}
        self.activations = {}
        
        # Register hooks
        self.target_layer.register_forward_hook(self._save_activation)
        self.target_layer.register_full_backward_hook(self._save_gradient)
        
        self.concern_labels = model.concern_labels
        
        print(f"🔥 GradCAM initialized for {target_layer_name}")
        print(f"   Target layer output shape: Will be determined during forward pass")
        print(f"   Concerns: {', '.join(self.concern_labels)}")
    
    def _save_activation(self, module, input, output):
        """Hook to save forward activations"""
        self.activations['value'] = output.detach()
    
    def _save_gradient(self, module, grad_input, grad_output):
        """Hook to save backward gradients"""
        self.gradients['value'] = grad_output[0].detach()
    
    def generate_cam(self, input_tensor, target_class_idx):
        """
        Generate Class Activation Map for a specific concern
        
        Args:
            input_tensor: Input image tensor [1, 3, 224, 224]
            target_class_idx: Index of target concern (0-3)
            
        Returns:
            CAM heatmap as numpy array [H, W]
        """
        # Forward pass
        self.model.zero_grad()
        output = self.model(input_tensor)
        
        # Get the score for target class
        class_score = output[0, target_class_idx]
        
        # Backward pass
        class_score.backward(retain_graph=True)
        
        # Get gradients and activations
        gradients = self.gradients['value'][0]  # [C, H, W]
        activations = self.activations['value'][0]  # [C, H, W]
        
        # Global average pooling of gradients to get weights
        weights = torch.mean(gradients, dim=(1, 2))  # [C]
        
        # Weighted combination of activation maps
        cam = torch.zeros(activations.shape[1:], dtype=torch.float32)  # [H, W]
        
        for i, weight in enumerate(weights):
            cam += weight * activations[i]
        
        # Apply ReLU (only positive influences)
        cam = F.relu(cam)
        
        # Normalize CAM to [0, 1]
        if cam.max() > 0:
            cam = (cam - cam.min()) / (cam.max() - cam.min())
        
        return cam.cpu().numpy()
    
    def generate_all_cams(self, input_tensor, threshold=0.5):
        """
        Generate CAMs for all concerns with predictions and scores
        
        Args:
            input_tensor: Input image tensor [1, 3, 224, 224]
            threshold: Classification threshold
            
        Returns:
            Dictionary with CAMs, scores, and predictions for each concern
        """
        # Get predictions
        with torch.no_grad():
            predictions = self.model(input_tensor)
            scores = predictions[0].cpu().numpy()
        
        results = {}
        
        for i, concern in enumerate(self.concern_labels):
            # Generate CAM for this concern
            cam = self.generate_cam(input_tensor, i)
            
            # Store results
            results[concern] = {
                'cam': cam,
                'score': float(scores[i]),
                'predicted': scores[i] > threshold,
                'confidence_level': 'High' if abs(scores[i] - 0.5) > 0.3 else 'Medium' if abs(scores[i] - 0.5) > 0.1 else 'Low'
            }
        
        return results
    
    def apply_colormap_and_overlay(self, original_image, cam, alpha=0.4, colormap=cv2.COLORMAP_JET):
        """
        Apply colormap to CAM and overlay on original image
        
        Args:
            original_image: Original PIL image
            cam: CAM heatmap [H, W]
            alpha: Overlay transparency
            colormap: OpenCV colormap
            
        Returns:
            Overlayed image as numpy array
        """
        # Resize CAM to match image size
        h, w = original_image.size[1], original_image.size[0]  # PIL size is (width, height)
        cam_resized = cv2.resize(cam, (w, h))
        
        # Apply colormap
        cam_colored = cv2.applyColorMap(np.uint8(255 * cam_resized), colormap)
        cam_colored = cv2.cvtColor(cam_colored, cv2.COLOR_BGR2RGB)
        
        # Convert original image to numpy
        original_np = np.array(original_image)
        
        # Overlay
        overlayed = cam_colored * alpha + original_np * (1 - alpha)
        
        return np.uint8(overlayed)
    
    def visualize_all_concerns(self, original_image, input_tensor, threshold=0.5, save_path=None):
        """
        Create comprehensive visualization for all concerns
        
        Args:
            original_image: Original PIL image
            input_tensor: Preprocessed input tensor
            threshold: Classification threshold
            save_path: Optional path to save visualization
            
        Returns:
            Matplotlib figure
        """
        # Generate CAMs
        results = self.generate_all_cams(input_tensor, threshold)
        
        # Create visualization
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        # Original image
        axes[0, 0].imshow(original_image)
        axes[0, 0].set_title('Original Image', fontsize=14, fontweight='bold')
        axes[0, 0].axis('off')
        
        # Summary of predictions
        detected_concerns = [concern for concern, data in results.items() if data['predicted']]
        summary_text = f"Detected Concerns: {len(detected_concerns)}\\n"
        
        for concern, data in results.items():
            status = "✅ DETECTED" if data['predicted'] else "❌ Not Detected"
            confidence = data['confidence_level']
            summary_text += f"{concern.replace('_', ' ').title()}: {data['score']:.1%} ({confidence}) {status}\\n"
        
        axes[0, 1].text(0.1, 0.5, summary_text, fontsize=11, verticalalignment='center',
                        bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue"))
        axes[0, 1].set_xlim(0, 1)
        axes[0, 1].set_ylim(0, 1)
        axes[0, 1].set_title('Prediction Summary', fontsize=14, fontweight='bold')
        axes[0, 1].axis('off')
        
        # Individual concern visualizations
        positions = [(0, 2), (1, 0), (1, 1), (1, 2)]
        colors = ['red', 'blue', 'orange', 'purple']
        
        for idx, (concern, data) in enumerate(results.items()):
            row, col = positions[idx]
            
            # Create overlay
            overlay = self.apply_colormap_and_overlay(
                original_image, 
                data['cam'], 
                alpha=0.4,
                colormap=cv2.COLORMAP_JET
            )
            
            axes[row, col].imshow(overlay)
            
            # Title with prediction info
            title_color = 'green' if data['predicted'] else 'red'
            confidence_emoji = "🔥" if data['confidence_level'] == 'High' else "🟡" if data['confidence_level'] == 'Medium' else "🔵"
            
            title = f"{confidence_emoji} {concern.replace('_', ' ').title()}\\n{data['score']:.1%} ({data['confidence_level']})"
            
            axes[row, col].set_title(title, fontsize=12, fontweight='bold', color=title_color)
            axes[row, col].axis('off')
        
        plt.suptitle('Face Concern Detection with GradCAM Visualization', 
                    fontsize=16, fontweight='bold', y=0.95)
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"💾 Visualization saved to {save_path}")
        
        return fig, results

class FaceConcernPredictor:
    """
    Complete end-to-end predictor with face detection, classification, and visualization
    """
    
    def __init__(self, model_path, device, face_preprocessor=None):
        """
        Initialize predictor with trained model
        """
        self.device = device
        
        # Load model
        self.model = SkinConcernDetector(num_classes=4).to(device)
        
        if os.path.exists(model_path):
            checkpoint = torch.load(model_path, map_location=device)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            print(f"✅ Model loaded from {model_path}")
        else:
            print(f"⚠️ Model file not found: {model_path}")
            print("   Using randomly initialized weights")
        
        self.model.eval()
        
        # Face preprocessor
        self.face_preprocessor = face_preprocessor or FacePreprocessor()
        
        # Data transform (same as validation)
        self.transform = create_data_transforms(train=False)
        
        # GradCAM
        self.gradcam = MultiLabelGradCAM(self.model)
        
        print(f"🎯 FaceConcernPredictor initialized and ready!")
    
    def predict_single_image(self, image_path, threshold=0.5, visualize=True, save_path=None):
        """
        Complete prediction pipeline for a single image
        
        Args:
            image_path: Path to image file
            threshold: Classification threshold
            visualize: Whether to create GradCAM visualization
            save_path: Optional path to save results
            
        Returns:
            Dictionary with predictions and visualizations
        """
        try:
            # Load and preprocess image
            original_image = Image.open(image_path).convert('RGB')
            
            # Face detection and preprocessing
            face_image, detection_info = self.face_preprocessor.preprocess_image(original_image)
            
            if face_image is None:
                return {
                    'error': f'Face detection failed: {detection_info}',
                    'image_path': image_path
                }
            
            # Prepare input tensor
            input_tensor = self.transform(face_image).unsqueeze(0).to(self.device)
            
            # Get predictions
            with torch.no_grad():
                predictions = self.model(input_tensor)
                scores = predictions[0].cpu().numpy()
            
            # Prepare results
            results = {
                'image_path': image_path,
                'face_detection': detection_info,
                'predictions': {},
                'detected_concerns': []
            }
            
            for i, concern in enumerate(self.model.concern_labels):
                score = float(scores[i])
                predicted = score > threshold
                
                results['predictions'][concern] = {
                    'score': score,
                    'predicted': predicted,
                    'confidence': score
                }
                
                if predicted:
                    results['detected_concerns'].append(concern)
            
            # Generate visualization if requested
            if visualize:
                fig, gradcam_results = self.gradcam.visualize_all_concerns(
                    face_image, input_tensor, threshold, save_path
                )
                
                results['visualization'] = fig
                results['gradcam_results'] = gradcam_results
            
            return results
            
        except Exception as e:
            return {
                'error': f'Prediction failed: {str(e)}',
                'image_path': image_path
            }

# Test GradCAM functionality
def test_gradcam():
    """Test GradCAM on a sample image"""
    
    # Create sample input
    sample_input = torch.randn(1, 3, 224, 224).to(device)
    
    # Initialize GradCAM
    gradcam = MultiLabelGradCAM(model)
    
    print("🧪 Testing GradCAM functionality...")
    
    # Test CAM generation
    try:
        cam = gradcam.generate_cam(sample_input, target_class_idx=0)
        print(f"✅ CAM generation successful!")
        print(f"   CAM shape: {cam.shape}")
        print(f"   CAM range: [{cam.min():.3f}, {cam.max():.3f}]")
        
        # Test all CAMs
        results = gradcam.generate_all_cams(sample_input)
        print(f"✅ Multi-label CAM generation successful!")
        print(f"   Generated CAMs for: {list(results.keys())}")
        
        return True
        
    except Exception as e:
        print(f"❌ GradCAM test failed: {e}")
        return False

# Test the implementation
print("🔥 Testing GradCAM implementation...")
if test_gradcam():
    print("\\n" + "="*60)
    print("🎯 GRADCAM SETUP COMPLETE!")
    print("="*60)
    print("GradCAM is ready for generating explainable visualizations!")
    print("Use the FaceConcernPredictor class for end-to-end inference.")
    print("="*60)
else:
    print("❌ GradCAM setup failed. Please check the model and try again.")

## 🚀 Section 9: Inference Pipeline with Visual Overlays

Let's create the complete inference pipeline that processes new images and generates confidence scores with semi-transparent overlay masks.

In [None]:
# Initialize the predictor
model_path = 'models/saved_weights/best_model.pth'

# Create predictor instance
if os.path.exists(model_path):
    predictor = FaceConcernPredictor(
        model_path=model_path,
        device=device,
        face_preprocessor=face_preprocessor
    )
    print("✅ Predictor initialized with trained model")
else:
    # Use untrained model for demonstration
    predictor = FaceConcernPredictor(
        model_path='',  # Empty path will use random weights
        device=device,
        face_preprocessor=face_preprocessor
    )
    print("⚠️ Using untrained model for demonstration")

def create_sample_test_images():
    """
    Create sample test images from our dataset for demonstration
    """
    sample_dir = 'data/sample_images'
    os.makedirs(sample_dir, exist_ok=True)
    
    # Copy a few sample images from processed dataset
    processed_dir = 'data/processed'
    
    if os.path.exists(processed_dir):
        # Get some sample images
        all_images = [f for f in os.listdir(processed_dir) 
                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        
        # Sample images from different categories
        sample_images = []
        
        # Try to get diverse samples
        for prefix in ['ds1_acne', 'ds1_wrinkles', 'ds1_spots', 'ds2_acne', 'ds2_redness']:
            matching = [img for img in all_images if img.startswith(prefix)]
            if matching:
                sample_images.append(matching[0])
        
        # Add a few more random samples
        remaining = [img for img in all_images if img not in sample_images]
        sample_images.extend(remaining[:3])
        
        # Copy to sample directory
        copied_samples = []
        for i, img in enumerate(sample_images[:6]):  # Limit to 6 samples
            src_path = os.path.join(processed_dir, img)
            dst_path = os.path.join(sample_dir, f'sample_{i+1}_{img}')
            
            if os.path.exists(src_path):
                shutil.copy2(src_path, dst_path)
                copied_samples.append(dst_path)
        
        print(f"📸 Created {len(copied_samples)} sample test images in {sample_dir}")
        return copied_samples
    else:
        print("⚠️ No processed images available for sampling")
        return []

def batch_inference_demo(image_paths, threshold=0.5):
    """
    Demonstrate batch inference on multiple images
    """
    print(f"🔍 Running batch inference on {len(image_paths)} images...")
    
    results_summary = []
    
    # Process each image
    for i, img_path in enumerate(image_paths):
        print(f"\\nProcessing image {i+1}/{len(image_paths)}: {os.path.basename(img_path)}")
        
        # Run prediction
        result = predictor.predict_single_image(
            img_path, 
            threshold=threshold, 
            visualize=True,
            save_path=f'outputs/visualizations/inference_result_{i+1}.png'
        )
        
        if 'error' in result:
            print(f"❌ {result['error']}")
            continue
        
        # Summary
        detected = result['detected_concerns']
        total_score = sum(data['score'] for data in result['predictions'].values())
        
        summary = {
            'image': os.path.basename(img_path),
            'detected_concerns': len(detected),
            'concerns_list': detected,
            'average_confidence': total_score / 4,
            'face_confidence': result['face_detection']['confidence']
        }
        
        results_summary.append(summary)
        
        print(f"✅ Detected {len(detected)} concerns: {', '.join(detected)}")
        print(f"   Average confidence: {summary['average_confidence']:.1%}")
        print(f"   Face detection confidence: {summary['face_confidence']:.2f}")
    
    return results_summary

def create_batch_summary_visualization(results_summary):
    """
    Create summary visualization of batch results
    """
    if not results_summary:
        print("No results to visualize")
        return
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Number of detected concerns per image
    images = [r['image'] for r in results_summary]
    concern_counts = [r['detected_concerns'] for r in results_summary]
    
    axes[0,0].bar(range(len(images)), concern_counts, color='skyblue', alpha=0.8)
    axes[0,0].set_title('Detected Concerns per Image', fontweight='bold')
    axes[0,0].set_ylabel('Number of Concerns')
    axes[0,0].set_xticks(range(len(images)))
    axes[0,0].set_xticklabels([f'Img {i+1}' for i in range(len(images))], rotation=45)
    
    for i, v in enumerate(concern_counts):
        axes[0,0].text(i, v + 0.1, str(v), ha='center', fontweight='bold')
    
    # Average confidence scores
    avg_confidences = [r['average_confidence'] for r in results_summary]
    axes[0,1].bar(range(len(images)), avg_confidences, color='lightcoral', alpha=0.8)
    axes[0,1].set_title('Average Confidence Scores', fontweight='bold')
    axes[0,1].set_ylabel('Confidence')
    axes[0,1].set_ylim(0, 1)
    axes[0,1].set_xticks(range(len(images)))
    axes[0,1].set_xticklabels([f'Img {i+1}' for i in range(len(images))], rotation=45)
    
    # Face detection confidence
    face_confidences = [r['face_confidence'] for r in results_summary]
    axes[1,0].bar(range(len(images)), face_confidences, color='lightgreen', alpha=0.8)
    axes[1,0].set_title('Face Detection Confidence', fontweight='bold')
    axes[1,0].set_ylabel('Confidence')
    axes[1,0].set_ylim(0, 1)
    axes[1,0].set_xticks(range(len(images)))
    axes[1,0].set_xticklabels([f'Img {i+1}' for i in range(len(images))], rotation=45)
    
    # Concern distribution across all images
    all_concerns = []
    for r in results_summary:
        all_concerns.extend(r['concerns_list'])
    
    from collections import Counter
    concern_counts = Counter(all_concerns)
    
    if concern_counts:
        concerns = list(concern_counts.keys())
        counts = list(concern_counts.values())
        
        axes[1,1].bar(concerns, counts, color='gold', alpha=0.8)
        axes[1,1].set_title('Overall Concern Distribution', fontweight='bold')
        axes[1,1].set_ylabel('Frequency')
        axes[1,1].tick_params(axis='x', rotation=45)
        
        for i, v in enumerate(counts):
            axes[1,1].text(i, v + 0.1, str(v), ha='center', fontweight='bold')
    else:
        axes[1,1].text(0.5, 0.5, 'No concerns detected\\nacross all images', 
                      ha='center', va='center', fontsize=12)
        axes[1,1].set_xlim(0, 1)
        axes[1,1].set_ylim(0, 1)
        axes[1,1].set_title('Overall Concern Distribution', fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('outputs/visualizations/batch_inference_summary.png', dpi=300, bbox_inches='tight')
    plt.show()

def interactive_single_prediction():
    """
    Interactive function for single image prediction
    """
    print("🎯 Interactive Single Image Prediction")
    print("=" * 50)
    
    # Get available sample images
    sample_dir = 'data/sample_images'
    if os.path.exists(sample_dir):
        sample_images = [f for f in os.listdir(sample_dir) 
                        if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        
        if sample_images:
            print(f"Available sample images:")
            for i, img in enumerate(sample_images):
                print(f"  {i+1}. {img}")
            
            try:
                choice = int(input(f"\\nSelect image (1-{len(sample_images)}): ")) - 1
                if 0 <= choice < len(sample_images):
                    selected_image = os.path.join(sample_dir, sample_images[choice])
                    
                    threshold = float(input("Enter threshold (0.0-1.0, default 0.5): ") or "0.5")
                    
                    print(f"\\n🔍 Processing {sample_images[choice]}...")
                    
                    result = predictor.predict_single_image(
                        selected_image,
                        threshold=threshold,
                        visualize=True,
                        save_path='outputs/visualizations/interactive_result.png'
                    )
                    
                    if 'error' in result:
                        print(f"❌ {result['error']}")
                        return
                    
                    # Display results
                    print(f"\\n🎉 Results for {sample_images[choice]}:")
                    print(f"Face Detection Confidence: {result['face_detection']['confidence']:.2f}")
                    print(f"\\nSkin Concern Predictions:")
                    
                    for concern, data in result['predictions'].items():
                        status = "✅ DETECTED" if data['predicted'] else "❌ Not Detected"
                        print(f"  {concern.replace('_', ' ').title()}: {data['score']:.1%} {status}")
                    
                    if result['detected_concerns']:
                        print(f"\\n🔥 Summary: Detected {len(result['detected_concerns'])} concerns")
                        print(f"Concerns: {', '.join(result['detected_concerns'])}")
                    else:
                        print(f"\\n✨ Summary: No skin concerns detected above threshold")
                    
                    plt.show()  # Show the visualization
                    
                else:
                    print("Invalid selection")
                    
            except ValueError:
                print("Invalid input")
        else:
            print("No sample images available")
    else:
        print("Sample images directory not found")

# Create sample images for testing
sample_paths = create_sample_test_images()

print("\\n" + "="*60)
print("🚀 INFERENCE PIPELINE READY!")
print("="*60)
print("Available functions:")
print("1. batch_inference_demo(sample_paths) - Run batch inference")
print("2. interactive_single_prediction() - Interactive single prediction")
print("3. predictor.predict_single_image(path) - Direct prediction")
print("="*60)

# Demo execution
if sample_paths:
    print(f"\\n🎯 Running batch inference demo on {len(sample_paths)} sample images...")
    
    # Run batch inference
    results = batch_inference_demo(sample_paths, threshold=0.5)
    
    # Create summary visualization
    if results:
        print(f"\\n📊 Creating batch summary visualization...")
        create_batch_summary_visualization(results)
        
        print(f"\\n✅ Batch inference complete!")
        print(f"   Processed: {len(results)} images")
        print(f"   Average concerns per image: {np.mean([r['detected_concerns'] for r in results]):.1f}")
        print(f"   Results saved to: outputs/visualizations/")
    
else:
    print("\\n⚠️ No sample images available for demonstration")
    print("Please ensure datasets are downloaded and processed first")

## 🎮 Section 10: Interactive Demo & Real-Time Analysis

This final section provides an interactive demonstration of the complete face concern detection system with various testing scenarios and real-time capabilities.