In [None]:
# ============================================================================
# CONFIGURATION - Adjust these parameters as needed
# ============================================================================
# Model: MaxViT-Tiny (maxvit_tiny_tf_512.in1k)
# Pretrained weights: pretrained_models/model_tiny.pth
# Encoder channels: [64, 64, 128, 256, 512]
# ============================================================================

class Config:
    # Dataset parameters
    IMG_SIZE = (512, 512)  # Image dimensions (height, width)
    TRAIN_VAL_SPLIT = 0.9  # Train/validation split ratio
    
    # DataLoader parameters
    BATCH_SIZE = 8  # Batch size for training
    NUM_WORKERS = 4  # Set to 0 for Windows, increase for Linux/Mac
    PIN_MEMORY = True  # Set to True if using GPU
    
    # Model architecture parameters
    
    IN_CHANNELS = 3  # RGB input channels
    OUT_CHANNELS = 1  # Binary segmentation output
    CBAM_REDUCTION = 16  # Reduction factor for CBAM attention modules
    SKIP_PRETRAINED = False  # If True, skip loading pretrained weights (train from scratch or use checkpoint only)
    FREEZE_ENCODER = False  # If True, freeze encoder layers (use pretrained embeddings, save compute)
    FREEZE_ENCODER_STAGES = [1,2,3]  # List of encoder stages to freeze: [1,2,3,4] or [1,2] etc. None = use FREEZE_ENCODER (freeze all or none)
    # Stage mapping: 1=stem+stage1, 2=stage2, 3=stage3, 4=stage4
    # Examples: [1,2] freezes first 2 stages, [1,2,3,4] freezes all, None uses FREEZE_ENCODER flag
    
    # Model Architecture Options
    USE_DEEP_SUPERVISION = False  # Enable deep supervision with auxiliary outputs
    USE_GRADIENT_CHECKPOINTING = False  # Enable gradient checkpointing to save memory (trades compute for VRAM)
    USE_ATROUS_PYRAMID_BOTTLENECK = False  # If True, use atrous (dilated) pyramid bottleneck (ASPP-style) instead of standard bottleneck
    USE_IDENTITY_BOTTLENECK = True # If True, use identity bottleneck (no processing, just pass through)
    
    # Training parameters
    NUM_EPOCHS = 40
    LEARNING_RATE = 4e-5
    OPTIMIZER = 'adam'  # 'adam', 'adamw', or 'sgd'
    WEIGHT_DECAY = 0  # L2 regularization weight
    
    # Loss function parameters
    FOCAL_WEIGHT = 0.95  # Weight for Focal loss (segmentation)
    USE_DICE_LOSS = False # Enable/disable Dice loss (if False, only Focal loss is used)
    DICE_WEIGHT = 0.05  # Weight for Dice loss (segmentation) - only used if USE_DICE_LOSS=True
    DICE_SMOOTH = 1e-6  # Smoothing factor for Dice loss (avoid division by zero)
    FOCAL_ALPHA = 0.33  # Alpha parameter for Focal loss (class balancing)
    FOCAL_GAMMA = 2.0  # Gamma parameter for Focal loss (focusing parameter)
    
    # Early stopping parameters
    PATIENCE = 99  # Number of epochs to wait before early stopping
    MIN_DELTA = 0.0  # Minimum change to qualify as improvement
    
    # Output directories
    OUTPUT_DIR = 'outputs_unet'  # Main output directory
    VIZ_DIR = 'outputs_unet/visualizations'  # Directory to save visualizations
    CHECKPOINT_DIR = 'outputs_unet/checkpoints'  # Directory to save model checkpoints
    
    # Visualization parameters
    NUM_VIZ_IMAGES = 4  # Number of images to visualize per epoch
    
    # Model saving parameters
    MODEL_SAVE_NAME = 'MaxVit-square-Unet-fraud-best.pth'  # Best model checkpoint name
    MODEL_COMPLETE_NAME = 'MaxVit-square-Unet-tiny-fraud-complete.pth'  # Complete model save name
    
    # Dataset paths
    DATASET_ROOT = 'combined_dataset'  # Root folder containing subfolders (casia_copymove, defacto_copymove, science-fraud_copymove)
    # Each subfolder should have: train/ and test/ subfolders, each with images/ (PNG files) and masks/ (NPY files)
    # Structure: combined_dataset/subfolder/train/images/, combined_dataset/subfolder/test/images/, etc.
    
    # Dataset selection options
    # List of dataset subfolder names to include (None = include all available datasets)
    # Examples: ['casia_copymove', 'defacto_copymove'] or ['science-fraud_copymove'] or None (all)
    DATASETS_TO_LOAD = ['science-fraud_copymove','defacto_copymove']  # None = load all datasets, or specify list like ['casia_copymove', 'defacto_copymove']
    # List of dataset subfolder names to exclude (applied after DATASETS_TO_LOAD if both are set)
    # Examples: ['science-fraud_copymove'] to exclude it, or [] to exclude none
    DATASETS_TO_SKIP = []  # Empty list = skip none, or specify list like ['science-fraud_copymove']
    
    # Mask area filtering options
    # Filter image-mask pairs based on mask area percentage (0.0 = disabled, 0.01 = 1%, 0.10 = 10%)
    FILTER_BY_MASK_AREA = True  # Enable/disable mask area filtering
    MIN_MASK_AREA_PERCENT = 0.02  # Minimum mask area as fraction (0.01 = 1% of image)
    MAX_MASK_AREA_PERCENT = 0.15  # Maximum mask area as fraction (0.10 = 10% of image)
    # Note: Authentic images (empty masks) are always included regardless of area filter
    
    # Augmentation parameters
    USE_AUGMENTATION = True  # Master switch: Enable/disable all data augmentation
    
    # Spatial augmentation switches (applied to both image and mask)
    AUG_ENABLE_ROTATION = True  # Enable/disable rotation augmentation
    AUG_ROTATION = 10  # Rotation angle in degrees (±) - only used if AUG_ENABLE_ROTATION=True
    AUG_ENABLE_HFLIP = True  # Enable/disable horizontal flip augmentation
    AUG_HFLIP = 0.5  # Probability of horizontal flip - only used if AUG_ENABLE_HFLIP=True
    AUG_ENABLE_VFLIP = True  # Enable/disable vertical flip augmentation
    AUG_VFLIP = 0.5  # Probability of vertical flip - only used if AUG_ENABLE_VFLIP=True
    
    # Color augmentation switches (applied to image only)
    AUG_ENABLE_COLOR = True  # Enable/disable color augmentations (brightness, contrast, saturation)
    AUG_BRIGHTNESS = 0.2  # Brightness adjustment range (±) - only used if AUG_ENABLE_COLOR=True
    AUG_CONTRAST = 0.2  # Contrast adjustment range (±) - only used if AUG_ENABLE_COLOR=True
    AUG_SATURATION = 0.2  # Saturation adjustment range (±) - only used if AUG_ENABLE_COLOR=True
    
    # Mask smoothing parameters
    MASK_GAUSSIAN_BLUR_RADIUS = 1.0 # Gaussian blur radius for mask label smoothing (0.0 = disabled)
    
    # Per-image normalization parameters
    # Normalization flow: [0,1] norm with 1% extreme clipping -> z-norm -> clamp to [-3,3]
    USE_PER_IMAGE_MINMAX = True # Enable per-image min-max normalization with 1% extreme clipping (1st-99th percentile) -> [0,1]
    USE_PER_IMAGE_ZSCORE = False  # Enable per-image z-score normalization (applied after min-max if enabled)
    PER_IMAGE_PERCENTILE_CLIP = True  # Use percentile clipping for min-max (robust to outliers, clips 1% extremes)
    PER_IMAGE_LOWER_PERCENTILE = 0.1  # Lower percentile for clipping (1 = 1st percentile, clips bottom 1%)
    PER_IMAGE_UPPER_PERCENTILE = 99.9  # Upper percentile for clipping (99 = 99th percentile, clips top 1%)
    USE_WEIGHTED_LOSS = True  # Enable weighted loss (higher penalty for missing origins)
    
    # Size-based loss scaling (linear inverse: 1% area = 10x loss, 10% area = 1x loss)
    USE_SIZE_BASED_WEIGHTING = True  # Enable size-based loss scaling
    SIZE_WEIGHT_MIN_AREA = 0.01  # Minimum area threshold (1% of image) - gets 10x weight
    SIZE_WEIGHT_MAX_AREA = 0.10  # Maximum area threshold (10% of image) - gets 1x weight
    SIZE_WEIGHT_MIN = 1.0  # Weight for forgeries >= 10% area
    SIZE_WEIGHT_MAX = 10.0  # Weight for forgeries <= 1% area
    
    # Prediction post-processing parameters
    PREDICTION_THRESHOLD = 0.6  # Threshold for binarizing predictions (0.0-1.0)

# Create config instance
config = Config()

# Print architecture options
print(f"Model Architecture:")
print(f"  Architecture: UNet with MaxViT blocks (encoder + decoder)")
print(f"  Encoder: Pretrained MaxViT-Tiny (maxvit_tiny_tf_512.in1k)")
print(f"  Decoder: MaxViT blocks with skip connections (same blocks as encoder)")
print(f"  Use Deep Supervision: {config.USE_DEEP_SUPERVISION}")
print(f"  Gradient Checkpointing: {config.USE_GRADIENT_CHECKPOINTING}")
# Determine bottleneck type
if config.USE_IDENTITY_BOTTLENECK:
    bottleneck_type = 'Identity (pass-through)'
elif config.USE_ATROUS_PYRAMID_BOTTLENECK:
    bottleneck_type = 'Atrous Pyramid (ASPP-style)'
else:
    bottleneck_type = 'Standard UNet'
print(f"  Bottleneck: {bottleneck_type}")

# Print configuration
print("=" * 60)
print("TRAINING CONFIGURATION")
print("=" * 60)
print(f"Image Size: {config.IMG_SIZE}")
print(f"Batch Size: {config.BATCH_SIZE}")
print(f"Model: MaxViT-Tiny UNet (MaxViT blocks in encoder + decoder)")

print(f"MaxViT UNet (pretrained encoder, MaxViT decoder blocks)")
print(f"Skip Pretrained: {'YES' if config.SKIP_PRETRAINED else 'NO'}")
print(f"Freeze Encoder: {'YES' if config.FREEZE_ENCODER else 'NO'}")
if config.FREEZE_ENCODER_STAGES is not None:
    print(f"Freeze Encoder Stages: {config.FREEZE_ENCODER_STAGES}")
print(f"Epochs: {config.NUM_EPOCHS}")
print(f"Learning Rate: {config.LEARNING_RATE}")
print(f"Optimizer: {config.OPTIMIZER}")
if config.USE_DICE_LOSS:
    print(f"Segmentation Loss: Focal ({config.FOCAL_WEIGHT}, alpha={config.FOCAL_ALPHA}, gamma={config.FOCAL_GAMMA}) + Dice ({config.DICE_WEIGHT})")
else:
    print(f"Segmentation Loss: Focal ({config.FOCAL_WEIGHT}, alpha={config.FOCAL_ALPHA}, gamma={config.FOCAL_GAMMA}) [Dice Loss: OFF]")
print(f"Patience: {config.PATIENCE}")
print(f"Train/Val Split: {config.TRAIN_VAL_SPLIT}")
print()
print("Dataset Selection:")
if config.DATASETS_TO_LOAD is None:
    print(f"  Load All Datasets: YES (all available datasets will be loaded)")
else:
    print(f"  Datasets to Load: {config.DATASETS_TO_LOAD}")
if config.DATASETS_TO_SKIP:
    print(f"  Datasets to Skip: {config.DATASETS_TO_SKIP}")
else:
    print(f"  Datasets to Skip: None (all selected datasets will be loaded)")
print("Mask Area Filtering:")
if config.FILTER_BY_MASK_AREA:
    print(f"  Enabled: YES")
    print(f"  Range: {config.MIN_MASK_AREA_PERCENT*100:.1f}% - {config.MAX_MASK_AREA_PERCENT*100:.1f}%")
    print(f"  Note: Authentic images (empty masks) are always included")
else:
    print(f"  Enabled: NO (all image-mask pairs will be loaded)")
print()
print("Augmentation Parameters:")
print(f"  Master Switch: {'ON' if config.USE_AUGMENTATION else 'OFF'}")
if config.USE_AUGMENTATION:
    print(f"  Spatial: Rotation={'ON' if config.AUG_ENABLE_ROTATION else 'OFF'} (±{config.AUG_ROTATION}°), "
          f"HFlip={'ON' if config.AUG_ENABLE_HFLIP else 'OFF'} (p={config.AUG_HFLIP}), "
          f"VFlip={'ON' if config.AUG_ENABLE_VFLIP else 'OFF'} (p={config.AUG_VFLIP})")
    print(f"  Color: {'ON' if config.AUG_ENABLE_COLOR else 'OFF'} "
          f"(Brightness=±{config.AUG_BRIGHTNESS}, Contrast=±{config.AUG_CONTRAST}, Saturation=±{config.AUG_SATURATION})")
else:
    print("  All augmentations disabled")
print("=" * 60)


In [None]:
# ============================================================================
# DOWNLOAD AND SAVE PRETRAINED WEIGHTS
# ============================================================================
# This cell downloads the MaxViT pretrained weights and saves them locally
# to avoid redownloading them each time you run the notebook.

import os
import torch
import timm

# Create pretrained_models directory if it doesn't exist
PRETRAINED_MODELS_DIR = 'pretrained_models'
os.makedirs(PRETRAINED_MODELS_DIR, exist_ok=True)

# Check if weights already exist
# Prioritize model_tiny.pth first for MaxViT-Tiny
weight_files = ['model_tiny.pth']
local_weights_path = None
weights_exist = False

for weight_file in weight_files:
    weight_path = os.path.join(PRETRAINED_MODELS_DIR, weight_file)
    if os.path.exists(weight_path):
        local_weights_path = weight_path
        weights_exist = True
        print(f"✓ Pretrained weights already exist at: {local_weights_path}")
        print("  Skipping download. Delete this file if you want to redownload.")
        break

# Download and save weights if they don't exist
if not weights_exist:
    print("=" * 60)
    print("DOWNLOADING PRETRAINED WEIGHTS")
    print("=" * 60)
    
    # Set environment variables for faster download
    os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
    os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '600'  # 10 minutes timeout
    
    try:
        print("Downloading MaxViT-Tiny pretrained weights from Hugging Face...")
        # Download the model with pretrained weights (num_classes=0 removes classifier head)
        maxvit = timm.create_model(
            'maxvit_tiny_tf_512.in1k',
            pretrained=True,
            num_classes=0  # Remove classifier head
        )
        
        # Save the state dict to pretrained_models folder
        save_path = os.path.join(PRETRAINED_MODELS_DIR, 'model_tiny.pth')
        torch.save(maxvit.state_dict(), save_path)
        
        print(f"✓ Successfully downloaded and saved pretrained weights!")
        print(f"  Saved to: {save_path}")
        print(f"  File size: {os.path.getsize(save_path) / (1024**2):.2f} MB")
        
    except Exception as e:
        print(f"⚠ Warning: Failed to download with primary model name: {e}")
        print("Trying alternative model name...")
        try:
            maxvit = timm.create_model(
                'maxvit_tiny_tf_512.in1k',
                pretrained=True,
                num_classes=0
            )
            
            # Save the state dict
            save_path = os.path.join(PRETRAINED_MODELS_DIR, 'model_tiny.pth')
            torch.save(maxvit.state_dict(), save_path)
            
            print(f"✓ Successfully downloaded and saved pretrained weights!")
            print(f"  Saved to: {save_path}")
            print(f"  File size: {os.path.getsize(save_path) / (1024**2):.2f} MB")
            
        except Exception as e2:
            print(f"✗ Error: Failed to download pretrained weights: {e2}")
            print("  The model will be created without pretrained weights.")
            print("  You can train from scratch or manually download the weights.")

print("=" * 60)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

class ChannelAttention(nn.Module):
    """Channel Attention Module (SE-Net style)"""
    def __init__(self, in_channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        # Apply channel reduction (bottleneck)
        reduced_channels = max(1, in_channels // reduction)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, reduced_channels, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(reduced_channels, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        # x shape: (B, C, H, W)
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    """Spatial Attention Module"""
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        # x shape: (B, C, H, W)
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(out)
        return self.sigmoid(out)

class CBAM(nn.Module):
    """Convolutional Block Attention Module (Channel + Spatial)"""
    def __init__(self, in_channels, reduction=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction)
        self.spatial_attention = SpatialAttention(kernel_size)
    
    def forward(self, x):
        # Apply channel attention first
        x = x * self.channel_attention(x)
        # Then apply spatial attention
        x = x * self.spatial_attention(x)
        return x

class AtrousPyramidBottleneck(nn.Module):
    """
    Atrous Spatial Pyramid Pooling (ASPP) style bottleneck for multi-scale feature extraction.
    Uses parallel convolutions with different dilation rates to capture features at multiple scales.
    """
    def __init__(self, in_channels, out_channels, dilation_rates=[1, 2, 4, 8]):
        """
        Args:
            in_channels: Input channel count
            out_channels: Output channel count (same as input for bottleneck)
            dilation_rates: List of dilation rates for parallel convolutions
        """
        super(AtrousPyramidBottleneck, self).__init__()
        
        # Calculate padding for each dilation rate to maintain spatial size
        # For 3x3 conv with dilation d: padding = d
        self.dilation_rates = dilation_rates
        
        # Parallel convolutions with different dilation rates
        self.conv_blocks = nn.ModuleList()
        for dilation in dilation_rates:
            self.conv_blocks.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                             padding=dilation, dilation=dilation, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True)
                )
            )
        
        # Global average pooling branch (1x1 output)
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # Fusion convolution: concatenate all branches and reduce channels
        # Total channels: out_channels * (len(dilation_rates) + 1) [parallel convs + global pool]
        total_channels = out_channels * (len(dilation_rates) + 1)
        self.fusion = nn.Sequential(
            nn.Conv2d(total_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        """
        Args:
            x: Input tensor (B, C, H, W)
        
        Returns:
            Output tensor (B, C, H, W) - same spatial size
        """
        # Get spatial dimensions
        B, C, H, W = x.shape
        
        # Apply parallel convolutions with different dilation rates
        branch_outputs = []
        for conv_block in self.conv_blocks:
            branch_outputs.append(conv_block(x))
        
        # Global average pooling branch (upsample to original size)
        global_feat = self.global_pool(x)
        global_feat = F.interpolate(global_feat, size=(H, W), mode='bilinear', align_corners=False)
        branch_outputs.append(global_feat)
        
        # Concatenate all branches
        out = torch.cat(branch_outputs, dim=1)
        
        # Fusion to reduce channels
        out = self.fusion(out)
        
        return out

class MaxViTDecoderBlock(nn.Module):
    """
    MaxViT-based decoder block with skip connection from corresponding encoder level.
    Uses MaxViT blocks (same as encoder) for feature processing.
    Creates decoder blocks that mirror the encoder structure using the same MaxViT block types.
    """
    def __init__(self, in_channels, skip_channels, out_channels, num_blocks=2):
        """
        Args:
            in_channels: Channel count from previous decoder level (after upsampling)
            skip_channels: Channel count from corresponding encoder skip connection
            out_channels: Output channel count for this decoder level
            num_blocks: Number of MaxViT blocks to use (default: 2, matching encoder structure)
        """
        super(MaxViTDecoderBlock, self).__init__()
        
        # Project concatenated features to out_channels
        # After concatenation: in_channels + skip_channels -> out_channels
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels + skip_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
        )
        
        # Create MaxViT blocks for feature processing (same as encoder)
        # We'll create blocks that match the encoder structure but without downsampling
        from timm.models.maxxvit import MaxxVitBlock, MaxxVitConvCfg, MaxxVitTransformerCfg
        
        self.maxvit_blocks = nn.ModuleList()
        
        # Configure MaxViT blocks with proper window and grid sizes
        # These match the MaxViT-Tiny architecture used in the encoder
        # Window size of 16x16 divides evenly into all decoder spatial sizes (32, 64, 128, 256, 512)
        conv_cfg = MaxxVitConvCfg()  # Use default conv config
        transformer_cfg = MaxxVitTransformerCfg(
            window_size=(16, 16),  # Same as encoder - divides evenly into all spatial sizes
            grid_size=(16, 16),    # Same as encoder
        )
        
        # Create blocks similar to encoder but without downsampling
        # Each block processes features at the same resolution
        for i in range(num_blocks):
            # Create a MaxViT block with stride=1 (no downsampling)
            # This matches the encoder block structure but processes at same resolution
            block = MaxxVitBlock(
                dim=out_channels,
                dim_out=out_channels,  # Same output dimension (no channel change)
                stride=1,  # No downsampling - key difference from encoder
                conv_cfg=conv_cfg,
                transformer_cfg=transformer_cfg,
            )
            self.maxvit_blocks.append(block)
    
    def forward(self, x, skip):
        """
        Args:
            x: Feature from previous decoder level (already upsampled)
            skip: Skip connection from corresponding encoder level
        Returns:
            Output feature map
        """
        # Concatenate upsampled feature with skip connection
        x = torch.cat([x, skip], dim=1)
        
        # Project to target channels
        x = self.projection(x)
        
        # Apply MaxViT blocks (same as encoder structure, but no downsampling)
        for block in self.maxvit_blocks:
            x = block(x)
        
        return x

class MaxViT_CBAM_UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, r=16, skip_pretrained=False, use_deep_supervision=True, use_gradient_checkpointing=False, use_atrous_pyramid_bottleneck=False, use_identity_bottleneck=False):
        super(MaxViT_CBAM_UNet, self).__init__()
        
        self.use_deep_supervision = use_deep_supervision
        self.use_gradient_checkpointing = use_gradient_checkpointing
        
        import os
        
        PRETRAINED_MODELS_DIR = 'pretrained_models'
        os.makedirs(PRETRAINED_MODELS_DIR, exist_ok=True)
        
        # Check for pre-downloaded model_tiny.pth
        local_weights_path = os.path.join(PRETRAINED_MODELS_DIR, 'model_tiny.pth')
        
        # Create model without pretrained weights (will be random initialization if no weights found)
        maxvit = timm.create_model('maxvit_tiny_tf_512.in1k', pretrained=False, num_classes=0)
        
        # Try to load pre-downloaded weights if they exist and skip_pretrained is False
        if not skip_pretrained and os.path.exists(local_weights_path):
            try:
                print(f"Loading MaxViT from pre-downloaded weights: {local_weights_path}")
                state_dict = torch.load(local_weights_path, map_location='cpu')
                
                # Filter state dict to match model architecture (remove classifier head keys)
                model_dict = maxvit.state_dict()
                filtered_dict = {}
                for k, v in state_dict.items():
                    # Skip classifier head keys (head.* or classifier.*)
                    if 'head.' in k or 'classifier.' in k:
                        continue
                    if k in model_dict and v.shape == model_dict[k].shape:
                        filtered_dict[k] = v
                
                missing_keys, unexpected_keys = maxvit.load_state_dict(filtered_dict, strict=False)
                if missing_keys:
                    print(f"  Note: {len(missing_keys)} keys not loaded")
                if unexpected_keys:
                    print(f"  Note: {len(unexpected_keys)} unexpected keys ignored (including classifier head)")
                
                print("✓ Successfully loaded pretrained weights from local file!")
            except Exception as e:
                print(f"⚠ Failed to load from local file: {e}")
                print("  Initializing model with random weights...")
        elif skip_pretrained:
            print("Creating MaxViT model without pretrained weights (skip_pretrained=True)...")
            print("  Initializing model with random weights...")
        else:
            print("No pre-downloaded model.pth found in pretrained_models/")
            print("  Initializing model with random weights...")
            print("  Tip: Run the 'Download and Save Pretrained Weights' cell first to download weights.")
        
        # MaxViT Base feature dimensions (for 512x512 input):
        # Stem: 64 channels, H/2 x W/2 (256x256)
        # Stage 1: 96 channels, H/4 x W/4 (128x128)
        # Stage 2: 192 channels, H/8 x W/8 (64x64)
        # Stage 3: 384 channels, H/16 x W/16 (32x32)
        # Stage 4: 768 channels, H/32 x W/32 (16x16)
        
        # Extract encoder stages from MaxViT
        # Access the internal structure of MaxViT
        self.encoder1 = nn.Identity()  # Input: (B, 3, H, W)
        
        # Get the stem and stages from MaxViT
        # MaxViT structure: stem -> stages[0] -> stages[1] -> stages[2] -> stages[3]
        self.stem = maxvit.stem  # Stem: (B, 3, H, W) -> (B, 64, H/2, W/2)
        self.stage1 = maxvit.stages[0]  # Stage 1: (B, 64, H/2, W/2) -> (B, 64, H/4, W/4) [MaxViT-Tiny]
        self.stage2 = maxvit.stages[1]  # Stage 2: (B, 64, H/4, W/4) -> (B, 128, H/8, W/8) [MaxViT-Tiny]
        self.stage3 = maxvit.stages[2]  # Stage 3: (B, 128, H/8, W/8) -> (B, 256, H/16, W/16) [MaxViT-Tiny]
        self.stage4 = maxvit.stages[3]  # Stage 4: (B, 256, H/16, W/16) -> (B, 512, H/32, W/32) [MaxViT-Tiny]
        
        # Encoder blocks
        # We'll use stem output separately for skip connection
        self.encoder2 = nn.Sequential(self.stem, self.stage1)  # Out: (B, 64, H/4, W/4) [MaxViT-Tiny]
        self.encoder3 = self.stage2  # Out: (B, 128, H/8, W/8) [MaxViT-Tiny]
        self.encoder4 = self.stage3  # Out: (B, 256, H/16, W/16) [MaxViT-Tiny]
        self.encoder5 = self.stage4  # Out: (B, 512, H/32, W/32) [MaxViT-Tiny]
        
        # Bottleneck: identity, atrous pyramid (ASPP-style), or standard UNet
        bottleneck_channels = 512  # Same as encoder5 output
        if use_identity_bottleneck:
            # Identity bottleneck: no processing, just pass through
            self.bottleneck = nn.Identity()
        elif use_atrous_pyramid_bottleneck:
            # Atrous pyramid bottleneck: multi-scale feature extraction with parallel dilated convolutions
            self.bottleneck = AtrousPyramidBottleneck(
                in_channels=bottleneck_channels,
                out_channels=bottleneck_channels,
                dilation_rates=[1, 2, 4, 8]  # Different dilation rates for multi-scale features
            )
        else:
            # Standard UNet bottleneck: 2 conv blocks (Conv2d -> BN -> ReLU)
            self.bottleneck = nn.Sequential(
                nn.Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(bottleneck_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(bottleneck_channels),
                nn.ReLU(inplace=True)
            )
        # Output: (B, 512, H/32, W/32) [MaxViT-Tiny]

        # 3-Stage Decoder (matching paper: D1, D2, D3 correspond to S1, S2, S3)
        # Paper: "The decoder is made up of three stages, D1 to D3, matching with S1 to S3 stages of the encoder"
        # Decoder channels at each level
        d3_ch, d2_ch, d1_ch = 256, 128, 64
        
        # Decoder 3 (D3): Upsample bottleneck (512) and concat with encoder4 (256) [S3]
        # Bottleneck is at H/32, encoder4 (S3) is at H/16
        self.upconv3 = nn.ConvTranspose2d(512, d3_ch, kernel_size=2, stride=2)
        self.decoder3 = MaxViTDecoderBlock(
            in_channels=d3_ch,  # After upconv from bottleneck
            skip_channels=256,  # encoder4 (S3) output
            out_channels=d3_ch,
            num_blocks=2  # Couple of hybrid MaxViT blocks as per paper
        )
        
        # Decoder 2 (D2): Upsample d3 (256) and concat with encoder3 (128) [S2]
        # D3 is at H/16, encoder3 (S2) is at H/8
        self.upconv2 = nn.ConvTranspose2d(d3_ch, d2_ch, kernel_size=2, stride=2)
        self.decoder2 = MaxViTDecoderBlock(
            in_channels=d2_ch,  # After upconv from decoder3
            skip_channels=128,  # encoder3 (S2) output
            out_channels=d2_ch,
            num_blocks=2  # Couple of hybrid MaxViT blocks as per paper
        )
        
        # Decoder 1 (D1): Upsample d2 (128) and concat with encoder2 (64) [S1]
        # D2 is at H/8, encoder2 (S1) is at H/4
        self.upconv1 = nn.ConvTranspose2d(d2_ch, d1_ch, kernel_size=2, stride=2)
        self.decoder1 = MaxViTDecoderBlock(
            in_channels=d1_ch,  # After upconv from decoder2
            skip_channels=64,   # encoder2 (S1) output
            out_channels=d1_ch,
            num_blocks=2  # Couple of hybrid MaxViT blocks as per paper
        )
        
        # Final upsampling: D1 is at H/4, need to upsample 4x to H
        # Paper: "feature maps of shape 64 × H/4 × W/4 are up-sampled four times"
        self.final_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
        
        # Main output head: reduce channels from 64 to C (number of classes)
        self.conv = nn.Conv2d(d1_ch, out_channels, kernel_size=1)
        
        self.training_losses = []
        self.eval_losses = []
        self.training_iou = []
        self.eval_iou = []
    
    def forward(self, x):
        # Import checkpoint here to avoid issues
        from torch.utils.checkpoint import checkpoint
        
        # Encoder path
        e1 = self.encoder1(x)  # (B, 3, H, W) - original input
        
        # Get stem output for skip connection
        stem_out = self.stem(e1)  # (B, 64, H/2, W/2)
        
        # Encoder stages
        e2 = self.encoder2(e1)  # (B, 64, H/4, W/4) - stem + stage1 [MaxViT-Tiny]
        e3 = self.encoder3(e2)  # (B, 128, H/8, W/8) - stage2 [MaxViT-Tiny]
        e4 = self.encoder4(e3)  # (B, 256, H/16, W/16) - stage3 [MaxViT-Tiny]
        e5 = self.encoder5(e4)  # (B, 512, H/32, W/32) - stage4 [MaxViT-Tiny]
        
        # Bottleneck
        bottleneck = self.bottleneck(e5)  # (B, 512, H/32, W/32) [MaxViT-Tiny]
        
        # 3-Stage Decoder (D1, D2, D3 matching encoder stages S1, S2, S3)
        
        # Decoder 3 (D3): Upsample bottleneck and concat with encoder4 (e4) [S3]
        d3_up = self.upconv3(bottleneck)  # (B, 256, H/16, W/16)
        if self.use_gradient_checkpointing and self.training:
            d3 = checkpoint(self.decoder3, d3_up, e4, use_reentrant=False)
        else:
            d3 = self.decoder3(d3_up, e4)
        
        # Decoder 2 (D2): Upsample d3 and concat with encoder3 (e3) [S2]
        d2_up = self.upconv2(d3)  # (B, 128, H/8, W/8)
        if self.use_gradient_checkpointing and self.training:
            d2 = checkpoint(self.decoder2, d2_up, e3, use_reentrant=False)
        else:
            d2 = self.decoder2(d2_up, e3)
        
        # Decoder 1 (D1): Upsample d2 and concat with encoder2 (e2) [S1]
        d1_up = self.upconv1(d2)  # (B, 64, H/4, W/4)
        if self.use_gradient_checkpointing and self.training:
            d1 = checkpoint(self.decoder1, d1_up, e2, use_reentrant=False)
        else:
            d1 = self.decoder1(d1_up, e2)
        
        # Final upsampling: 4x from H/4 to H (as per paper)
        d1_upsampled = self.final_upsample(d1)  # (B, 64, H, W)
        
        # Output layer: reduce channels from 64 to C (number of classes)
        output = self.conv(d1_upsampled)  # (B, 1, H, W) - logits
        
        return output

# Check for existing checkpoint before creating model
checkpoint_dir = config.CHECKPOINT_DIR
os.makedirs(checkpoint_dir, exist_ok=True)

# Checkpoint to use (from config)
checkpoint_path = os.path.join(checkpoint_dir, config.MODEL_SAVE_NAME)

checkpoint_found = False
checkpoint_to_load = None

# Check for checkpoint saved by this notebook
if os.path.exists(checkpoint_path):
    checkpoint_to_load = checkpoint_path
    checkpoint_found = True
    print(f"✓ Found checkpoint: {checkpoint_path}")
    print(f"  Loading from checkpoint instead of pretrained weights...")
    print(f"  Checkpoint file: {checkpoint_to_load}")

# Determine if we should skip pretrained weights
# Skip if: config option is True OR checkpoint is found (checkpoint will overwrite pretrained anyway)
skip_pretrained = config.SKIP_PRETRAINED or checkpoint_found

if config.SKIP_PRETRAINED and not checkpoint_found:
    print(f"  Skipping pretrained weights (SKIP_PRETRAINED=True). Training from scratch or using checkpoint only.")

# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Pass skip_pretrained flag to model initialization
model = MaxViT_CBAM_UNet(
    in_channels=config.IN_CHANNELS, 
    out_channels=config.OUT_CHANNELS,
    r=config.CBAM_REDUCTION,
    skip_pretrained=skip_pretrained,
    use_deep_supervision=config.USE_DEEP_SUPERVISION,
    use_gradient_checkpointing=config.USE_GRADIENT_CHECKPOINTING,
    use_atrous_pyramid_bottleneck=config.USE_ATROUS_PYRAMID_BOTTLENECK,
    use_identity_bottleneck=config.USE_IDENTITY_BOTTLENECK
).to(device)

# Load checkpoint if found (this will overwrite pretrained weights)
if checkpoint_found:
    try:
        if checkpoint_to_load.endswith('.pt'):
            # Complete checkpoint format
            checkpoint = torch.load(checkpoint_to_load, map_location=device)
            if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
                print(f"✓ Successfully loaded model from complete checkpoint!")
                if 'val_loss' in checkpoint:
                    print(f"  Checkpoint validation loss: {checkpoint['val_loss']:.4f}")
                if 'epoch' in checkpoint:
                    print(f"  Checkpoint epoch: {checkpoint['epoch']}")
            else:
                # Treat as state dict
                model.load_state_dict(checkpoint)
                print(f"✓ Successfully loaded model from checkpoint!")
        else:
            # State dict format
            model.load_state_dict(torch.load(checkpoint_to_load, map_location=device))
            print(f"✓ Successfully loaded model from checkpoint!")
        print(f"  Model weights loaded from: {checkpoint_to_load}")
    except Exception as e:
        print(f"⚠ Warning: Failed to load checkpoint: {e}")
        if config.SKIP_PRETRAINED:
            print(f"  Continuing without pretrained weights (SKIP_PRETRAINED=True).")
        else:
            print(f"  Continuing with pretrained weights instead...")
else:
    if config.SKIP_PRETRAINED:
        print(f"No checkpoint found. Training from scratch (SKIP_PRETRAINED=True).")
    else:
        print(f"No checkpoint found. Using pretrained weights from MaxViT encoder.")

print(f"Model created on device: {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# Freeze encoder layers if enabled
# Support selective stage freezing via FREEZE_ENCODER_STAGES
if config.FREEZE_ENCODER_STAGES is not None:
    # Selective stage freezing: FREEZE_ENCODER_STAGES = [1,2,3,4] or [1,2] etc.
    # Stage mapping: 1=stem+stage1, 2=stage2, 3=stage3, 4=stage4
    freeze_stages = set(config.FREEZE_ENCODER_STAGES)
    
    frozen_components = []
    
    if 1 in freeze_stages:
        # Freeze Stage 1: stem + stage1 (encoder2)
        for param in model.stem.parameters():
            param.requires_grad = False
        for param in model.stage1.parameters():
            param.requires_grad = False
        for param in model.encoder2.parameters():
            param.requires_grad = False
        frozen_components.append("Stage 1 (stem+stage1)")
    
    if 2 in freeze_stages:
        # Freeze Stage 2: stage2 (encoder3)
        for param in model.stage2.parameters():
            param.requires_grad = False
        for param in model.encoder3.parameters():
            param.requires_grad = False
        frozen_components.append("Stage 2 (stage2)")
    
    if 3 in freeze_stages:
        # Freeze Stage 3: stage3 (encoder4)
        for param in model.stage3.parameters():
            param.requires_grad = False
        for param in model.encoder4.parameters():
            param.requires_grad = False
        frozen_components.append("Stage 3 (stage3)")
    
    if 4 in freeze_stages:
        # Freeze Stage 4: stage4 (encoder5)
        for param in model.stage4.parameters():
            param.requires_grad = False
        for param in model.encoder5.parameters():
            param.requires_grad = False
        frozen_components.append("Stage 4 (stage4)")
    
    # Note: encoder1 (Identity) and Bottleneck are NOT frozen - they should always be trainable
    
    # Count trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    print(f"\n✓ Encoder stages frozen: {config.FREEZE_ENCODER_STAGES}")
    print(f"  Frozen components: {', '.join(frozen_components)}")
    print(f"  Trainable parameters: {trainable_params:,} ({100*trainable_params/(trainable_params+frozen_params):.1f}%)")
    print(f"  Frozen parameters: {frozen_params:,} ({100*frozen_params/(trainable_params+frozen_params):.1f}%)")
    
elif config.FREEZE_ENCODER:
    # Legacy behavior: freeze all encoder components (MaxViT backbone)
    for param in model.encoder1.parameters():
        param.requires_grad = False
    for param in model.stem.parameters():
        param.requires_grad = False
    for param in model.stage1.parameters():
        param.requires_grad = False
    for param in model.stage2.parameters():
        param.requires_grad = False
    for param in model.stage3.parameters():
        param.requires_grad = False
    for param in model.stage4.parameters():
        param.requires_grad = False
    for param in model.encoder2.parameters():
        param.requires_grad = False
    for param in model.encoder3.parameters():
        param.requires_grad = False
    for param in model.encoder4.parameters():
        param.requires_grad = False
    for param in model.encoder5.parameters():
        param.requires_grad = False
    # Note: Bottleneck is NOT frozen - it should always be trainable
    
    # Count trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    print(f"\n✓ Encoder frozen (using pretrained embeddings)")
    print(f"  Trainable parameters: {trainable_params:,} ({100*trainable_params/(trainable_params+frozen_params):.1f}%)")
    print(f"  Frozen parameters: {frozen_params:,} ({100*frozen_params/(trainable_params+frozen_params):.1f}%)")
else:
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nEncoder: Trainable (all {trainable_params:,} parameters will be updated)")

print(model)

In [None]:
# ============================================================================
# LOSS FUNCTIONS
# ============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F

class DiceLoss(nn.Module):
    """Dice Loss for binary segmentation with optional spatial weights"""
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, predictions, targets, weights=None):
        # Calculate per-pixel Dice loss
        # Flatten tensors (keep spatial dimensions for weighting)
        batch_size = predictions.shape[0]
        predictions_flat = predictions.view(batch_size, -1)
        targets_flat = targets.view(batch_size, -1)
        
        # Calculate Dice coefficient per sample
        intersection = (predictions_flat * targets_flat).sum(dim=1)
        union = predictions_flat.sum(dim=1) + targets_flat.sum(dim=1)
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        
        # Clamp dice to [0, 1] to prevent numerical issues
        dice = torch.clamp(dice, 0.0, 1.0)
        dice_loss = 1 - dice
        
        # If weights provided, calculate weighted average
        if weights is not None:
            # Average weight per sample (for weighting the loss)
            weights_flat = weights.view(batch_size, -1)
            sample_weights = weights_flat.mean(dim=1)  # Average weight per sample
            # Weighted mean
            dice_loss = (dice_loss * sample_weights).sum() / (sample_weights.sum() + 1e-8)
        else:
            dice_loss = dice_loss.mean()
        
        return dice_loss

class FocalLoss(nn.Module):
    """Focal Loss for binary segmentation with optional spatial weights"""
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, logits, targets, weights=None):
        # Apply sigmoid to get probabilities
        probs = torch.sigmoid(logits)
        
        # Calculate BCE loss
        bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        
        # Calculate p_t (probability of true class)
        p_t = probs * targets + (1 - probs) * (1 - targets)
        
        # Calculate alpha_t (alpha for true class)
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        
        # Calculate focal weight
        focal_weight = alpha_t * (1 - p_t) ** self.gamma
        
        # Apply focal weight to BCE loss
        focal_loss = focal_weight * bce_loss
        
        # Apply spatial weights if provided
        if weights is not None:
            focal_loss = focal_loss * weights
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class CombinedLossWithDeepSupervision(nn.Module):
    """Combined Focal and Dice Loss with Deep Supervision support"""
    def __init__(self, focal_weight=0.5, dice_weight=0.5, dice_smooth=1e-6, 
                 focal_alpha=0.25, focal_gamma=2.0, deep_supervision_weights=None, use_dice_loss=True):
        super(CombinedLossWithDeepSupervision, self).__init__()
        self.focal_weight = focal_weight
        self.use_dice_loss = use_dice_loss
        self.dice_weight = dice_weight if use_dice_loss else 0.0
        self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
        if use_dice_loss:
            self.dice_loss = DiceLoss(smooth=dice_smooth)
        else:
            self.dice_loss = None
        
        # Weights for deep supervision outputs (main, aux5, aux4, aux3, aux2)
        if deep_supervision_weights is None:
            self.deep_supervision_weights = [1.0, 0.8, 0.6, 0.4, 0.2]
        else:
            self.deep_supervision_weights = deep_supervision_weights
    
    def forward(self, outputs, mask_targets, weight_map=None):
        """
        Args:
            outputs: Single tensor (main output only)
            mask_targets: Ground truth masks
            weight_map: Optional spatial weight map
        """
        # Handle both tuple (for compatibility) and single output
        if isinstance(outputs, tuple):
            # If tuple provided, use only main output (first element)
            main_output = outputs[0]
        else:
            main_output = outputs
        
        # Single output loss calculation
        seg_focal = self.focal_loss(main_output, mask_targets, weight_map)
        if self.use_dice_loss:
            mask_probs = torch.sigmoid(main_output)
            seg_dice = self.dice_loss(mask_probs, mask_targets, weight_map)
            return self.focal_weight * seg_focal + self.dice_weight * seg_dice
        else:
            return self.focal_weight * seg_focal

def calculate_iou(logits_or_probs, targets, threshold=0.5, smooth=1e-6):
    """
    Calculate Intersection over Union (IoU) for binary segmentation.
    
    Args:
        logits_or_probs: Model output logits (B, 1, H, W) or probabilities (B, 1, H, W)
                        Can be a tuple for deep supervision (uses first element)
        targets: Ground truth masks tensor (B, 1, H, W) - values in [0, 1]
        threshold: Threshold for binarizing predictions (default: 0.5)
        smooth: Smoothing factor to avoid division by zero
    
    Returns:
        Mean IoU score (scalar tensor)
    """
    # Handle deep supervision outputs (tuple) or single output (tensor)
    if isinstance(logits_or_probs, tuple):
        main_output = logits_or_probs[0]  # Use main output for metrics
    else:
        main_output = logits_or_probs
    
    # Apply sigmoid if logits (values outside [0,1] range), otherwise assume probabilities
    if main_output.min() < 0 or main_output.max() > 1:
        mask_probs = torch.sigmoid(main_output)
    else:
        mask_probs = main_output
    
    # Threshold predictions to binary
    predicted = (mask_probs > threshold).float()
    
    # Calculate IoU efficiently
    intersection = (predicted * targets).sum(dim=(1, 2, 3))
    predicted_sum = predicted.sum(dim=(1, 2, 3))
    gt_sum = targets.sum(dim=(1, 2, 3))
    union = predicted_sum + gt_sum - intersection
    
    # Fast empty check using .any() - both masks are completely empty
    both_empty = ~predicted.any(dim=(1, 2, 3)) & ~targets.any(dim=(1, 2, 3))
    
    # Calculate IoU per image
    iou_per_image = intersection / (union + smooth)
    
    # Set IoU to 0.0 when both masks are empty (authentic images)
    # This makes IoU only measure forgery detection performance
    iou_per_image = torch.where(both_empty, torch.zeros_like(iou_per_image), iou_per_image)
    
    # Return mean IoU
    return iou_per_image.mean()

def calculate_f1_score(predicted, targets, smooth=1e-6):
    """
    Calculate pixel-wise F1 score for binary segmentation.
    
    Args:
        predicted: Binary predictions tensor (B, 1, H, W) or (B, H, W) - values in {0, 1}
        targets: Ground truth masks tensor (B, 1, H, W) or (B, H, W) - values in {0, 1}
        smooth: Smoothing factor to avoid division by zero
    
    Returns:
        F1 score (scalar tensor)
    """
    # Flatten tensors for pixel-wise calculation
    predicted_flat = predicted.view(-1)
    targets_flat = targets.view(-1)
    
    # Calculate True Positives, False Positives, False Negatives
    tp = (predicted_flat * targets_flat).sum()
    fp = (predicted_flat * (1 - targets_flat)).sum()
    fn = ((1 - predicted_flat) * targets_flat).sum()
    
    # Calculate Precision and Recall
    precision = (tp + smooth) / (tp + fp + smooth)
    recall = (tp + smooth) / (tp + fn + smooth)
    
    # Calculate F1 score (harmonic mean of precision and recall)
    f1 = (2 * precision * recall + smooth) / (precision + recall + smooth)
    
    return f1

def calculate_metrics(logits_or_probs, targets, threshold=0.5):
    """
    Calculate both IoU and F1 score for binary segmentation.
    
    Args:
        logits_or_probs: Model output logits (B, 1, H, W) or probabilities (B, 1, H, W)
                        Can be a tuple for deep supervision (uses first element)
        targets: Ground truth masks tensor (B, 1, H, W) - values in [0, 1]
        threshold: Threshold for binarizing predictions (default: 0.5)
    
    Returns:
        Tuple of (iou, f1) scores (both scalar tensors)
    """
    # Handle deep supervision outputs (tuple) or single output (tensor)
    if isinstance(logits_or_probs, tuple):
        main_output = logits_or_probs[0]  # Use main output for metrics
    else:
        main_output = logits_or_probs
    
    # Apply sigmoid if logits, otherwise assume probabilities
    if main_output.min() < 0 or main_output.max() > 1:
        mask_probs = torch.sigmoid(main_output)
    else:
        mask_probs = main_output
    
    # Threshold predictions to binary
    predicted = (mask_probs > threshold).float()
    
    # Calculate IoU (pass original logits/probs, not processed)
    iou = calculate_iou(logits_or_probs, targets, threshold=threshold)
    
    # Calculate F1 score
    f1 = calculate_f1_score(predicted, targets)
    
    return iou, f1

print("✓ Loss functions and metrics defined")


In [None]:
# ============================================================================
# VISUALIZATION FUNCTIONS
# ============================================================================

import numpy as np
import matplotlib.pyplot as plt
import os

def denormalize_imagenet(img_tensor):
    """
    Denormalize ImageNet-normalized images for correct display.
    
    Args:
        img_tensor: Tensor of shape (B, C, H, W) or (B, H, W, C) or (C, H, W) or numpy array of shape (H, W, C)
                    with ImageNet normalization applied
    
    Returns:
        Denormalized image in range [0, 1] ready for display
    """
    # ImageNet mean and std
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    
    # Handle different input formats
    if isinstance(img_tensor, torch.Tensor):
        img_np = img_tensor.cpu().numpy()
    else:
        img_np = img_tensor.copy()
    
    # Handle different tensor shapes
    if len(img_np.shape) == 4:  # Batch dimension present
        # Check if it's (B, C, H, W) or (B, H, W, C)
        if img_np.shape[1] == 3:  # (B, C, H, W)
            # Permute to (B, H, W, C) for easier broadcasting
            img_np = img_np.transpose(0, 2, 3, 1)
        # img_np is now (B, H, W, C)
        # Denormalize: img = img * std + mean
        img_np = img_np * std + mean
    elif len(img_np.shape) == 3:
        if img_np.shape[0] == 3:  # (C, H, W)
            # Permute to (H, W, C)
            img_np = img_np.transpose(1, 2, 0)
        # img_np is now (H, W, C)
        img_np = img_np * std + mean
    
    # Clip to valid range [0, 1]
    img_np = np.clip(img_np, 0, 1)
    
    return img_np

def save_visualization(images, masks, predictions, epoch, split='train', viz_dir=None):
    """
    Save visualization images from a batch.
    
    Args:
        images: Tensor of shape (B, C, H, W) - original images (normalized)
        masks: Tensor of shape (B, 1, H, W) - ground truth masks
        predictions: Tensor of shape (B, 1, H, W) - predicted masks (probabilities)
        epoch: Current epoch number
        split: 'train' or 'val'
        viz_dir: Directory to save visualizations (uses config.VIZ_DIR if None)
    """
    if viz_dir is None:
        viz_dir = config.VIZ_DIR
    
    # Convert tensors to numpy and move to CPU
    images_np = images.cpu().permute(0, 2, 3, 1).float().numpy()  # (B, H, W, C)
    masks_np = masks.cpu().squeeze(1).float().numpy()  # (B, H, W)
    preds_np = predictions.cpu().squeeze(1).float().numpy()  # (B, H, W)
    
    # Denormalize images for correct display
    images_np = denormalize_imagenet(images_np)
    
    # Threshold predictions to binary
    preds_binary = (preds_np > config.PREDICTION_THRESHOLD).astype(np.float32)
    
    # Save all images from the batch
    batch_size = masks_np.shape[0]
    for img_idx in range(batch_size):
        fig, axes = plt.subplots(1, 4, figsize=(16, 4))
        
        # Determine image type
        is_forged = masks_np[img_idx].any()
        image_type = "Forged" if is_forged else "Original/Authentic"
        
        # Original image
        img = images_np[img_idx]
        axes[0].imshow(img)
        axes[0].set_title(f'{image_type} Image')
        axes[0].axis('off')
        
        # Ground truth mask
        axes[1].imshow(masks_np[img_idx], cmap='gray')
        axes[1].set_title('Ground Truth Mask')
        axes[1].axis('off')
        
        # Predicted mask
        axes[2].imshow(preds_binary[img_idx], cmap='gray')
        axes[2].set_title('Predicted Mask')
        axes[2].axis('off')
        
        # Overlay: original image with predicted mask in red
        overlay = img.copy()
        mask_colored = np.zeros_like(overlay)
        mask_colored[:, :, 0] = preds_binary[img_idx]  # Red channel
        overlay = np.clip(overlay * 0.6 + mask_colored * 0.4, 0, 1)
        axes[3].imshow(overlay)
        axes[3].set_title('Overlay (Original + Prediction)')
        axes[3].axis('off')
        
        plt.tight_layout()
        save_path = os.path.join(viz_dir, f'{split}_epoch_{epoch:03d}_{img_idx+1}.jpg')
        plt.savefig(save_path, dpi=150, bbox_inches='tight', format='jpg')
        plt.close()

print("✓ Visualization functions defined")


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
from torchvision import transforms
import torchvision.transforms.functional as TF
import random
import os
import re
from pathlib import Path


class JointTransform:
    """Apply the same spatial transformations to both image and mask"""
    def __init__(self, rotation=15, hflip_prob=0.5, vflip_prob=0.0, 
                 enable_rotation=True, enable_hflip=True, enable_vflip=True):
        self.rotation = rotation if enable_rotation else 0
        self.hflip_prob = hflip_prob if enable_hflip else 0.0
        self.vflip_prob = vflip_prob if enable_vflip else 0.0
        # Store last transformation parameters for weight_map
        self.last_angle = 0
        self.last_hflip = False
        self.last_vflip = False
    
    def __call__(self, img, mask):
        # Random rotation
        if self.rotation > 0:
            angle = random.uniform(-self.rotation, self.rotation)
            self.last_angle = angle
            img = TF.rotate(img, angle, interpolation=TF.InterpolationMode.BILINEAR, fill=0)
            mask = TF.rotate(mask, angle, interpolation=TF.InterpolationMode.NEAREST, fill=0)
        
        # Random horizontal flip
        self.last_hflip = False
        if self.hflip_prob > 0 and random.random() < self.hflip_prob:
            self.last_hflip = True
            img = TF.hflip(img)
            mask = TF.hflip(mask)
        
        # Random vertical flip
        self.last_vflip = False
        if self.vflip_prob > 0 and random.random() < self.vflip_prob:
            self.last_vflip = True
            img = TF.vflip(img)
            mask = TF.vflip(mask)
        
        return img, mask
    
    def transform_mask_only(self, mask):
        """Apply the same transformations to weight_map that were applied to img/mask"""
        # Apply rotation
        if self.rotation > 0 and self.last_angle != 0:
            mask = TF.rotate(mask, self.last_angle, interpolation=TF.InterpolationMode.NEAREST, fill=1.0)
        
        # Apply horizontal flip
        if self.last_hflip:
            mask = TF.hflip(mask)
        
        # Apply vertical flip
        if self.last_vflip:
            mask = TF.vflip(mask)
        
        return mask

def apply_size_based_weighting(weight_map, mask, min_area=0.01, max_area=0.10, min_weight=1.0, max_weight=10.0):
    """
    Apply linear inverse size-based weighting to weight map.
    
    Args:
        weight_map: Weight map to modify (numpy array)
        mask: Binary mask (numpy array, values in [0, 1])
        min_area: Minimum area ratio (1% = 0.01) - gets max_weight
        max_area: Maximum area ratio (10% = 0.10) - gets min_weight
        min_weight: Weight for forgeries >= max_area
        max_weight: Weight for forgeries <= min_area
    
    Returns:
        Modified weight_map with size-based weights applied
    """
    if mask.max() == 0:
        return weight_map  # No forgeries, return unchanged
    
    mask_uint8 = (mask * 255).astype(np.uint8)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_uint8, connectivity=8)
    
    total_pixels = mask.shape[0] * mask.shape[1]
    
    # For each connected component (forgery region)
    for label_id in range(1, num_labels):
        area = stats[label_id, cv2.CC_STAT_AREA]
        area_ratio = area / total_pixels
        
        # Calculate size-based weight using linear interpolation
        if area_ratio <= min_area:
            # Very small forgeries (<= 1%): maximum weight
            size_weight = max_weight
        elif area_ratio >= max_area:
            # Large forgeries (>= 10%): minimum weight
            size_weight = min_weight
        else:
            # Linear interpolation between min_area and max_area
            # weight = max_weight - (max_weight - min_weight) * (area_ratio - min_area) / (max_area - min_area)
            normalized = (area_ratio - min_area) / (max_area - min_area)
            size_weight = max_weight - (max_weight - min_weight) * normalized
        
        # Apply to weight map for this component (multiply existing weights)
        component_mask = (labels == label_id)
        weight_map[component_mask] *= size_weight
    
    return weight_map

class SyntheticCopyMoveForgery:
    """Apply synthetic copy-move forgeries to authentic images using grid-based placement"""
    def __init__(self, prob=0.5, min_regions=1, max_regions=3, 
                 min_copies=1, max_copies=3, min_width=35, max_width=100,
                 min_aspect=1.0, max_aspect=4.0, min_scale=0.8, max_scale=1.2, 
                 origin_weight=3.0, probe_weight=1.0, grid_size=64, blur_border=False,
                 origin_rotation=True, origin_rotation_range=90,
                 probe_rotation=True, probe_rotation_range=90):
        self.prob = prob
        self.min_regions = min_regions
        self.max_regions = max_regions
        self.min_copies = min_copies
        self.max_copies = max_copies
        self.min_width = min_width
        self.max_width = max_width
        self.min_aspect = min_aspect
        self.max_aspect = max_aspect
        self.min_scale = min_scale
        self.max_scale = max_scale
        self.origin_weight = origin_weight
        self.probe_weight = probe_weight
        self.grid_size = grid_size
        self.blur_border = blur_border
        self.origin_rotation = origin_rotation
        self.origin_rotation_range = origin_rotation_range
        self.probe_rotation = probe_rotation
        self.probe_rotation_range = probe_rotation_range
    
    def _create_grid_mask(self, img_height, img_width):
        """Create a grid mask to track occupied cells"""
        grid_h = (img_height + self.grid_size - 1) // self.grid_size
        grid_w = (img_width + self.grid_size - 1) // self.grid_size
        return np.zeros((grid_h, grid_w), dtype=np.uint8)
    
    def _mark_grid_cells(self, grid_mask, x, y, w, h, img_height, img_width, margin=5):
        """Mark grid cells as occupied"""
        x1 = max(0, x - margin)
        y1 = max(0, y - margin)
        x2 = min(img_width, x + w + margin)
        y2 = min(img_height, y + h + margin)
        grid_x1 = x1 // self.grid_size
        grid_y1 = y1 // self.grid_size
        grid_x2 = (x2 + self.grid_size - 1) // self.grid_size
        grid_y2 = (y2 + self.grid_size - 1) // self.grid_size
        grid_mask[grid_y1:grid_y2, grid_x1:grid_x2] = 1
    
    def _check_grid_cells(self, grid_mask, x, y, w, h, img_height, img_width, margin=5):
        """Check if grid cells are available for a region"""
        x1 = max(0, x - margin)
        y1 = max(0, y - margin)
        x2 = min(img_width, x + w + margin)
        y2 = min(img_height, y + h + margin)
        grid_x1 = x1 // self.grid_size
        grid_y1 = y1 // self.grid_size
        grid_x2 = (x2 + self.grid_size - 1) // self.grid_size
        grid_y2 = (y2 + self.grid_size - 1) // self.grid_size
        return grid_mask[grid_y1:grid_y2, grid_x1:grid_x2].max() == 0
    
    def _get_random_region_size(self, img_height, img_width, region_type):
        """Generate random size for a region"""
        width = random.randint(self.min_width, min(self.max_width, img_width // 3))
        aspect_ratio = random.uniform(self.min_aspect, self.max_aspect)
        if random.random() < 0.5:
            height = int(width * aspect_ratio)
        else:
            height = int(width / aspect_ratio)
        height = max(self.min_width, min(height, img_height // 3))
        return width, height
    
    def _get_random_position(self, img_height, img_width, width, height, grid_mask, margin=5):
        """Get a random position that doesn't overlap with occupied grid cells"""
        max_attempts = 50
        for _ in range(max_attempts):
            x = random.randint(margin, img_width - width - margin)
            y = random.randint(margin, img_height - height - margin)
            if self._check_grid_cells(grid_mask, x, y, width, height, img_height, img_width, margin):
                return x, y
        return None, None
    
    def _create_region_mask(self, img_height, img_width, width, height, x, y, region_type):
        """Create a binary mask for a specific region type"""
        mask = np.zeros((img_height, img_width), dtype=np.uint8)
        if region_type == 'circle':
            radius = min(width, height) // 2
            center_x = x + width // 2
            center_y = y + height // 2
            cv2.circle(mask, (center_x, center_y), radius, 255, -1)
        elif region_type == 'oval':
            center_x = x + width // 2
            center_y = y + height // 2
            axes = (width // 2, height // 2)
            cv2.ellipse(mask, (center_x, center_y), axes, 0, 0, 360, 255, -1)
        elif region_type == 'rectangle':
            cv2.rectangle(mask, (x, y), (x + width, y + height), 255, -1)
        return mask
    
    def _rotate_region(self, region, region_mask, angle):
        """Rotate a region and its mask using BORDER_REPLICATE to avoid artifacts"""
        if abs(angle) < 0.1:
            return region, region_mask
        h, w = region.shape[:2]
        center = (w / 2.0, h / 2.0)
        rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
        cos = np.abs(rotation_matrix[0, 0])
        sin = np.abs(rotation_matrix[0, 1])
        new_w = int((h * sin) + (w * cos))
        new_h = int((h * cos) + (w * sin))
        rotation_matrix[0, 2] += (new_w / 2) - center[0]
        rotation_matrix[1, 2] += (new_h / 2) - center[1]
        region_rotated = cv2.warpAffine(region, rotation_matrix, (new_w, new_h), 
                                        flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
        region_mask_rotated = cv2.warpAffine(region_mask, rotation_matrix, (new_w, new_h),
                                             flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
        return region_rotated, region_mask_rotated
    
    def _apply_copy(self, img, mask, source_mask, target_x, target_y, scale_factor=1.0, 
                    region_type=None, width=None, height=None, origin_rotation=0.0, probe_rotation=0.0):
        """Copy a region from source to target with alpha blending insertion"""
        contours, _ = cv2.findContours(source_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if len(contours) == 0:
            return img, mask, False
        x, y, w, h = cv2.boundingRect(contours[0])
        region = img[y:y+h, x:x+w].copy()
        region_mask = source_mask[y:y+h, x:x+w].copy()
        region_mask = (region_mask > 127).astype(np.uint8) * 255
        if abs(origin_rotation) > 0.1:
            region, region_mask = self._rotate_region(region, region_mask, origin_rotation)
            h, w = region.shape[:2]
        if abs(scale_factor - 1.0) > 0.01:
            new_w = max(1, int(w * scale_factor))
            new_h = max(1, int(h * scale_factor))
            region = cv2.resize(region, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
            region_mask = cv2.resize(region_mask, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
            w, h = new_w, new_h
        if abs(probe_rotation) > 0.1:
            region, region_mask = self._rotate_region(region, region_mask, probe_rotation)
            h, w = region.shape[:2]
        if target_x < 0 or target_y < 0 or target_x + w > img.shape[1] or target_y + h > img.shape[0]:
            return img, mask, False
        alpha = region_mask.astype(np.float32) / 255.0
        alpha[alpha < 0.05] = 0.0
        alpha_3d = alpha[:, :, np.newaxis]
        target_patch = img[target_y:target_y+h, target_x:target_x+w]
        img[target_y:target_y+h, target_x:target_x+w] = alpha_3d * region + (1.0 - alpha_3d) * target_patch
        if self.blur_border and alpha.max() > 0:
            patch = img[target_y:target_y+h, target_x:target_x+w].copy()
            binary_mask = (alpha > 0.5).astype(np.uint8)
            dist = cv2.distanceTransform(binary_mask, cv2.DIST_L2, 3)
            blur_zone = ((dist > 0) & (dist <= 3.0)).astype(np.float32)[:, :, np.newaxis]
            blurred = cv2.GaussianBlur(patch, (5, 5), 1.0)
            img[target_y:target_y+h, target_x:target_x+w] = blur_zone * blurred + (1.0 - blur_zone) * patch
        probe_mask_uint8 = ((alpha > 0.5).astype(np.uint8)) * 255
        mask[target_y:target_y+h, target_x:target_x+w] = np.maximum(
            mask[target_y:target_y+h, target_x:target_x+w], probe_mask_uint8)
        return img, mask, True
    
    def __call__(self, img, mask):
        """Apply synthetic copy-move forgeries to image and update mask using grid-based placement"""
        if mask.max() > 0:
            weight_map = np.ones_like(mask, dtype=np.float32)
            return img, mask, weight_map
        if random.random() > self.prob:
            weight_map = np.ones_like(mask, dtype=np.float32)
            return img, mask, weight_map
        img_height, img_width = img.shape[:2]
        mask_uint8 = (mask * 255).astype(np.uint8)
        weight_map = np.ones((img_height, img_width), dtype=np.float32)
        grid_mask = self._create_grid_mask(img_height, img_width)
        num_regions = random.randint(self.min_regions, self.max_regions)
        region_types = ['circle', 'oval', 'rectangle']
        successful_forgeries = []
        for region_idx in range(num_regions):
            region_type = random.choice(region_types)
            width, height = self._get_random_region_size(img_height, img_width, region_type)
            source_x, source_y = self._get_random_position(img_height, img_width, width, height, grid_mask, margin=5)
            if source_x is None or source_y is None:
                continue
            source_mask = self._create_region_mask(img_height, img_width, width, height, source_x, source_y, region_type)
            num_copies = random.randint(self.min_copies, self.max_copies)
            placed_copies = []
            temp_grid_mask = grid_mask.copy()
            self._mark_grid_cells(temp_grid_mask, source_x, source_y, width, height, img_height, img_width, margin=5)
            for copy_idx in range(num_copies):
                scale_factor = random.uniform(self.min_scale, self.max_scale)
                origin_rot = 0.0
                probe_rot = 0.0
                if self.origin_rotation:
                    origin_rot = random.uniform(-self.origin_rotation_range, self.origin_rotation_range)
                if self.probe_rotation:
                    probe_rot = random.uniform(-self.probe_rotation_range, self.probe_rotation_range)
                max_rotation = max(abs(origin_rot), abs(probe_rot)) if (abs(origin_rot) > 0.1 or abs(probe_rot) > 0.1) else 0
                if max_rotation > 0.1:
                    rotation_factor = 1.5 if max_rotation > 45 else 1.2
                    scaled_width = int(width * scale_factor * rotation_factor)
                    scaled_height = int(height * scale_factor * rotation_factor)
                else:
                    scaled_width = int(width * scale_factor)
                    scaled_height = int(height * scale_factor)
                for attempt in range(50):
                    target_x, target_y = self._get_random_position(img_height, img_width, scaled_width, scaled_height, temp_grid_mask, margin=5)
                    if target_x is None or target_y is None:
                        continue
                    distance = np.sqrt((target_x - source_x)**2 + (target_y - source_y)**2)
                    if distance < 30:
                        continue
                    placed_copies.append((target_x, target_y, scale_factor, origin_rot, probe_rot))
                    self._mark_grid_cells(temp_grid_mask, target_x, target_y, scaled_width, scaled_height, img_height, img_width, margin=5)
                    break
            if len(placed_copies) > 0:
                successful_forgeries.append({
                    'source_mask': source_mask, 'source_x': source_x, 'source_y': source_y,
                    'width': width, 'height': height, 'copies': placed_copies, 'region_type': region_type})
                grid_mask = temp_grid_mask
        for forgery in successful_forgeries:
            source_mask = forgery['source_mask']
            copies_applied = 0
            for copy_info in forgery['copies']:
                if len(copy_info) == 5:
                    target_x, target_y, scale_factor, origin_rot, probe_rot = copy_info
                else:
                    target_x, target_y, scale_factor = copy_info
                    origin_rot, probe_rot = 0.0, 0.0
                img, mask_uint8, success = self._apply_copy(img, mask_uint8, source_mask, target_x, target_y, scale_factor,
                    region_type=forgery['region_type'], width=forgery['width'], height=forgery['height'],
                    origin_rotation=origin_rot, probe_rotation=probe_rot)
                if success:
                    copies_applied += 1
            if copies_applied > 0:
                mask_uint8 = np.maximum(mask_uint8, source_mask)
                weight_map = np.where(source_mask > 0, self.origin_weight, weight_map)
                probe_mask = (mask_uint8 > 0) & (source_mask == 0)
                weight_map = np.where(probe_mask, self.probe_weight, weight_map)
        mask = mask_uint8.astype(np.float32) / 255.0
        return img, mask, weight_map

class UnifiedDataset(Dataset):
    """
    Unified dataset loader for combined_dataset structure.
    Supports both old structure (subfolder/images/, subfolder/masks/) and 
    new structure (subfolder/train/images/, subfolder/test/images/).
    """
    def __init__(self, dataset_root, img_size=(512, 512), transform=None, joint_transform=None, 
                 mask_blur_radius=0.0, split='train', datasets_to_load=None, datasets_to_skip=None,
                 filter_by_mask_area=False, min_mask_area_percent=0.01, max_mask_area_percent=0.10, synthetic_forgery=None):
        """
        Args:
            dataset_root: Root folder containing subfolders (e.g., 'combined_dataset')
                         Each subfolder should have either:
                         - Old: images/ (PNG) and masks/ (NPY) folders
                         - New: train/ and test/ subfolders, each with images/ and masks/
            img_size: Target size for resizing images and masks (height, width)
            transform: Image-only transforms (brightness, contrast, etc.)
            joint_transform: Spatial transforms applied to both image and mask
            mask_blur_radius: Gaussian blur radius for mask label smoothing (0.0 = disabled)
            split: 'train', 'test', or 'all' - which split to load (default: 'train')
            datasets_to_load: List of dataset subfolder names to include (None = all)
            datasets_to_skip: List of dataset subfolder names to exclude (applied after datasets_to_load)
            filter_by_mask_area: Enable mask area filtering (default: False)
            min_mask_area_percent: Minimum mask area as fraction (0.01 = 1%, default: 0.01)
            max_mask_area_percent: Maximum mask area as fraction (0.10 = 10%, default: 0.10)
        """
        self.dataset_root = Path(dataset_root)
        self.img_size = img_size
        self.transform = transform
        self.joint_transform = joint_transform
        self.mask_blur_radius = mask_blur_radius
        self.synthetic_forgery = synthetic_forgery
        self.split = split
        self.datasets_to_load = datasets_to_load
        self.datasets_to_skip = datasets_to_skip or []
        
        # Mask area filtering parameters
        self.filter_by_mask_area = filter_by_mask_area
        self.min_mask_area_percent = min_mask_area_percent
        self.max_mask_area_percent = max_mask_area_percent
        
        # Find all image-mask pairs from all subfolders
        self.image_paths = []
        self.mask_paths = []
        self.loaded_datasets = []  # Track which datasets were actually loaded
        self.filtered_out_count = 0  # Track how many pairs were filtered out
        
        # Iterate through all subfolders in dataset_root
        for subfolder in self.dataset_root.iterdir():
            if not subfolder.is_dir():
                continue
            
            # Get subfolder name for filtering
            subfolder_name = subfolder.name
            
            # Apply dataset filtering
            # First check if we should load this dataset
            if self.datasets_to_load is not None:
                if subfolder_name not in self.datasets_to_load:
                    continue  # Skip if not in the list
            
            # Then check if we should skip this dataset
            if subfolder_name in self.datasets_to_skip:
                continue  # Skip if in the skip list
            
            # Track that we're loading this dataset
            if subfolder_name not in self.loaded_datasets:
                self.loaded_datasets.append(subfolder_name)
            
            # Check for new structure (train/test subfolders)
            if (subfolder / 'train').exists() or (subfolder / 'test').exists():
                # New structure: subfolder/train/ and subfolder/test/
                splits_to_load = []
                if split == 'train':
                    splits_to_load = ['train']
                elif split == 'test':
                    splits_to_load = ['test']
                elif split == 'all':
                    splits_to_load = ['train', 'test']
                
                for split_name in splits_to_load:
                    split_dir = subfolder / split_name
                    if not split_dir.exists():
                        continue
                    
                    images_dir = split_dir / 'images'
                    masks_dir = split_dir / 'masks'
                    
                    if not images_dir.exists() or not masks_dir.exists():
                        continue
                    
                    # Find all PNG images in images/ folder
                    for img_path in images_dir.glob('*.png'):
                        # Find corresponding mask (same filename but .npy extension)
                        mask_path = masks_dir / f"{img_path.stem}.npy"
                        
                        if mask_path.exists():
                            # Check if pair should be included based on mask area filtering
                            if self._should_include_pair(mask_path):
                                self.image_paths.append(str(img_path))
                                self.mask_paths.append(str(mask_path))
                            else:
                                self.filtered_out_count += 1
            else:
                # Old structure: subfolder/images/ and subfolder/masks/
                images_dir = subfolder / 'images'
                masks_dir = subfolder / 'masks'
                
                if images_dir.exists() and masks_dir.exists():
                    # Find all PNG images in images/ folder
                    for img_path in images_dir.glob('*.png'):
                        # Find corresponding mask (same filename but .npy extension)
                        mask_path = masks_dir / f"{img_path.stem}.npy"
                        
                        if mask_path.exists():
                            # Check if pair should be included based on mask area filtering
                            if self._should_include_pair(mask_path):
                                self.image_paths.append(str(img_path))
                                self.mask_paths.append(str(mask_path))
                            else:
                                self.filtered_out_count += 1
        
        # Sort for consistent ordering
        combined = list(zip(self.image_paths, self.mask_paths))
        combined.sort()
        self.image_paths, self.mask_paths = zip(*combined)
        self.image_paths = list(self.image_paths)
        self.mask_paths = list(self.mask_paths)
        
        num_subfolders = len(set(Path(p).parent.parent.parent if 'train' in str(p) or 'test' in str(p) else Path(p).parent.parent for p in self.image_paths))
        
        # Print dataset loading information
        if self.loaded_datasets:
            datasets_str = ', '.join(sorted(self.loaded_datasets))
            print(f"Found {len(self.image_paths)} image-mask pairs from {num_subfolders} subfolders (split: {split})")
            print(f"  Loaded datasets: {datasets_str}")
            if self.datasets_to_load is not None:
                print(f"  (Filtered to: {self.datasets_to_load})")
            if self.datasets_to_skip:
                print(f"  (Excluded: {self.datasets_to_skip})")
            if self.filter_by_mask_area:
                print(f"  Mask area filter: {self.min_mask_area_percent*100:.1f}% - {self.max_mask_area_percent*100:.1f}%")
                if self.filtered_out_count > 0:
                    print(f"  Filtered out: {self.filtered_out_count} pairs (outside area range)")
        else:
            print(f"⚠ Warning: No datasets loaded! Check DATASETS_TO_LOAD and DATASETS_TO_SKIP settings.")
            print(f"  Available subfolders: {[d.name for d in self.dataset_root.iterdir() if d.is_dir()]}")
    
    def _calculate_mask_area_percent(self, mask_path):
        """
        Calculate mask area as percentage of image area.
        
        Args:
            mask_path: Path to .npy mask file
        
        Returns:
            Mask area percentage (0.0 to 1.0, where 0.01 = 1%)
        """
        try:
            # Load mask from .npy file
            mask = np.load(mask_path)
            
            # Handle multi-channel masks: OR all channels together
            if len(mask.shape) > 2:
                mask = np.any(mask > 0, axis=0).astype(np.uint8)
            else:
                # Single channel mask - convert to binary
                if mask.dtype != np.uint8:
                    if mask.max() <= 1.0:
                        mask = (mask > 0).astype(np.uint8)
                    else:
                        mask = (mask > 0).astype(np.uint8)
                else:
                    mask = (mask > 0).astype(np.uint8)
            
            # Calculate area percentage
            total_pixels = mask.size
            mask_pixels = mask.sum()
            area_percent = mask_pixels / total_pixels if total_pixels > 0 else 0.0
            
            return area_percent
        except Exception as e:
            # If mask loading fails, return 0 (will be treated as authentic/empty)
            return 0.0
    
    def _should_include_pair(self, mask_path):
        """
        Check if image-mask pair should be included based on mask area filtering.
        
        Args:
            mask_path: Path to mask file
        
        Returns:
            True if pair should be included, False otherwise
        """
        if not self.filter_by_mask_area:
            return True  # No filtering, include all
        
        # Calculate mask area percentage
        area_percent = self._calculate_mask_area_percent(mask_path)
        
        # Authentic images (empty masks) are always included
        if area_percent == 0.0:
            return True
        
        # Check if area is within the specified range
        return self.min_mask_area_percent <= area_percent <= self.max_mask_area_percent
    
    def _apply_mask_blur(self, mask):
        """Apply Gaussian blur to mask if blur radius > 0"""
        if self.mask_blur_radius > 0:
            import math
            kernel_size = int(2 * math.ceil(3 * self.mask_blur_radius) + 1)
            if kernel_size % 2 == 0:
                kernel_size += 1
            mask = cv2.GaussianBlur(mask, (kernel_size, kernel_size), self.mask_blur_radius)
        return mask
    
    def _load_and_process_mask(self, mask_path):
        """
        Load mask from .npy file, handle multi-channel by ORing, resize to target size.
        
        Args:
            mask_path: Path to .npy mask file
        
        Returns:
            Processed mask as numpy array (H, W) in range [0, 1]
        """
        # Load mask from .npy file
        mask = np.load(mask_path)
        
        # Handle multi-channel masks: OR all channels together
        if len(mask.shape) > 2:
            # OR operation across channels: if any channel has a pixel > 0, result is 1
            mask = np.any(mask > 0, axis=0).astype(np.uint8) * 255
        else:
            # Single channel mask
            if mask.dtype != np.uint8:
                # Normalize to uint8 if needed
                if mask.max() <= 1.0:
                    mask = (mask * 255).astype(np.uint8)
                else:
                    mask = mask.astype(np.uint8)
        
        # Resize to target size
        mask = cv2.resize(mask, self.img_size, interpolation=cv2.INTER_NEAREST)
        
        # Apply Gaussian blur for label smoothing if enabled
        mask = self._apply_mask_blur(mask)
        
        # Normalize to [0, 1] range
        mask = mask.astype(np.float32) / 255.0
        
        return mask
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image
        img = cv2.imread(self.image_paths[idx], cv2.IMREAD_UNCHANGED)
        if img is None:
            raise ValueError(f"Failed to load image: {self.image_paths[idx]}")
        
        # Handle different image formats (RGB, grayscale, etc.)
        if len(img.shape) == 2:
            # Grayscale image, convert to RGB
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        elif img.shape[2] == 4:
            # RGBA image, convert to RGB
            img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
        else:
            # BGR image, convert to RGB
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Resize image to target size
        # Use INTER_AREA for better quality when downsampling (reduces blur)
        img = cv2.resize(img, self.img_size, interpolation=cv2.INTER_AREA)
        img = img.astype(np.float32) / 255.0
        
        # Load and process mask
        mask = self._load_and_process_mask(self.mask_paths[idx])
        
        # Initialize weight map (default: all 1.0)
        weight_map = np.ones_like(mask, dtype=np.float32)
        
        # Apply synthetic copy-move forgery augmentation (before converting to tensors)
        # This only affects authentic images with empty masks
        # Returns: img, mask, weight_map
        if self.synthetic_forgery is not None:
            img, mask, weight_map = self.synthetic_forgery(img, mask)
        
        # Apply size-based weighting if enabled
        if config.USE_SIZE_BASED_WEIGHTING:
            weight_map = apply_size_based_weighting(
                weight_map, mask,
                min_area=config.SIZE_WEIGHT_MIN_AREA,
                max_area=config.SIZE_WEIGHT_MAX_AREA,
                min_weight=config.SIZE_WEIGHT_MIN,
                max_weight=config.SIZE_WEIGHT_MAX
            )
        
        # Convert to PyTorch tensors: (H, W, C) -> (C, H, W)
        img = torch.from_numpy(img).permute(2, 0, 1)  # (3, H, W)
        mask = torch.from_numpy(mask).unsqueeze(0)  # (1, H, W)
        weight_map = torch.from_numpy(weight_map).unsqueeze(0)  # (1, H, W)
        
        # Apply joint spatial transformations (rotation, flip) to both image and mask
        if self.joint_transform:
            img, mask = self.joint_transform(img, mask)
            # Note: weight_map should follow mask transformations
            weight_map = self.joint_transform.transform_mask_only(weight_map)
        
        # Apply image-only transformations (brightness, contrast, saturation)
        if self.transform:
            img = self.transform(img)
        
        # Normalization logic:
        # 1. If USE_PER_IMAGE_MINMAX=True, apply min-max normalization first
        # 2. Then apply either per-image z-score OR ImageNet normalization based on USE_PER_IMAGE_ZSCORE
        
        # Step 1: Per-image min-max normalization (optional)
        if config.USE_PER_IMAGE_MINMAX:
            # Image is in shape (C, H, W) and range [0, 1]
            # Flatten spatial dimensions for percentile calculation: (C, H*W)
            img_flat = img.view(3, -1)  # (3, H*W)
            
            if config.PER_IMAGE_PERCENTILE_CLIP:
                # Use percentile clipping to remove 1% extremes (robust to outliers)
                # Clips bottom 1% and top 1% (1st to 99th percentile)
                lower_percentile = config.PER_IMAGE_LOWER_PERCENTILE  # 1%
                upper_percentile = config.PER_IMAGE_UPPER_PERCENTILE  # 99%
                
                # Compute percentiles for each channel
                k_lower = max(1, int(lower_percentile / 100.0 * img_flat.shape[1]))
                k_upper = max(1, int(upper_percentile / 100.0 * img_flat.shape[1]))
                
                # Get percentile values per channel
                img_sorted, _ = torch.sort(img_flat, dim=1)
                min_vals = img_sorted[:, k_lower - 1].unsqueeze(1)  # (3, 1) - 1st percentile
                max_vals = img_sorted[:, k_upper - 1].unsqueeze(1)  # (3, 1) - 99th percentile
            else:
                # Use actual min/max per channel (no clipping)
                min_vals = img_flat.min(dim=1, keepdim=True)[0]  # (3, 1)
                max_vals = img_flat.max(dim=1, keepdim=True)[0]  # (3, 1)
            
            # Reshape for broadcasting: (3, 1, 1)
            min_vals = min_vals.view(3, 1, 1)
            max_vals = max_vals.view(3, 1, 1)
            
            # Avoid division by zero
            range_vals = max_vals - min_vals
            range_vals = torch.clamp(range_vals, min=1e-6)
            
            # Min-max normalization: (img - min) / (max - min) -> [0, 1]
            img = (img - min_vals) / range_vals
            
            # Clip to [0, 1] to handle any numerical issues
            img = torch.clamp(img, 0.0, 1.0)
        
        # Step 2: Apply either per-image z-score OR ImageNet normalization
        if config.USE_PER_IMAGE_ZSCORE:
            # Per-image z-score normalization
            # Calculate mean and std per channel
            # img is (C, H, W), compute statistics across spatial dimensions
            mean_per_channel = img.view(3, -1).mean(dim=1, keepdim=True)  # (3, 1)
            std_per_channel = img.view(3, -1).std(dim=1, keepdim=True)  # (3, 1)
            
            # Reshape for broadcasting: (3, 1, 1)
            mean_per_channel = mean_per_channel.view(3, 1, 1)
            std_per_channel = std_per_channel.view(3, 1, 1)
            
            # Avoid division by zero
            std_per_channel = torch.clamp(std_per_channel, min=1e-6)
            
            # Z-score normalization: (img - mean) / std
            img = (img - mean_per_channel) / std_per_channel
            
            # Clamp to [-3, 3] after z-normalization to prevent extreme values
            img = torch.clamp(img, min=-3.0, max=3.0)
        else:
            # ImageNet normalization (standard normalization for pretrained models)
            # Mean and std for ImageNet normalization
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            img = (img - mean) / std
            
            # Clip to [-3, 3] after normalization to prevent extreme values
            img = torch.clamp(img, min=-3.0, max=3.0)
        
        # Classification label: 1.0 if mask is empty (forged), 0.0 if mask has pixels (original)
        # Note: Forgeries have empty masks (all zeros), originals have non-empty masks
        has_forgery = mask.sum() == 0  # Empty mask = forged
        class_label = torch.tensor(1.0 if has_forgery else 0.0, dtype=torch.float32)
        
        return img, mask, class_label, weight_map

# Create augmentation transforms
if config.USE_AUGMENTATION:
    # Joint spatial transforms (applied to both image and mask)
    # Only apply if individual switches are enabled
    train_joint_transform = JointTransform(
        rotation=config.AUG_ROTATION if config.AUG_ENABLE_ROTATION else 0,
        hflip_prob=config.AUG_HFLIP if config.AUG_ENABLE_HFLIP else 0.0,
        vflip_prob=config.AUG_VFLIP if config.AUG_ENABLE_VFLIP else 0.0,
        enable_rotation=config.AUG_ENABLE_ROTATION,
        enable_hflip=config.AUG_ENABLE_HFLIP,
        enable_vflip=config.AUG_ENABLE_VFLIP
    )
    
    # Image-only color transforms (only applied to image)
    # Only apply if color augmentation is enabled
    if config.AUG_ENABLE_COLOR:
        train_image_transform = transforms.Compose([
            transforms.ColorJitter(
                brightness=config.AUG_BRIGHTNESS,
                contrast=config.AUG_CONTRAST,
                saturation=config.AUG_SATURATION,
                hue=0.0  # Keep hue unchanged for forgery detection
            )
        ])
    else:
        train_image_transform = None
    
    # Validation: no augmentations
    val_joint_transform = None
    val_image_transform = None
else:
    # Master switch disabled - no augmentations
    train_joint_transform = None
    train_image_transform = None
    val_joint_transform = None
    val_image_transform = None

# Print augmentation status
print(f"\nAugmentation Configuration:")
print(f"  Master Switch: {'ON' if config.USE_AUGMENTATION else 'OFF'}")
if config.USE_AUGMENTATION:
    print(f"  Spatial Augmentations:")
    print(f"    Rotation: {'ON' if config.AUG_ENABLE_ROTATION else 'OFF'} (angle: ±{config.AUG_ROTATION}°)")
    print(f"    Horizontal Flip: {'ON' if config.AUG_ENABLE_HFLIP else 'OFF'} (prob: {config.AUG_HFLIP})")
    print(f"    Vertical Flip: {'ON' if config.AUG_ENABLE_VFLIP else 'OFF'} (prob: {config.AUG_VFLIP})")
    print(f"  Color Augmentations: {'ON' if config.AUG_ENABLE_COLOR else 'OFF'}")
    if config.AUG_ENABLE_COLOR:
        print(f"    Brightness: ±{config.AUG_BRIGHTNESS}")
        print(f"    Contrast: ±{config.AUG_CONTRAST}")
        print(f"    Saturation: ±{config.AUG_SATURATION}")


# Print normalization configuration
print(f"\nNormalization Configuration:")
if config.USE_PER_IMAGE_MINMAX:
    print(f"  Per-Image Min-Max: ON")
    print(f"    Percentile Clipping: {'ON' if config.PER_IMAGE_PERCENTILE_CLIP else 'OFF'}")
    if config.PER_IMAGE_PERCENTILE_CLIP:
        print(f"    Percentiles: {config.PER_IMAGE_LOWER_PERCENTILE}%-{config.PER_IMAGE_UPPER_PERCENTILE}%")
else:
    print(f"  Per-Image Min-Max: OFF")
if config.USE_PER_IMAGE_ZSCORE:
    print(f"  Per-Image Z-Score: ON")
else:
    print(f"  Per-Image Z-Score: OFF")
if config.USE_PER_IMAGE_ZSCORE:
    print(f"  ImageNet Normalization: OFF (replaced by per-image z-score normalization)")
else:
    print(f"  ImageNet Normalization: ON")

# Create synthetic copy-move forgery augmentation (only for training)
# Use default values if config doesn't have EXTRA_FORGERIES parameters
try:
    train_synthetic_forgery = SyntheticCopyMoveForgery(
        prob=getattr(config, 'EXTRA_FORGERIES_PROB', 0.75),
        min_regions=getattr(config, 'EXTRA_FORGERIES_MIN_REGIONS', 3),
        max_regions=getattr(config, 'EXTRA_FORGERIES_MAX_REGIONS', 5),
        min_copies=getattr(config, 'EXTRA_FORGERIES_MIN_COPIES', 1),
        max_copies=getattr(config, 'EXTRA_FORGERIES_MAX_COPIES', 1),
        min_width=getattr(config, 'EXTRA_FORGERIES_MIN_WIDTH', 20),
        max_width=getattr(config, 'EXTRA_FORGERIES_MAX_WIDTH', 80),
        min_aspect=getattr(config, 'EXTRA_FORGERIES_MIN_ASPECT', 0.5),
        max_aspect=getattr(config, 'EXTRA_FORGERIES_MAX_ASPECT', 2.0),
        min_scale=getattr(config, 'EXTRA_FORGERIES_MIN_SCALE', 0.8),
        max_scale=getattr(config, 'EXTRA_FORGERIES_MAX_SCALE', 1.2),
        origin_weight=getattr(config, 'EXTRA_FORGERIES_ORIGIN_WEIGHT', 3.0),
        probe_weight=getattr(config, 'EXTRA_FORGERIES_PROBE_WEIGHT', 1.0),
        blur_border=getattr(config, 'EXTRA_FORGERIES_BLUR_BORDER', True),
        origin_rotation=getattr(config, 'EXTRA_FORGERIES_ORIGIN_ROTATION', True),
        origin_rotation_range=getattr(config, 'EXTRA_FORGERIES_ORIGIN_ROTATION_RANGE', 90),
        probe_rotation=getattr(config, 'EXTRA_FORGERIES_PROBE_ROTATION', True),
        probe_rotation_range=getattr(config, 'EXTRA_FORGERIES_PROBE_ROTATION_RANGE', 90)
    )
    print(f"\nSynthetic Copy-Move Forgery Augmentation:")
    print(f"  Enabled for training: YES")
    print(f"  Probability: {getattr(config, 'EXTRA_FORGERIES_PROB', 0.75)}")
    print(f"  Regions: {getattr(config, 'EXTRA_FORGERIES_MIN_REGIONS', 3)}-{getattr(config, 'EXTRA_FORGERIES_MAX_REGIONS', 5)}")
    print(f"  Copies per region: {getattr(config, 'EXTRA_FORGERIES_MIN_COPIES', 1)}-{getattr(config, 'EXTRA_FORGERIES_MAX_COPIES', 1)}")
    print(f"  Origin weight: {getattr(config, 'EXTRA_FORGERIES_ORIGIN_WEIGHT', 3.0)}x | Probe weight: {getattr(config, 'EXTRA_FORGERIES_PROBE_WEIGHT', 1.0)}x")
except Exception as e:
    train_synthetic_forgery = None
    print(f"\nSynthetic Copy-Move Forgery Augmentation: Disabled (error: {e})")

# ============================================================================
# DATASET INITIALIZATION
# ============================================================================
# Training dataset with augmentations (loads from train/ subfolders)
train_dataset = UnifiedDataset(
    dataset_root=config.DATASET_ROOT,
    img_size=config.IMG_SIZE,
    transform=train_image_transform,
    joint_transform=train_joint_transform,
    mask_blur_radius=config.MASK_GAUSSIAN_BLUR_RADIUS,
    split='train',  # Load from train/ subfolders
    datasets_to_load=config.DATASETS_TO_LOAD,
    datasets_to_skip=config.DATASETS_TO_SKIP,
    filter_by_mask_area=config.FILTER_BY_MASK_AREA,
    min_mask_area_percent=config.MIN_MASK_AREA_PERCENT,
    max_mask_area_percent=config.MAX_MASK_AREA_PERCENT,
    synthetic_forgery=train_synthetic_forgery
)

# Validation dataset without augmentations (no synthetic forgeries, loads from test/ subfolders)
val_dataset = UnifiedDataset(
    dataset_root=config.DATASET_ROOT,
    img_size=config.IMG_SIZE,
    transform=val_image_transform,  # None (no color augmentation)
    joint_transform=val_joint_transform,  # None (no spatial augmentation)
    mask_blur_radius=0.0,  # No mask blur for validation
    split='test',  # Load from test/ subfolders
    datasets_to_load=config.DATASETS_TO_LOAD,
    datasets_to_skip=config.DATASETS_TO_SKIP,
    filter_by_mask_area=config.FILTER_BY_MASK_AREA,
    min_mask_area_percent=config.MIN_MASK_AREA_PERCENT,
    max_mask_area_percent=config.MAX_MASK_AREA_PERCENT,
    synthetic_forgery=None  # No synthetic forgeries for validation
)

# Create full dataset for statistics (loads from both train and test)
full_dataset = UnifiedDataset(
    dataset_root=config.DATASET_ROOT,
    img_size=config.IMG_SIZE,
    transform=None,
    joint_transform=None,
    mask_blur_radius=0.0,
    split='all',  # Load from both train/ and test/ subfolders
    datasets_to_load=config.DATASETS_TO_LOAD,
    datasets_to_skip=config.DATASETS_TO_SKIP,
    filter_by_mask_area=config.FILTER_BY_MASK_AREA,
    min_mask_area_percent=config.MIN_MASK_AREA_PERCENT,
    max_mask_area_percent=config.MAX_MASK_AREA_PERCENT
)

print(f"\nTotal images (train + test): {len(full_dataset)}")
# Efficiently count forged vs original by checking mask files directly (without loading images)
# Note: Forgeries have empty masks (all zeros), originals have non-empty masks
forged_count = 0
for mask_path in full_dataset.mask_paths:
    mask = np.load(mask_path)
    # Check if mask is empty (all zeros) - empty mask = forged image
    if not np.any(mask > 0):
        forged_count += 1
original_count = len(full_dataset) - forged_count
print(f"  Images with forgeries: {forged_count}")
print(f"  Images without forgeries: {original_count}")

print(f"\nTrain dataset: {len(train_dataset)} images")
print(f"Val dataset: {len(val_dataset)} images")

# Test one sample
sample_img, sample_mask, sample_class, sample_weight = train_dataset[0]
print(f"\nSample shape - Image: {sample_img.shape}, Mask: {sample_mask.shape}, Class: {sample_class}, Weight: {sample_weight.shape}")
print(f"Image dtype: {sample_img.dtype}, Mask dtype: {sample_mask.dtype}, Class dtype: {sample_class.dtype}, Weight dtype: {sample_weight.dtype}")

# ============================================================================
# DATASET VISUALIZATION
# ============================================================================
import matplotlib.pyplot as plt

print("\n" + "=" * 60)
print("DATASET VISUALIZATION")
print("=" * 60)

# Visualize first 5 samples from the dataset
num_samples = 5
fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))

# Simply take first 5 samples from the top
selected_indices = list(range(min(num_samples, len(full_dataset))))

for row_idx, dataset_idx in enumerate(selected_indices):
    img, mask, class_label, weight_map = full_dataset[dataset_idx]
    
    # Convert to numpy for visualization
    img_np = img.permute(1, 2, 0).cpu().numpy()  # (C, H, W) -> (H, W, C)
    mask_np = mask.squeeze().cpu().numpy()  # (1, H, W) -> (H, W)
    
    # Denormalize image
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_np = img_np * std + mean
    img_np = np.clip(img_np, 0, 1)
    
    # Determine image type based on mask content
    # Note: Forgeries have empty masks (all zeros), originals have non-empty masks
    is_forged = not mask_np.any()  # Empty mask = forged
    image_type = "Forged" if is_forged else "Original/Authentic"
    mask_pixels = mask_np.sum()  # Only compute sum for display
    
    # Original image
    axes[row_idx, 0].imshow(img_np)
    axes[row_idx, 0].set_title(f'{image_type} Image\n({mask_pixels:.0f} mask pixels)')
    axes[row_idx, 0].axis('off')
    
    # Ground truth mask
    axes[row_idx, 1].imshow(mask_np, cmap='gray')
    axes[row_idx, 1].set_title(f'Ground Truth Mask\n({mask_pixels:.0f} pixels)')
    axes[row_idx, 1].axis('off')
    
    # Overlay
    overlay = img_np.copy()
    mask_colored = np.zeros_like(overlay)
    mask_colored[:, :, 0] = mask_np  # Red channel
    overlay = np.clip(overlay * 0.7 + mask_colored * 0.3, 0, 1)
    axes[row_idx, 2].imshow(overlay)
    axes[row_idx, 2].set_title(f'{image_type} + Mask Overlay')
    axes[row_idx, 2].axis('off')

plt.tight_layout()
plt.savefig('dataset_samples_visualization.jpg', dpi=150, bbox_inches='tight', format='jpg')
print("Dataset visualization saved to: dataset_samples_visualization.jpg")
plt.show()

In [None]:
# ============================================================================
# DATASET SPLIT AND SETUP
# ============================================================================

from torch.utils.data import DataLoader
import torch

print("=" * 60)
print("DATASET SPLIT AND SETUP")
print("=" * 60)

# Use train_dataset for training (with augmentations and synthetic forgeries)
# Use val_dataset for validation (clean, no augmentations, from test/ folders)
train_dataset_split = train_dataset
val_dataset_split = val_dataset

print(f"\nDataset Split:")
print(f"  Train: {len(train_dataset_split)} images (from train/ folders with augmentations)")
print(f"  Val: {len(val_dataset_split)} images (from test/ folders, no augmentations)")
print(f"  Note: Train uses augmented data with synthetic forgeries, Val uses clean test data")

# Create dataloaders (will be updated by curriculum learning if enabled)
train_loader = DataLoader(
    train_dataset_split,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    num_workers=config.NUM_WORKERS,
    pin_memory=config.PIN_MEMORY
)

val_loader = DataLoader(
    val_dataset_split,
    batch_size=config.BATCH_SIZE,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=config.PIN_MEMORY
)

print(f"\nDataLoaders created:")
print(f"  Train loader: {len(train_loader)} batches")
print(f"  Validation loader: {len(val_loader)} batches")

# Initialize training history
history = {
    'train_loss': [],
    'train_iou': [],
    'train_f1': [],
    'val_loss': [],
    'val_iou': [],
    'val_f1': []
}

# Initialize early stopping parameters
patience = config.PATIENCE
patience_counter = 0
best_val_loss = float('inf')
best_model_state = None

print("\n" + "=" * 60)
print("Setup complete! Ready for training.")
print("=" * 60)



In [None]:
import torch
import torch.optim as optim
from torch.amp import autocast, GradScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import os
import gc

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Setup AMP (Automatic Mixed Precision)
if torch.cuda.is_available():
    if torch.cuda.is_bf16_supported():
        amp_dtype = torch.bfloat16
        print(f"✓ GPU supports bfloat16 - Using AMP with bfloat16")
    else:
        amp_dtype = torch.float16
        print(f"⚠ GPU does not support bfloat16 - Falling back to float16")
else:
    amp_dtype = None
    print(f"⚠ CUDA not available - AMP disabled")

# Create GradScaler for AMP
scaler = GradScaler('cuda') if device.type == 'cuda' else None

# Create output directories
os.makedirs(config.OUTPUT_DIR, exist_ok=True)
os.makedirs(config.VIZ_DIR, exist_ok=True)
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)

print(f"\nOutput directories created:")
print(f"  Main output: {config.OUTPUT_DIR}/")
print(f"  Visualizations: {config.VIZ_DIR}/")
print(f"  Checkpoints: {config.CHECKPOINT_DIR}/")

# Use loss functions, metrics, and visualization functions from helper cells (Cells 3 and 4)
# Loss functions: CombinedLossWithDeepSupervision (Cell 3)
# Metrics: calculate_iou, calculate_f1_score, calculate_metrics (Cell 3)
# Visualization functions: save_visualization (Cell 4)
    
# Create loss function and optimizer
criterion = CombinedLossWithDeepSupervision(
    focal_weight=config.FOCAL_WEIGHT,
    dice_weight=config.DICE_WEIGHT,
    dice_smooth=config.DICE_SMOOTH,
    focal_alpha=config.FOCAL_ALPHA,
    focal_gamma=config.FOCAL_GAMMA,
    deep_supervision_weights=None,  # Deep supervision disabled
    use_dice_loss=config.USE_DICE_LOSS
)

if config.OPTIMIZER.lower() == 'adam':
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
elif config.OPTIMIZER.lower() == 'adamw':
    optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
elif config.OPTIMIZER.lower() == 'sgd':
    optimizer = optim.SGD(model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY, momentum=0.9)
else:
    raise ValueError(f"Unknown optimizer: {config.OPTIMIZER}. Use 'adam', 'adamw', or 'sgd'.")

# ============================================================================
# TRAINING LOOP
# ============================================================================

print("=" * 60)
print("STARTING TRAINING")
print("=" * 60)
print(f"Total epochs: {config.NUM_EPOCHS}")
print("=" * 60)

for epoch in range(config.NUM_EPOCHS):
    # Training phase
    model.train()
    train_loss = 0.0
    train_iou_sum = 0.0
    train_f1_sum = 0.0
    train_batches = 0
    train_images_batch = None
    train_masks_batch = None
    train_outputs_batch = None
    
    train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config.NUM_EPOCHS} [Train]')
    
    # Get total number of batches for last batch capture
    total_train_batches = len(train_loader)
    
    for batch_idx, (images, masks, class_labels, weight_maps) in enumerate(train_pbar):
        images = images.to(device)
        masks = masks.to(device)
        weight_maps = weight_maps.to(device)
        
        # Forward pass with AMP (Automatic Mixed Precision)
        optimizer.zero_grad()
        
        # Use autocast context for forward pass (uses bfloat16 or float16)
        with autocast(device_type='cuda', dtype=amp_dtype):
            mask_outputs = model(images)
            
            # Calculate loss (segmentation only) with weight maps
            loss = criterion(mask_outputs, masks, weight_maps)
        
        # Backward pass
        # Use scaler for gradient scaling (important for fp16, optional for bf16)
        if scaler is not None:
            scaler.scale(loss).backward()
            # Unscale gradients before clipping (required when using scaler)
            scaler.unscale_(optimizer)
            # Clip gradient norm to 1.0
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            # Clip gradient norm to 1.0
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        
        # Save second-to-last batch for visualization
        if total_train_batches > 1 and batch_idx == total_train_batches - 2:
            train_images_batch = images.detach().clone()
            train_masks_batch = masks.detach().clone()
            # Handle deep supervision outputs (tuple) or single output (tensor)
            if isinstance(mask_outputs, tuple):
                train_outputs_batch = torch.sigmoid(mask_outputs[0].detach().clone())  # Use main output
            else:
                train_outputs_batch = torch.sigmoid(mask_outputs.detach().clone())  # Convert logits to probs
        
        # Calculate metrics using helper functions
        train_loss += loss.item()
        iou, f1 = calculate_metrics(mask_outputs, masks, threshold=config.PREDICTION_THRESHOLD)
        
        train_iou_sum += iou.item()
        train_f1_sum += f1.item()
        
        train_batches += 1
        
        current_lr = optimizer.param_groups[0]['lr']
        train_pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'iou': f'{iou.item():.4f}',
            'f1': f'{f1.item():.4f}',
            'lr': f'{current_lr:.2e}'
        })
        
        # Periodic cache clearing during batch processing (every 50 batches)
        # Prevents memory buildup during long epochs
        if (batch_idx + 1) % 10 == 0:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()
    
    avg_train_loss = train_loss / len(train_loader)
    train_iou = train_iou_sum / train_batches if train_batches > 0 else 0.0
    train_f1 = train_f1_sum / train_batches if train_batches > 0 else 0.0
    
    history['train_loss'].append(avg_train_loss)
    history['train_iou'].append(train_iou)
    if 'train_f1' not in history:
        history['train_f1'] = []
    history['train_f1'].append(train_f1)
    
    # Save training visualizations (should always be available since we save batch 0)
    if train_images_batch is not None:
        try:
            save_visualization(train_images_batch, train_masks_batch, train_outputs_batch, 
                             epoch + 1, split='train')
        except Exception as e:
            print(f"  Warning: Failed to save training visualizations: {e}")
    else:
        print(f"  Warning: No training batch captured for visualization in epoch {epoch + 1}")
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    val_iou_sum = 0.0
    val_f1_sum = 0.0
    val_batches = 0
    val_images_batch = None
    val_masks_batch = None
    val_outputs_batch = None
    
    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{config.NUM_EPOCHS} [Val]')
        
        # Get total number of batches for last batch capture
        total_val_batches = len(val_loader)
        
        for batch_idx, (images, masks, class_labels, weight_maps) in enumerate(val_pbar):
            images = images.to(device)
            masks = masks.to(device)
            weight_maps = weight_maps.to(device)
            
            # Use autocast for validation forward pass as well
            with autocast(device_type='cuda', dtype=amp_dtype):
                mask_outputs = model(images)
                loss = criterion(mask_outputs, masks, weight_maps)
            
            # Save second-to-last batch for visualization
            if total_val_batches > 1 and batch_idx == total_val_batches - 2:
                val_images_batch = images.clone()  # Already in no_grad context
                val_masks_batch = masks.clone()
                # Handle deep supervision outputs (tuple) or single output (tensor)
                if isinstance(mask_outputs, tuple):
                    val_outputs_batch = torch.sigmoid(mask_outputs[0].clone())  # Use main output
                else:
                    val_outputs_batch = torch.sigmoid(mask_outputs.clone())  # Convert logits to probs
            
            val_loss += loss.item()
            
            # Calculate metrics using helper functions
            iou, f1 = calculate_metrics(mask_outputs, masks, threshold=config.PREDICTION_THRESHOLD)
            
            val_iou_sum += iou.item()
            val_f1_sum += f1.item()
            
            val_batches += 1
            
            current_lr = optimizer.param_groups[0]['lr']
            val_pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'iou': f'{iou.item():.4f}',
                'f1': f'{f1.item():.4f}',
                'lr': f'{current_lr:.2e}'
            })
    
    avg_val_loss = val_loss / len(val_loader)
    val_iou = val_iou_sum / val_batches if val_batches > 0 else 0.0
    val_f1 = val_f1_sum / val_batches if val_batches > 0 else 0.0
    
    history['val_loss'].append(avg_val_loss)
    history['val_iou'].append(val_iou)
    if 'val_f1' not in history:
        history['val_f1'] = []
    history['val_f1'].append(val_f1)
    
    # Save validation visualizations (should always be available since we save batch 0)
    if val_images_batch is not None:
        try:
            save_visualization(val_images_batch, val_masks_batch, val_outputs_batch, 
                             epoch + 1, split='val')
        except Exception as e:
            print(f"  Warning: Failed to save validation visualizations: {e}")
    else:
        print(f"  Warning: No validation batch captured for visualization in epoch {epoch + 1}")
    
    print(f'\nEpoch {epoch+1}/{config.NUM_EPOCHS}:')
    print(f'  Train - Loss: {avg_train_loss:.4f}, IoU: {train_iou:.4f}, F1: {train_f1:.4f}')
    print(f'  Val   - Loss: {avg_val_loss:.4f}, IoU: {val_iou:.4f}, F1: {val_f1:.4f}')
    print(f'  Visualizations saved to {config.VIZ_DIR}/')
    
    # Model checkpointing - always save last checkpoint (overwrite)
    checkpoint_path = os.path.join(config.CHECKPOINT_DIR, config.MODEL_SAVE_NAME)
    torch.save(model.state_dict(), checkpoint_path)
    print(f'  ✓ Checkpoint saved: {checkpoint_path} (Val Loss: {avg_val_loss:.4f})')
    
    # Track best loss for early stopping (but always save last checkpoint)
    if avg_val_loss < best_val_loss - config.MIN_DELTA:
        best_val_loss = avg_val_loss
        patience_counter = 0
        best_model_state = model.state_dict().copy()
        print(f'  ✓ New best validation loss: {best_val_loss:.4f}')
    else:
        patience_counter += 1
        print(f'  No improvement. Patience: {patience_counter}/{patience}')
    
    # Plot and save training history graph (updated each epoch)
    plt.figure(figsize=(12, 5))
    
    # Loss subplot
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss', marker='o')
    plt.plot(history['val_loss'], label='Val Loss', marker='o')
    plt.legend()
    plt.title('Loss Evolution')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True, alpha=0.3)
    
    # IoU subplot
    plt.subplot(1, 2, 2)
    plt.plot(history['train_iou'], label='Train IoU', marker='o')
    plt.plot(history['val_iou'], label='Val IoU', marker='o')
    plt.legend()
    plt.title('IoU Evolution')
    plt.xlabel('Epoch')
    plt.ylabel('IoU')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save the plot (overwrites each epoch)
    history_plot_path = os.path.join(config.OUTPUT_DIR, 'training_history.jpg')
    plt.savefig(history_plot_path, dpi=150, bbox_inches='tight', format='jpg')
    plt.close()  # Close the figure to free memory
    print(f'  ✓ Training history graph saved: {history_plot_path}')
    
    # Clear cache to prevent performance degradation over time
    if torch.cuda.is_available():
        torch.cuda.empty_cache()  # Clear GPU cache
    gc.collect()  # Force Python garbage collection
    
    # Early stopping
    if patience_counter >= patience:
        print(f'\nEarly stopping triggered after {epoch+1} epochs')
        break

# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print(f'\nBest model loaded with validation loss: {best_val_loss:.4f}')

# Plot training history
plt.figure(figsize=(12, 5))

# Loss
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.legend()
plt.title('Loss Evolution')
plt.xlabel('Epoch')
plt.ylabel('Loss')

# IoU
plt.subplot(1, 2, 2)
plt.plot(history['train_iou'], label='Train IoU')
plt.plot(history['val_iou'], label='Val IoU')
plt.legend()
plt.title('IoU Evolution')
plt.xlabel('Epoch')
plt.ylabel('IoU')

plt.tight_layout()
plt.show()

In [None]:
# ============================================================================
# INFERENCE ON TEST DATASET
# ============================================================================
# Load checkpoint from config and evaluate on combined_dataset/science-fraud_copymove/test

import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from tqdm import tqdm

print("=" * 60)
print("INFERENCE ON TEST DATASET")
print("=" * 60)

# Configuration for inference
# Use checkpoint from config
INFERENCE_CHECKPOINT = os.path.join(config.CHECKPOINT_DIR, config.MODEL_SAVE_NAME)
TEST_DATASET_ROOT = 'combined_dataset'  # Dataset root
TEST_DATASET_NAME = 'science-fraud_copymove'  # Specific dataset
TEST_SPLIT = 'test'  # Test split
PREDICTION_THRESHOLD = 0.5  # Threshold for binary predictions
BATCH_SIZE = 8  # Batch size for inference

# Check if checkpoint exists
if not os.path.exists(INFERENCE_CHECKPOINT):
    raise FileNotFoundError(f"Checkpoint not found: {INFERENCE_CHECKPOINT}")

print(f"\n✓ Checkpoint found: {INFERENCE_CHECKPOINT}")

# Create model (same architecture as training)
print("\nCreating model...")
inference_model = MaxViT_CBAM_UNet(
    in_channels=config.IN_CHANNELS,
    out_channels=config.OUT_CHANNELS,
    r=config.CBAM_REDUCTION,
    skip_pretrained=True,  # We're loading from checkpoint
    use_deep_supervision=config.USE_DEEP_SUPERVISION,
    use_gradient_checkpointing=False,  # Disable for inference
    use_atrous_pyramid_bottleneck=config.USE_ATROUS_PYRAMID_BOTTLENECK,
    use_identity_bottleneck=config.USE_IDENTITY_BOTTLENECK
).to(device)

# Load checkpoint
print(f"Loading checkpoint: {INFERENCE_CHECKPOINT}")
checkpoint = torch.load(INFERENCE_CHECKPOINT, map_location=device)

# Handle different checkpoint formats
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
    model_state_dict = checkpoint['model_state_dict']
    print(f"✓ Loaded complete checkpoint (epoch: {checkpoint.get('epoch', 'N/A')})")
else:
    model_state_dict = checkpoint
    print("✓ Loaded state dict checkpoint")

# Load weights
inference_model.load_state_dict(model_state_dict, strict=False)
inference_model.eval()
print(f"✓ Model loaded successfully on {device}")

# Create test dataset (no augmentations, no synthetic forgeries)
print(f"\nLoading test dataset: {TEST_DATASET_ROOT}/{TEST_DATASET_NAME}/{TEST_SPLIT}")
test_dataset = UnifiedDataset(
    dataset_root=TEST_DATASET_ROOT,
    img_size=config.IMG_SIZE,
    transform=None,  # No image transforms for inference
    joint_transform=None,  # No spatial transforms for inference
    mask_blur_radius=0.0,  # No blur for inference
    split=TEST_SPLIT,
    datasets_to_load=[TEST_DATASET_NAME],  # Only load the specified dataset
    datasets_to_skip=[],
    filter_by_mask_area=False,  # Don't filter for inference
    synthetic_forgery=None  # No synthetic forgeries for inference
)

print(f"✓ Test dataset loaded: {len(test_dataset)} images")

# Create data loader
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True if torch.cuda.is_available() else False
)

# Run inference
print("\nRunning inference...")
all_predictions = []
all_targets = []
all_probs = []
iou_per_image = []
f1_per_image = []

with torch.no_grad():
    for batch_idx, (images, masks, _, _) in enumerate(tqdm(test_loader, desc="Inference")):
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        outputs = inference_model(images)
        
        # Apply sigmoid to get probabilities
        probs = torch.sigmoid(outputs)
        
        # Threshold to get binary predictions
        predictions = (probs > PREDICTION_THRESHOLD).float()
        
        # Store results
        all_predictions.append(predictions.cpu())
        all_targets.append(masks.cpu())
        all_probs.append(probs.cpu())
        
        # Calculate per-image metrics
        batch_size = images.shape[0]
        for i in range(batch_size):
            pred_i = predictions[i:i+1]
            target_i = masks[i:i+1]
            prob_i = probs[i:i+1]
            
            # Calculate IoU for this image
            intersection = (pred_i * target_i).sum()
            union = pred_i.sum() + target_i.sum() - intersection
            iou = (intersection / (union + 1e-6)).item() if union > 0 else 0.0
            
            # Calculate F1 for this image
            pred_flat = pred_i.view(-1)
            target_flat = target_i.view(-1)
            tp = (pred_flat * target_flat).sum().item()
            fp = (pred_flat * (1 - target_flat)).sum().item()
            fn = ((1 - pred_flat) * target_flat).sum().item()
            
            precision = (tp + 1e-6) / (tp + fp + 1e-6)
            recall = (tp + 1e-6) / (tp + fn + 1e-6)
            f1 = (2 * precision * recall) / (precision + recall + 1e-6)
            
            iou_per_image.append(iou)
            f1_per_image.append(f1)

# Concatenate all results
all_predictions = torch.cat(all_predictions, dim=0)
all_targets = torch.cat(all_targets, dim=0)
all_probs = torch.cat(all_probs, dim=0)

print(f"\n✓ Inference complete: {len(all_predictions)} images processed")

# ============================================================================
# CALCULATE METRICS
# ============================================================================

print("\n" + "=" * 60)
print("CALCULATING METRICS")
print("=" * 60)

# 1. Forgery IoU (mean IoU across all images)
forgery_iou = np.mean(iou_per_image)
print(f"\n1. Forgery IoU: {forgery_iou:.4f}")

# 2. Pixel F1 (overall pixel-wise F1 score)
pred_flat = all_predictions.view(-1).numpy()
target_flat = all_targets.view(-1).numpy()
tp = np.sum((pred_flat == 1) & (target_flat == 1))
fp = np.sum((pred_flat == 1) & (target_flat == 0))
fn = np.sum((pred_flat == 0) & (target_flat == 1))
tn = np.sum((pred_flat == 0) & (target_flat == 0))

pixel_precision = (tp + 1e-6) / (tp + fp + 1e-6)
pixel_recall = (tp + 1e-6) / (tp + fn + 1e-6)
pixel_f1 = (2 * pixel_precision * pixel_recall) / (pixel_precision + pixel_recall + 1e-6)
print(f"2. Pixel F1: {pixel_f1:.4f}")
print(f"   Pixel Precision: {pixel_precision:.4f}")
print(f"   Pixel Recall: {pixel_recall:.4f}")

# 3. Macro Image F1 (average F1 per image, then mean)
macro_image_f1 = np.mean(f1_per_image)
print(f"3. Macro Image F1: {macro_image_f1:.4f}")

# 4. Confusion Matrix (pixel-level)
print(f"\n4. Confusion Matrix (Pixel-level):")
print(f"   True Positives (TP): {tp:,}")
print(f"   False Positives (FP): {fp:,}")
print(f"   False Negatives (FN): {fn:,}")
print(f"   True Negatives (TN): {tn:,}")

# Create confusion matrix array
cm = np.array([[tn, fp],
               [fn, tp]])

print(f"\n   Confusion Matrix:")
print(f"   {'':>15} {'Predicted 0':>15} {'Predicted 1':>15}")
print(f"   {'Actual 0':>15} {tn:>15,} {fp:>15,}")
print(f"   {'Actual 1':>15} {fn:>15,} {tp:>15,}")

# ============================================================================
# VISUALIZE CONFUSION MATRIX
# ============================================================================

print("\n" + "=" * 60)
print("VISUALIZING RESULTS")
print("=" * 60)

# Plot confusion matrix
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Confusion matrix heatmap
ax1 = axes[0]
cm_display = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Background', 'Forgery'])
cm_display.plot(ax=ax1, cmap='Blues', values_format='d')
ax1.set_title('Confusion Matrix (Pixel-level)', fontsize=14, fontweight='bold')
ax1.set_xlabel('Predicted', fontsize=12)
ax1.set_ylabel('Actual', fontsize=12)

# Metrics summary
ax2 = axes[1]
ax2.axis('off')
metrics_text = f"""
METRICS SUMMARY

Forgery IoU:     {forgery_iou:.4f}
Pixel F1:        {pixel_f1:.4f}
Macro Image F1:  {macro_image_f1:.4f}

Pixel-level Statistics:
  Precision:     {pixel_precision:.4f}
  Recall:        {pixel_recall:.4f}
  
Confusion Matrix:
  TP: {tp:,}
  FP: {fp:,}
  FN: {fn:,}
  TN: {tn:,}
  
Total Pixels:    {len(pred_flat):,}
Forgery Pixels:  {np.sum(target_flat):,} ({100*np.sum(target_flat)/len(target_flat):.2f}%)
"""
ax2.text(0.1, 0.5, metrics_text, fontsize=11, family='monospace',
         verticalalignment='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.savefig('inference_results.png', dpi=150, bbox_inches='tight')
print(f"\n✓ Results saved to: inference_results.png")
plt.show()

# Print final summary
print("\n" + "=" * 60)
print("FINAL SUMMARY")
print("=" * 60)
print(f"Dataset: {TEST_DATASET_ROOT}/{TEST_DATASET_NAME}/{TEST_SPLIT}")
print(f"Checkpoint: {INFERENCE_CHECKPOINT}")
print(f"Total Images: {len(test_dataset)}")
print(f"\nMetrics:")
print(f"  Forgery IoU:     {forgery_iou:.4f}")
print(f"  Pixel F1:        {pixel_f1:.4f}")
print(f"  Macro Image F1:  {macro_image_f1:.4f}")
print("=" * 60)


In [None]:
# ============================================================================
# COMPREHENSIVE INFERENCE ON ALL 3 TEST DATASETS
# ============================================================================

import os
import cv2
import numpy as np
import torch
import pandas as pd
from pathlib import Path
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import matplotlib
from tqdm import tqdm
matplotlib.use('Agg')  # Use non-interactive backend for saving

print("=" * 80)
print("COMPREHENSIVE INFERENCE ON ALL 3 TEST DATASETS")
print("=" * 80)

# Test datasets to process
TEST_DATASETS = [
    'casia_copymove',
    'defacto_copymove', 
    'science-fraud_copymove'
]

# Load model for inference
INFERENCE_CHECKPOINT = os.path.join(config.CHECKPOINT_DIR, config.MODEL_SAVE_NAME)
if not os.path.exists(INFERENCE_CHECKPOINT):
    raise FileNotFoundError(f"Checkpoint not found: {INFERENCE_CHECKPOINT}")

print(f"\n✓ Loading checkpoint: {INFERENCE_CHECKPOINT}")

# Create inference model
inference_model = MaxViT_CBAM_UNet(
    in_channels=config.IN_CHANNELS,
    out_channels=config.OUT_CHANNELS,
    r=config.CBAM_REDUCTION,
    skip_pretrained=True,
    use_gradient_checkpointing=False,
    use_atrous_pyramid_bottleneck=config.USE_ATROUS_PYRAMID_BOTTLENECK,
    use_identity_bottleneck=config.USE_IDENTITY_BOTTLENECK
).to(device)

# Load checkpoint
checkpoint = torch.load(INFERENCE_CHECKPOINT, map_location=device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
    model_state_dict = checkpoint['model_state_dict']
else:
    model_state_dict = checkpoint

inference_model.load_state_dict(model_state_dict, strict=False)
inference_model.eval()
print("✓ Model loaded and set to eval mode")

# Create output directory for visualizations
VIZ_OUTPUT_DIR = Path('outputs_unet/test_inference_strips')
VIZ_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Function to create image strip (image + gt + pred + overlay)
def create_image_strip(img_array, gt_mask, pred_mask, output_path, img_name, alpha=0.4):
    """
    Create a 4-panel strip: Original | GT Mask | Pred Mask | Overlay
    """
    # Normalize image to [0, 1] if needed
    if img_array.max() > 1.0:
        img_array = img_array.astype(np.float32) / 255.0
    else:
        img_array = img_array.astype(np.float32)
    
    # Ensure masks are binary [0, 1]
    gt_mask = (gt_mask > 0.5).astype(np.float32)
    pred_mask = (pred_mask > 0.5).astype(np.float32)
    
    # Create figure with 4 panels
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    # Panel 1: Original image
    axes[0].imshow(img_array)
    axes[0].set_title('Original', fontsize=12, fontweight='bold')
    axes[0].axis('off')
    
    # Panel 2: Ground truth mask
    axes[1].imshow(gt_mask, cmap='gray', vmin=0, vmax=1)
    axes[1].set_title(f'GT Mask\n({int(gt_mask.sum())} px)', fontsize=12, fontweight='bold')
    axes[1].axis('off')
    
    # Panel 3: Predicted mask
    axes[2].imshow(pred_mask, cmap='gray', vmin=0, vmax=1)
    axes[2].set_title(f'Pred Mask\n({int(pred_mask.sum())} px)', fontsize=12, fontweight='bold')
    axes[2].axis('off')
    
    # Panel 4: Overlay (GT=red, Pred=blue)
    overlay = img_array.copy()
    gt_colored = np.zeros_like(overlay)
    gt_colored[:, :, 0] = gt_mask  # Red for GT
    pred_colored = np.zeros_like(overlay)
    pred_colored[:, :, 2] = pred_mask  # Blue for pred
    
    overlay = overlay * (1 - alpha)
    overlay = overlay + gt_colored * (alpha * gt_mask[:, :, np.newaxis])
    overlay = overlay + pred_colored * (alpha * pred_mask[:, :, np.newaxis])
    overlay = np.clip(overlay, 0, 1)
    
    axes[3].imshow(overlay)
    axes[3].set_title('Overlay\n(Red=GT, Blue=Pred)', fontsize=12, fontweight='bold')
    axes[3].axis('off')
    
    fig.suptitle(f'{img_name}', fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight', format='jpg')
    plt.close(fig)

# Process each test dataset
all_results = {}

for dataset_name in TEST_DATASETS:
    print("\n" + "=" * 80)
    print(f"PROCESSING DATASET: {dataset_name}")
    print("=" * 80)
    
    # Create test dataset
    test_dataset = UnifiedDataset(
        dataset_root=config.DATASET_ROOT,
        img_size=config.IMG_SIZE,
        transform=None,
        joint_transform=None,
        mask_blur_radius=0.0,
        split='test',
        datasets_to_load=[dataset_name],
        datasets_to_skip=[],
        filter_by_mask_area=False
    )
    
    if len(test_dataset) == 0:
        print(f"⚠️ No images found in {dataset_name}/test, skipping...")
        continue
    
    print(f"✓ Found {len(test_dataset)} test images")
    
    # Create DataLoader
    test_loader = DataLoader(
        test_dataset,
        batch_size=8,
        shuffle=False,
        num_workers=config.NUM_WORKERS,
        pin_memory=config.PIN_MEMORY
    )
    
    # Run inference
    print("Running inference...")
    all_predictions = []
    all_targets = []
    all_probs = []
    image_paths = []
    iou_per_image = []
    f1_per_image = []
    
    with torch.no_grad():
        for batch_idx, (images, masks, class_labels, weight_maps) in enumerate(tqdm(test_loader, desc=f"Inference {dataset_name}")):
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward pass
            outputs = inference_model(images)
            probs = torch.sigmoid(outputs)
            predictions = (probs > 0.5).float()
            
            # Store results
            all_predictions.append(predictions.cpu())
            all_targets.append(masks.cpu())
            all_probs.append(probs.cpu())
            
            # Get image paths from dataset based on batch indices
            batch_size = images.shape[0]
            start_idx = batch_idx * test_loader.batch_size
            for i in range(batch_size):
                if start_idx + i < len(test_dataset.image_paths):
                    image_paths.append(test_dataset.image_paths[start_idx + i])
            
            # Calculate per-image metrics
            for i in range(batch_size):
                pred_i = predictions[i:i+1]
                target_i = masks[i:i+1]
                
                # IoU
                intersection = (pred_i * target_i).sum()
                union = pred_i.sum() + target_i.sum() - intersection
                iou = (intersection / (union + 1e-6)).item() if union > 0 else 0.0
                
                # F1
                pred_flat = pred_i.view(-1)
                target_flat = target_i.view(-1)
                tp = (pred_flat * target_flat).sum().item()
                fp = (pred_flat * (1 - target_flat)).sum().item()
                fn = ((1 - pred_flat) * target_flat).sum().item()
                precision = (tp + 1e-6) / (tp + fp + 1e-6)
                recall = (tp + 1e-6) / (tp + fn + 1e-6)
                f1 = (2 * precision * recall) / (precision + recall + 1e-6)
                
                iou_per_image.append(iou)
                f1_per_image.append(f1)
    
    # Concatenate all predictions and targets
    pred_flat = torch.cat(all_predictions, dim=0).view(-1).numpy()
    target_flat = torch.cat(all_targets, dim=0).view(-1).numpy()
    
    # Calculate overall metrics
    # Pixel-level metrics
    tp = ((pred_flat == 1) & (target_flat == 1)).sum()
    fp = ((pred_flat == 1) & (target_flat == 0)).sum()
    fn = ((pred_flat == 0) & (target_flat == 1)).sum()
    tn = ((pred_flat == 0) & (target_flat == 0)).sum()
    
    pixel_precision = tp / (tp + fp + 1e-6)
    pixel_recall = tp / (tp + fn + 1e-6)
    pixel_f1 = 2 * (pixel_precision * pixel_recall) / (pixel_precision + pixel_recall + 1e-6)
    pixel_accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-6)
    
    # IoU (forgery class)
    intersection = (pred_flat * target_flat).sum()
    union = pred_flat.sum() + target_flat.sum() - intersection
    forgery_iou = (intersection / (union + 1e-6)) if union > 0 else 0.0
    
    # Image-level metrics
    macro_image_f1 = np.mean(f1_per_image) if len(f1_per_image) > 0 else 0.0
    mean_iou = np.mean(iou_per_image) if len(iou_per_image) > 0 else 0.0
    
    # Calculate image-level predictions (image has forgery if any forgery pixels)
    image_pred_forgery = []
    image_gt_forgery = []
    for i in range(len(test_dataset)):
        batch_idx = i // test_loader.batch_size
        pos_in_batch = i % test_loader.batch_size
        if batch_idx < len(all_predictions) and pos_in_batch < all_predictions[batch_idx].shape[0]:
            pred_img = all_predictions[batch_idx][pos_in_batch]
            target_img = all_targets[batch_idx][pos_in_batch]
            image_pred_forgery.append(pred_img.sum().item() > 0)
            image_gt_forgery.append(target_img.sum().item() > 0)
    
    # Calculate image-level confusion matrix
    image_pred_forgery = np.array(image_pred_forgery)
    image_gt_forgery = np.array(image_gt_forgery)
    cm_image = confusion_matrix(image_gt_forgery, image_pred_forgery, labels=[False, True])
    tn_i, fp_i, fn_i, tp_i = cm_image[0,0], cm_image[0,1], cm_image[1,0], cm_image[1,1]
    image_accuracy = (tp_i + tn_i) / (tp_i + tn_i + fp_i + fn_i + 1e-6)
    image_precision = tp_i / (tp_i + fp_i + 1e-6)
    image_recall = tp_i / (tp_i + fn_i + 1e-6)
    image_f1 = 2 * (image_precision * image_recall) / (image_precision + image_recall + 1e-6)
    
    # Store results
    all_results[dataset_name] = {
        'pixel_f1': pixel_f1,
        'pixel_precision': pixel_precision,
        'pixel_recall': pixel_recall,
        'pixel_accuracy': pixel_accuracy,
        'forgery_iou': forgery_iou,
        'macro_image_f1': macro_image_f1,
        'mean_iou': mean_iou,
        'image_accuracy': image_accuracy,
        'image_precision': image_precision,
        'image_recall': image_recall,
        'image_f1': image_f1,
        'tp': int(tp),
        'fp': int(fp),
        'fn': int(fn),
        'tn': int(tn),
        'tp_i': int(tp_i),
        'fp_i': int(fp_i),
        'fn_i': int(fn_i),
        'tn_i': int(tn_i),
        'cm_image': cm_image,
        'total_images': len(test_dataset),
        'iou_per_image': iou_per_image,
        'f1_per_image': f1_per_image,
        'image_paths': image_paths,
        'predictions': all_predictions,
        'targets': all_targets
    }
    
    # Print comprehensive review
    print("\n" + "=" * 80)
    print(f"COMPREHENSIVE REVIEW: {dataset_name.upper()}")
    print("=" * 80)
    print(f"\nDataset: {dataset_name}/test")
    print(f"Total Images: {len(test_dataset)}")
    print(f"\n📊 OVERALL METRICS:")
    print(f"  Forgery IoU:        {forgery_iou:.4f}")
    print(f"  Pixel-wise F1:      {pixel_f1:.4f}")
    print(f"  Macro Image F1:     {macro_image_f1:.4f}")
    print(f"  Mean IoU per Image: {mean_iou:.4f}")
    print(f"\n📈 PIXEL-LEVEL STATISTICS:")
    print(f"  Accuracy:           {pixel_accuracy:.4f}")
    print(f"  Precision:          {pixel_precision:.4f}")
    print(f"  Recall:             {pixel_recall:.4f}")
    print(f"\n🔢 CONFUSION MATRIX (Pixel-level):")
    print(f"  {'':>15} {'Predicted 0':>15} {'Predicted 1':>15}")
    print(f"  {'Actual 0':>15} {tn:>15,} {fp:>15,}")
    print(f"  {'Actual 1':>15} {fn:>15,} {tp:>15,}")
    print(f"\n  Total Pixels:       {len(pred_flat):,}")
    print(f"  Forgery Pixels:     {int(target_flat.sum()):,} ({100*target_flat.sum()/len(target_flat):.2f}%)")
    print(f"\n📊 IMAGE-LEVEL STATISTICS:")
    print(f"  Accuracy:           {image_accuracy:.4f}")
    print(f"  Precision:          {image_precision:.4f}")
    print(f"  Recall:             {image_recall:.4f}")
    print(f"  F1 Score:           {image_f1:.4f}")
    print(f"\n🔢 CONFUSION MATRIX (Image-level):")
    print(f"  {'':>15} {'Predicted 0':>15} {'Predicted 1':>15}")
    print(f"  {'Actual 0':>15} {tn_i:>15,} {fp_i:>15,}")
    print(f"  {'Actual 1':>15} {fn_i:>15,} {tp_i:>15,}")
    print("=" * 80)
    
    # Create output directory
    cm_output_dir = VIZ_OUTPUT_DIR / dataset_name
    cm_output_dir.mkdir(parents=True, exist_ok=True)
    
    # Save Image-level confusion matrix
    fig_image, ax_image = plt.subplots(figsize=(8, 6))
    disp_image = ConfusionMatrixDisplay(confusion_matrix=cm_image, 
                                       display_labels=['Authentic', 'Forgery'])
    disp_image.plot(ax=ax_image, cmap='Blues', values_format='d')
    ax_image.set_title(f'{dataset_name.upper()} - Image-Level Confusion Matrix\nAccuracy: {image_accuracy:.4f}, F1: {image_f1:.4f}', 
                      fontsize=14, fontweight='bold', pad=20)
    ax_image.set_xlabel('Predicted', fontsize=12, fontweight='bold')
    ax_image.set_ylabel('Actual', fontsize=12, fontweight='bold')
    
    # Add metrics text box
    metrics_text_image = f"""Image-Level Metrics:
    
Accuracy:  {image_accuracy:.4f}
Precision: {image_precision:.4f}
Recall:    {image_recall:.4f}
F1 Score:  {image_f1:.4f}

Total Images: {len(test_dataset)}"""
    
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
    ax_image.text(1.02, 0.5, metrics_text_image, transform=ax_image.transAxes, fontsize=10,
                 verticalalignment='center', bbox=props, family='monospace')
    
    plt.tight_layout()
    cm_image_path = cm_output_dir / f'{dataset_name}_image_level_confusion_matrix.png'
    plt.savefig(cm_image_path, dpi=150, bbox_inches='tight', format='png')
    plt.close(fig_image)
    print(f"\n💾 Image-level confusion matrix saved to: {cm_image_path}")
    
    # Save Pixel-level confusion matrix
    cm_pixel = np.array([[tn, fp], [fn, tp]])
    fig_pixel, ax_pixel = plt.subplots(figsize=(8, 6))
    cm_display = ConfusionMatrixDisplay(confusion_matrix=cm_pixel, display_labels=['Authentic', 'Forgery'])
    cm_display.plot(ax=ax_pixel, cmap='Blues', values_format='d')
    ax_pixel.set_title(f'{dataset_name.upper()} - Pixel-Level Confusion Matrix\nIoU: {forgery_iou:.4f}, F1: {pixel_f1:.4f}', 
                      fontsize=14, fontweight='bold', pad=20)
    ax_pixel.set_xlabel('Predicted', fontsize=12, fontweight='bold')
    ax_pixel.set_ylabel('Actual', fontsize=12, fontweight='bold')
    
    # Add metrics text box
    metrics_text_pixel = f"""Pixel-Level Metrics:
    
Forgery IoU:     {forgery_iou:.4f}
Pixel F1:        {pixel_f1:.4f}
Macro Image F1:  {macro_image_f1:.4f}
Mean IoU:        {mean_iou:.4f}

Pixel-level:
  Accuracy:      {pixel_accuracy:.4f}
  Precision:     {pixel_precision:.4f}
  Recall:        {pixel_recall:.4f}
  
Total Pixels:    {len(pred_flat):,}
Forgery Pixels:  {int(target_flat.sum()):,} ({100*target_flat.sum()/len(target_flat):.2f}%)"""
    
    ax_pixel.text(1.02, 0.5, metrics_text_pixel, transform=ax_pixel.transAxes, fontsize=10,
                 verticalalignment='center', bbox=props, family='monospace')
    
    plt.tight_layout()
    cm_pixel_path = cm_output_dir / f'{dataset_name}_pixel_level_confusion_matrix.png'
    plt.savefig(cm_pixel_path, dpi=150, bbox_inches='tight', format='png')
    plt.close(fig_pixel)
    print(f"💾 Pixel-level confusion matrix saved to: {cm_pixel_path}")
    
    # Save first 10 image strips
    print(f"\n💾 Saving first 10 image strips for {dataset_name}...")
    dataset_viz_dir = VIZ_OUTPUT_DIR / dataset_name
    dataset_viz_dir.mkdir(parents=True, exist_ok=True)
    
    num_strips = min(10, len(test_dataset))
    saved_count = 0
    
    for img_idx in range(num_strips):
        try:
            # Get image data
            img_path = image_paths[img_idx]
            
            # Find which batch and position this image is in
            batch_idx = img_idx // 8
            pos_in_batch = img_idx % 8
            
            if batch_idx >= len(all_predictions):
                continue
                
            pred_batch = all_predictions[batch_idx]  # Shape: (batch_size, 1, H, W)
            target_batch = all_targets[batch_idx]    # Shape: (batch_size, 1, H, W)
            
            if pos_in_batch >= pred_batch.shape[0]:
                continue
                
            pred_tensor = pred_batch[pos_in_batch]   # Shape: (1, H, W)
            target_tensor = target_batch[pos_in_batch]  # Shape: (1, H, W)
            
            # Load original image
            img_bgr = cv2.imread(str(img_path))
            if img_bgr is None:
                continue
            img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
            
            # Convert masks to numpy
            pred_mask = pred_tensor.squeeze().numpy()
            gt_mask = target_tensor.squeeze().numpy()
            
            # Resize masks to original image size if needed
            if pred_mask.shape != img_rgb.shape[:2]:
                pred_mask = cv2.resize(pred_mask, (img_rgb.shape[1], img_rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
                gt_mask = cv2.resize(gt_mask, (img_rgb.shape[1], img_rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
            
            # Create and save strip
            img_name = Path(img_path).stem
            output_path = dataset_viz_dir / f"{img_idx+1:02d}_{img_name}.jpg"
            create_image_strip(img_rgb, gt_mask, pred_mask, output_path, img_name)
            saved_count += 1
            
        except Exception as e:
            print(f"  ⚠️ Error saving strip {img_idx+1}: {e}")
            continue
    
    print(f"  ✓ Saved {saved_count} image strips to {dataset_viz_dir}")

# Print overall summary
print("\n" + "=" * 80)
print("OVERALL SUMMARY - ALL DATASETS")
print("=" * 80)

summary_data = []
for dataset_name in TEST_DATASETS:
    if dataset_name in all_results:
        r = all_results[dataset_name]
        summary_data.append({
            'Dataset': dataset_name,
            'Images': r['total_images'],
            'Mean IoU': f"{r['mean_iou']:.4f}",
            'Pixel F1': f"{r['pixel_f1']:.4f}",
            'Image F1': f"{r['image_f1']:.4f}",
            'Pixel Prec': f"{r['pixel_precision']:.4f}",
            'Pixel Rec': f"{r['pixel_recall']:.4f}"
        })

if summary_data:
    summary_df = pd.DataFrame(summary_data)
    print("\n" + summary_df.to_string(index=False))
    
    # Calculate averages
    if len(summary_data) > 1:
        avg_mean_iou = np.mean([all_results[d]['mean_iou'] for d in TEST_DATASETS if d in all_results])
        avg_f1 = np.mean([all_results[d]['pixel_f1'] for d in TEST_DATASETS if d in all_results])
        avg_image_f1 = np.mean([all_results[d]['image_f1'] for d in TEST_DATASETS if d in all_results])
        
        print(f"\n📊 AVERAGE ACROSS ALL DATASETS:")
        print(f"  Average Mean IoU:     {avg_mean_iou:.4f}")
        print(f"  Average Pixel F1:     {avg_f1:.4f}")
        print(f"  Average Image F1:     {avg_image_f1:.4f}")

# Create combined confusion matrix visualization for all datasets
if len(all_results) > 0:
    print("\n💾 Creating combined confusion matrix visualization...")
    # Create 2 rows: image-level (top) and pixel-level (bottom)
    fig, axes = plt.subplots(2, len(all_results), figsize=(6 * len(all_results), 12))
    if len(all_results) == 1:
        axes = axes.reshape(2, 1)
    
    for idx, dataset_name in enumerate(TEST_DATASETS):
        if dataset_name not in all_results:
            continue
        
        r = all_results[dataset_name]
        
        # Image-level confusion matrix (top row)
        cm_image = r.get('cm_image', np.array([[r['tn_i'], r['fp_i']], [r['fn_i'], r['tp_i']]]))
        disp_image = ConfusionMatrixDisplay(confusion_matrix=cm_image, 
                                           display_labels=['Authentic', 'Forgery'])
        disp_image.plot(ax=axes[0, idx], cmap='Blues', values_format='d')
        axes[0, idx].set_title(f'{dataset_name}\nImage-Level (F1: {r.get("image_f1", 0):.3f})', 
                              fontsize=11, fontweight='bold')
        axes[0, idx].set_xlabel('Predicted', fontsize=10)
        axes[0, idx].set_ylabel('Actual', fontsize=10)
        
        # Pixel-level confusion matrix (bottom row)
        cm_pixel = np.array([[r['tn'], r['fp']], [r['fn'], r['tp']]])
        cm_display = ConfusionMatrixDisplay(confusion_matrix=cm_pixel, display_labels=['Authentic', 'Forgery'])
        cm_display.plot(ax=axes[1, idx], cmap='Blues', values_format='d')
        axes[1, idx].set_title(f'{dataset_name}\nPixel-Level (F1: {r["pixel_f1"]:.3f})', 
                              fontsize=11, fontweight='bold')
        axes[1, idx].set_xlabel('Predicted', fontsize=10)
        axes[1, idx].set_ylabel('Actual', fontsize=10)
    
    plt.suptitle('Combined Confusion Matrices - All Test Datasets', fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    
    # Save combined confusion matrix
    combined_cm_path = VIZ_OUTPUT_DIR / 'combined_confusion_matrices.png'
    plt.savefig(combined_cm_path, dpi=150, bbox_inches='tight', format='png')
    plt.close(fig)
    print(f"  ✓ Combined confusion matrix saved to: {combined_cm_path}")

print(f"\n💾 Visualizations saved to: {VIZ_OUTPUT_DIR}")
print("=" * 80)


COMPREHENSIVE INFERENCE ON ALL 3 TEST DATASETS

✓ Loading checkpoint: outputs_unet/checkpoints/MaxVit-square-Unet-fraud-best.pth
Creating MaxViT model without pretrained weights (skip_pretrained=True)...
  Initializing model with random weights...
✓ Model loaded and set to eval mode

PROCESSING DATASET: casia_copymove
Found 115 image-mask pairs from 1 subfolders (split: test)
  Loaded datasets: casia_copymove
  (Filtered to: ['casia_copymove'])
✓ Found 115 test images
Running inference...


Inference casia_copymove: 100%|██████████| 15/15 [00:05<00:00,  2.99it/s]



COMPREHENSIVE REVIEW: CASIA_COPYMOVE

Dataset: casia_copymove/test
Total Images: 115

📊 OVERALL METRICS:
  Forgery IoU:        0.1420
  Pixel-wise F1:      0.2487
  Macro Image F1:     0.2331
  Mean IoU per Image: 0.1245

📈 PIXEL-LEVEL STATISTICS:
  Accuracy:           0.9285
  Precision:          0.3559
  Recall:             0.1912

🔢 CONFUSION MATRIX (Pixel-level):
                      Predicted 0     Predicted 1
         Actual 0      27,632,929         646,177
         Actual 1       1,510,427         357,027

  Total Pixels:       30,146,560
  Forgery Pixels:     1,867,454 (6.19%)

📊 IMAGE-LEVEL STATISTICS:
  Accuracy:           0.6000
  Precision:          0.5980
  Recall:             0.9242
  F1 Score:           0.7262

🔢 CONFUSION MATRIX (Image-level):
                      Predicted 0     Predicted 1
         Actual 0               8              41
         Actual 1               5              61

💾 Image-level confusion matrix saved to: outputs_unet/test_inference_strips/

Inference defacto_copymove: 100%|██████████| 33/33 [00:10<00:00,  3.23it/s]



COMPREHENSIVE REVIEW: DEFACTO_COPYMOVE

Dataset: defacto_copymove/test
Total Images: 264

📊 OVERALL METRICS:
  Forgery IoU:        0.6961
  Pixel-wise F1:      0.8208
  Macro Image F1:     0.7818
  Mean IoU per Image: 0.6907

📈 PIXEL-LEVEL STATISTICS:
  Accuracy:           0.9713
  Precision:          0.9011
  Recall:             0.7536

🔢 CONFUSION MATRIX (Pixel-level):
                      Predicted 0     Predicted 1
         Actual 0      62,673,175         498,892
         Actual 1       1,486,682       4,547,267

  Total Pixels:       69,206,016
  Forgery Pixels:     6,033,949 (8.72%)

📊 IMAGE-LEVEL STATISTICS:
  Accuracy:           0.9848
  Precision:          1.0000
  Recall:             0.9848
  F1 Score:           0.9924

🔢 CONFUSION MATRIX (Image-level):
                      Predicted 0     Predicted 1
         Actual 0               0               0
         Actual 1               4             260

💾 Image-level confusion matrix saved to: outputs_unet/test_inference_str

Inference science-fraud_copymove: 100%|██████████| 27/27 [00:08<00:00,  3.26it/s]



COMPREHENSIVE REVIEW: SCIENCE-FRAUD_COPYMOVE

Dataset: science-fraud_copymove/test
Total Images: 213

📊 OVERALL METRICS:
  Forgery IoU:        0.2557
  Pixel-wise F1:      0.4072
  Macro Image F1:     0.4504
  Mean IoU per Image: 0.1264

📈 PIXEL-LEVEL STATISTICS:
  Accuracy:           0.9780
  Precision:          0.6136
  Recall:             0.3047

🔢 CONFUSION MATRIX (Pixel-level):
                      Predicted 0     Predicted 1
         Actual 0      54,183,570         266,141
         Actual 1         964,290         422,671

  Total Pixels:       55,836,672
  Forgery Pixels:     1,386,961 (2.48%)

📊 IMAGE-LEVEL STATISTICS:
  Accuracy:           0.6667
  Precision:          0.6860
  Recall:             0.7155
  F1 Score:           0.7004

🔢 CONFUSION MATRIX (Image-level):
                      Predicted 0     Predicted 1
         Actual 0              59              38
         Actual 1              33              83

💾 Image-level confusion matrix saved to: outputs_unet/test_i