# Research Question 4: Grad-CAM Explainability Analysis
## U-Net Nuclei Segmentation Interpretability with Stain Normalization

**Research Question**: Do lightweight explainability techniques—like Grad-CAM—enhance interpretability of U-Net-based nuclei segmentation on PanNuke, and does stain normalization improve this further?

### Research Hypotheses:
- **H₀ (Null)**: Stain normalization does not improve Grad-CAM alignment
- **H₁ (Alternative)**: Stain normalization significantly improves Grad-CAM alignment

### Experimental Design:
1. **Model Variants**:
   - U-Net without Grad-CAM (baseline)
   - U-Net with Grad-CAM (original images)
   - U-Net with Grad-CAM (stain normalized images)

2. **Evaluation Framework**:
   - Grad-CAM spatial accuracy metrics
   - Biological relevance assessment
   - Statistical analysis (ANOVA)
   - Visual interpretability comparison

3. **Key Metrics**:
   - Grad-CAM localization accuracy
   - Attention map quality scores
   - Segmentation performance correlation
   - Clinical interpretability measures

### Expected Outcomes:
- Quantified improvement in explainability with stain normalization
- Statistical validation of Grad-CAM effectiveness
- Clinical insights for interpretable AI in histopathology

---
**🔬 Explainability Research | Grad-CAM Analysis | Clinical Interpretability**


In [None]:
# Install requirements
%pip install -r ../requirements.txt
%pip install grad-cam
%pip install captum


In [None]:
# =============================================================================
# RQ4 IMPORTS AND SETUP
# =============================================================================

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from PIL import Image
from pathlib import Path
import os
import sys
import json
import time
import warnings
from collections import defaultdict
from typing import Dict, List, Tuple, Optional, Union
import gc
from tqdm import tqdm
from sklearn.metrics import jaccard_score, f1_score, precision_score, recall_score
from scipy.stats import wilcoxon, ttest_rel, ttest_ind, shapiro, levene, mannwhitneyu, f_oneway
from scipy import stats
import statsmodels.api as sm
from statsmodels.stats.power import ttest_power
from statsmodels.stats.anova import anova_lm
from statsmodels.formula.api import ols
import logging
import shutil

# Grad-CAM and explainability imports
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, XGradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import captum
from captum.attr import IntegratedGradients, Saliency, GuidedGradCam
from captum.attr import visualization as viz

# Set up paths
project_root = Path('/Users/shubhangmalviya/Documents/Projects/Walsh College/HistoPathologyResearch/')
sys.path.append(str(project_root))

# Import custom modules
from preprocessing.vahadane_gpu import GPUVahadaneNormalizer
from models.unet_rq3 import UNetRQ3, create_unet_rq3
from models.unet import UNet  # Import the original UNet class
from utils.metrics import calculate_segmentation_metrics


In [None]:
# Configure logging and setup
log_dir = project_root / 'artifacts' / 'rq4_gradcam' / 'logs'
log_dir.mkdir(parents=True, exist_ok=True)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(log_dir / 'rq4_gradcam_pipeline.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Enhanced style and warnings
plt.style.use('seaborn-v0_8') if 'seaborn-v0_8' in plt.style.available else plt.style.use('seaborn')
sns.set_palette('husl')
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.cuda.manual_seed_all(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🔬 RQ4 Grad-CAM Explainability Pipeline initialized on {device}")
if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name()}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


In [None]:
# Create RQ4 artifacts directories
artifacts_dir = project_root / 'artifacts' / 'rq4_gradcam'
subdirs = [
    'checkpoints', 'results', 'plots', 'logs',
    'gradcam_visualizations', 'attention_maps', 'comparisons',
    'statistical_analysis', 'explainability_metrics',
    'models/original', 'models/normalized',
    'evaluation/baseline', 'evaluation/gradcam_original', 'evaluation/gradcam_normalized'
]

for subdir in subdirs:
    (artifacts_dir / subdir).mkdir(parents=True, exist_ok=True)

logger.info("RQ4 Grad-CAM Explainability Pipeline - Initialized Successfully")
print(f"📁 Artifacts directory: {artifacts_dir}")


## 1. Enhanced Grad-CAM Implementation for U-Net

### 1.1 Custom Grad-CAM Wrapper for Segmentation Models


In [None]:
class SegmentationGradCAM:
    """
    Custom Grad-CAM implementation for U-Net segmentation models.
    Adapted for nuclei segmentation with proper attention map generation.
    """
    
    def __init__(self, model, target_layers, device='cuda'):
        self.model = model
        self.target_layers = target_layers
        self.device = device
        self.model.eval()
        
        # Hook to store gradients and activations
        self.gradients = None
        self.activations = None
        
        # Register hooks
        self._register_hooks()
    
    def _register_hooks(self):
        """Register forward and backward hooks for target layers."""
        def forward_hook(module, input, output):
            self.activations = output
            
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]
            
        for layer in self.target_layers:
            layer.register_forward_hook(forward_hook)
            layer.register_backward_hook(backward_hook)
    
    def generate_cam(self, input_tensor, class_idx=None):
        """
        Generate Class Activation Map for segmentation.
        
        Args:
            input_tensor: Input image tensor (B, C, H, W)
            class_idx: Class index for which to generate CAM (None for max activation)
        
        Returns:
            cam: Class Activation Map (H, W)
        """
        # Forward pass
        input_tensor = input_tensor.to(self.device)
        input_tensor.requires_grad_(True)
        
        # Get model output
        output = self.model(input_tensor)
        
        if class_idx is None:
            # For segmentation, use the maximum activation across all classes
            class_idx = output.argmax(dim=1)
        
        # Zero gradients
        self.model.zero_grad()
        
        # Backward pass
        if output.dim() == 4:  # Segmentation output (B, C, H, W)
            # Create one-hot encoding for the target class
            target = torch.zeros_like(output)
            target.scatter_(1, class_idx.unsqueeze(1), 1)
            loss = (output * target).sum()
        else:  # Classification output (B, C)
            loss = output[:, class_idx].sum()
        
        loss.backward()
        
        # Generate CAM
        gradients = self.gradients[0]  # (B, C, H, W)
        activations = self.activations[0]  # (B, C, H, W)
        
        # Global average pooling of gradients
        weights = gradients.mean(dim=(2, 3))  # (B, C)
        
        # Weighted combination of activation maps
        cam = torch.zeros(activations.shape[0], activations.shape[2], activations.shape[3])
        for i in range(weights.shape[1]):
            cam += weights[0, i] * activations[0, i]
        
        # Apply ReLU and normalize
        cam = F.relu(cam)
        cam = cam - cam.min()
        cam = cam / cam.max() if cam.max() > 0 else cam
        
        return cam.detach().cpu().numpy()
    
    def generate_attention_map(self, input_tensor, class_idx=None):
        """
        Generate attention map with proper normalization for visualization.
        
        Args:
            input_tensor: Input image tensor (B, C, H, W)
            class_idx: Class index for attention map
        
        Returns:
            attention_map: Normalized attention map (H, W)
        """
        cam = self.generate_cam(input_tensor, class_idx)
        
        # Resize to input image size if needed
        if cam.shape != input_tensor.shape[2:]:
            cam = cv2.resize(cam, (input_tensor.shape[3], input_tensor.shape[2]))
        
        return cam

class ExplainabilityMetrics:
    """
    Comprehensive metrics for evaluating explainability methods.
    """
    
    @staticmethod
    def localization_accuracy(cam, ground_truth_mask, threshold=0.5):
        """
        Calculate localization accuracy of CAM with respect to ground truth.
        
        Args:
            cam: Class activation map (H, W)
            ground_truth_mask: Binary ground truth mask (H, W)
            threshold: Threshold for binarizing CAM
        
        Returns:
            accuracy: Localization accuracy score
        """
        # Binarize CAM
        cam_binary = (cam > threshold).astype(np.uint8)
        
        # Calculate intersection over union
        intersection = np.logical_and(cam_binary, ground_truth_mask).sum()
        union = np.logical_or(cam_binary, ground_truth_mask).sum()
        
        if union == 0:
            return 0.0
        
        return intersection / union
    
    @staticmethod
    def attention_consistency(cam1, cam2):
        """
        Calculate consistency between two attention maps.
        
        Args:
            cam1, cam2: Two attention maps to compare (H, W)
        
        Returns:
            consistency: Pearson correlation coefficient
        """
        # Flatten and calculate correlation
        cam1_flat = cam1.flatten()
        cam2_flat = cam2.flatten()
        
        correlation = np.corrcoef(cam1_flat, cam2_flat)[0, 1]
        return correlation if not np.isnan(correlation) else 0.0
    
    @staticmethod
    def spatial_coherence(cam, window_size=5):
        """
        Calculate spatial coherence of attention map.
        
        Args:
            cam: Attention map (H, W)
            window_size: Size of local window for coherence calculation
        
        Returns:
            coherence: Spatial coherence score
        """
        # Apply local variance filter
        kernel = np.ones((window_size, window_size)) / (window_size ** 2)
        local_mean = cv2.filter2D(cam, -1, kernel)
        local_variance = cv2.filter2D(cam**2, -1, kernel) - local_mean**2
        
        # Coherence is inverse of local variance
        coherence = 1.0 / (1.0 + local_variance.mean())
        return coherence
    
    @staticmethod
    def biological_relevance(cam, nuclei_mask, background_mask):
        """
        Calculate biological relevance of attention map.
        
        Args:
            cam: Attention map (H, W)
            nuclei_mask: Binary mask of nuclei regions (H, W)
            background_mask: Binary mask of background regions (H, W)
        
        Returns:
            relevance: Biological relevance score
        """
        # Calculate attention in nuclei vs background
        nuclei_attention = cam[nuclei_mask > 0].mean() if nuclei_mask.sum() > 0 else 0
        background_attention = cam[background_mask > 0].mean() if background_mask.sum() > 0 else 0
        
        # Relevance is the ratio of nuclei attention to background attention
        if background_attention > 0:
            relevance = nuclei_attention / background_attention
        else:
            relevance = nuclei_attention
        
        return relevance

print("✅ Grad-CAM implementation and explainability metrics loaded successfully")


## 2. Dataset Loading and Model Setup

### 2.1 Load Dataset and Identify Top Tissues


In [None]:
# =============================================================================
# DATASET ANALYSIS - TOP 5 TISSUES FOR RQ4
# =============================================================================

dataset_path = project_root / 'dataset_tissues'
logger.info("Analyzing dataset for RQ4 Grad-CAM explainability analysis...")

# Count images per tissue and collect all valid pairs
tissue_data = {}
tissue_counts = {}

for tissue_dir in dataset_path.iterdir():
    if tissue_dir.is_dir():
        tissue_name = tissue_dir.name
        all_pairs = []
        
        # Collect all valid image-mask pairs from all splits
        for split in ['train', 'test', 'val']:
            images_dir = tissue_dir / split / 'images'
            masks_dir = tissue_dir / split / 'sem_masks'  # Semantic masks
            
            if images_dir.exists() and masks_dir.exists():
                image_files = list(images_dir.glob('*.png'))
                
                for img_file in image_files:
                    mask_file = masks_dir / img_file.name.replace('img_', 'sem_')
                    if mask_file.exists():
                        all_pairs.append((img_file, mask_file))
        
        tissue_data[tissue_name] = all_pairs
        tissue_counts[tissue_name] = len(all_pairs)

# Sort tissues by count and select top 5
top_5_tissues = sorted(tissue_counts.items(), key=lambda x: x[1], reverse=True)[:5]
selected_tissues = [tissue for tissue, count in top_5_tissues]

print("🔍 RQ4 Dataset Analysis Results:")
print("=" * 60)
for tissue, count in sorted(tissue_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
    marker = "✅" if tissue in selected_tissues else "  "
    print(f"{marker} {tissue:15}: {count:,} images")

print(f"\n🎯 Selected Top 5 Tissues for RQ4 Grad-CAM Analysis:")
for i, (tissue, count) in enumerate(top_5_tissues, 1):
    print(f"{i}. {tissue}: {count:,} images")

total_selected = sum(count for _, count in top_5_tissues)
print(f"\n📊 Total images in top 5 tissues: {total_selected:,}")

# Configuration for RQ4 analysis
TESTING_MODE = False  # Set to False for full analysis
SAMPLE_SIZE_PER_TISSUE = 20 if TESTING_MODE else None  # Sample for explainability analysis

print(f"\n🧪 Running in {'TESTING' if TESTING_MODE else 'PRODUCTION'} mode")
if TESTING_MODE:
    print(f"   Sample size: {SAMPLE_SIZE_PER_TISSUE} images per tissue")


In [None]:
# =============================================================================
# DATALOADER CREATION AND MODEL SETUP
# =============================================================================

# Create dataloaders for evaluation (these were missing)
def create_dataloaders(tissue_pairs, batch_size=8, num_workers=2):
    """
    Create train, validation, and test dataloaders from tissue pairs.
    """
    # Split data into train/val/test (80/10/10)
    total_samples = len(tissue_pairs)
    train_size = int(0.8 * total_samples)
    val_size = int(0.1 * total_samples)
    
    # Shuffle data
    import random
    random.shuffle(tissue_pairs)
    
    # Split data
    train_pairs = tissue_pairs[:train_size]
    val_pairs = tissue_pairs[train_size:train_size + val_size]
    test_pairs = tissue_pairs[train_size + val_size:]
    
    # Create datasets
    train_dataset = RQ4Dataset(train_pairs, transform=transform, normalize=False)
    val_dataset = RQ4Dataset(val_pairs, transform=transform, normalize=False)
    test_dataset = RQ4Dataset(test_pairs, transform=transform, normalize=False)
    
    # Create dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return train_dataloader, val_dataloader, test_dataloader

# Create dataloaders for the top 5 tissues
print("🔄 Creating dataloaders for RQ4 analysis...")

# Combine all tissue pairs for dataloader creation
all_tissue_pairs = []
for tissue in selected_tissues:
    tissue_pairs = tissue_data[tissue]
    if TESTING_MODE and SAMPLE_SIZE_PER_TISSUE:
        tissue_pairs = tissue_pairs[:SAMPLE_SIZE_PER_TISSUE]
    all_tissue_pairs.extend(tissue_pairs)

# Create dataloaders
train_dataloader, val_dataloader, test_dataloader = create_dataloaders(
    all_tissue_pairs, batch_size=8, num_workers=2
)

print(f"✅ Dataloaders created:")
print(f"   • Train: {len(train_dataloader.dataset)} samples")
print(f"   • Validation: {len(val_dataloader.dataset)} samples") 
print(f"   • Test: {len(test_dataloader.dataset)} samples")

# Create model instances for different configurations
print("\n🔄 Creating model instances...")

# Create UNet model for baseline comparison
baseline_model = UNet(in_channels=3, num_classes=6).to(device)
total_params = sum(p.numel() for p in baseline_model.parameters())
print(f"✅ Baseline UNet model created: {total_params:,} parameters")

# Create UNetRQ3 models for Grad-CAM analysis
original_model = create_unet_rq3(n_channels=3, n_classes=1, device=device, verbose=False)
normalized_model = create_unet_rq3(n_channels=3, n_classes=1, device=device, verbose=False)

print(f"✅ RQ3 models created:")
orig_params = sum(p.numel() for p in original_model.parameters())
norm_params = sum(p.numel() for p in normalized_model.parameters())
print(f"   • Original: {orig_params:,} parameters")
print(f"   • Normalized: {norm_params:,} parameters")

print("\n✅ All models and dataloaders ready for RQ4 analysis!")


In [None]:
# =============================================================================
# CUSTOM DATASET CLASS FOR RQ4 GRAD-CAM ANALYSIS
# =============================================================================

class RQ4Dataset(Dataset):
    """
    Custom dataset class for RQ4 Grad-CAM explainability analysis.
    Handles both original and normalized images with proper preprocessing.
    """
    
    def __init__(self, tissue_pairs, transform=None, normalize=False, normalizer=None):
        self.tissue_pairs = tissue_pairs
        self.transform = transform
        self.normalize = normalize
        self.normalizer = normalizer
        
    def __len__(self):
        return len(self.tissue_pairs)
    
    def __getitem__(self, idx):
        img_path, mask_path = self.tissue_pairs[idx]
        
        # Load image and mask
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        
        # Apply stain normalization if requested
        if self.normalize and self.normalizer is not None:
            try:
                image = self.normalizer.normalize(image)
            except Exception as e:
                logger.warning(f"Normalization failed for {img_path}: {e}")
                # Use original image if normalization fails
        
        # Convert to PIL for transforms
        image = Image.fromarray(image)
        mask = Image.fromarray(mask)
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        # Convert mask to binary (nuclei vs background)
        mask = (mask > 0).float()
        
        return {
            'image': image,
            'mask': mask,
            'image_path': str(img_path),
            'mask_path': str(mask_path)
        }

# Define transforms
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("✅ RQ4 Dataset class and transforms defined successfully")


### 2.2 Model Loading and Grad-CAM Setup


In [None]:
# =============================================================================
# MODEL LOADING AND GRAD-CAM SETUP
# =============================================================================

def load_pretrained_model(model_path, device='cuda'):
    """
    Load a pretrained U-Net model for RQ4 analysis.
    
    Args:
        model_path: Path to the saved model checkpoint
        device: Device to load model on
    
    Returns:
        model: Loaded U-Net model
    """
    # Create model architecture
    model = create_unet_rq3(n_channels=3, n_classes=1)
    
    # Load checkpoint
    if os.path.exists(model_path):
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        logger.info(f"Loaded model from {model_path}")
    else:
        logger.warning(f"Model checkpoint not found at {model_path}. Using random weights.")
    
    model = model.to(device)
    model.eval()
    return model

def setup_gradcam_models():
    """
    Setup Grad-CAM for different model variants.
    
    Returns:
        models: Dictionary containing different model configurations
    """
    models = {}
    
    # Try to load pretrained models from RQ3 artifacts
    original_model_path = project_root / 'artifacts' / 'rq3_enhanced' / 'checkpoints' / 'best_model_original.pth'
    normalized_model_path = project_root / 'artifacts' / 'rq3_enhanced' / 'checkpoints' / 'best_model_normalized.pth'
    
    # Load or create models
    if original_model_path.exists():
        models['original'] = load_pretrained_model(original_model_path, device)
        logger.info("✅ Loaded pretrained original model")
    else:
        models['original'] = create_unet_rq3(n_channels=3, n_classes=1).to(device)
        logger.warning("⚠️ Using random weights for original model")
    
    if normalized_model_path.exists():
        models['normalized'] = load_pretrained_model(normalized_model_path, device)
        logger.info("✅ Loaded pretrained normalized model")
    else:
        models['normalized'] = create_unet_rq3(n_channels=3, n_classes=1).to(device)
        logger.warning("⚠️ Using random weights for normalized model")
    
    # Setup Grad-CAM for each model
    # Target the last convolutional layer in the decoder
    # Create a list of model names to avoid dictionary modification during iteration
    model_names = list(models.keys())
    
    for model_name in model_names:
        model = models[model_name]
        # Find the target layer (last conv layer in decoder)
        target_layers = []
        for name, module in model.named_modules():
            if 'decoder' in name and isinstance(module, nn.Conv2d):
                target_layers.append(module)
        
        if not target_layers:
            # Fallback to any conv layer
            for name, module in model.named_modules():
                if isinstance(module, nn.Conv2d):
                    target_layers.append(module)
                    break
        
        if target_layers:
            models[f'{model_name}_gradcam'] = SegmentationGradCAM(
                model, target_layers, device
            )
            logger.info(f"✅ Setup Grad-CAM for {model_name} model")
        else:
            logger.error(f"❌ Could not find target layers for {model_name} model")
    
    return models

# Initialize models
print("🔄 Setting up models for RQ4 Grad-CAM analysis...")
models = setup_gradcam_models()

print(f"\n📊 Available model configurations:")
for key in models.keys():
    print(f"   - {key}")

# Initialize stain normalizer
normalizer = GPUVahadaneNormalizer(device=device)
print(f"\n✅ Stain normalizer initialized on {device}")


## 3. Visualization Framework for Grad-CAM Analysis

### 3.1 Comprehensive Visualization Functions


In [None]:
# =============================================================================
# CORRECTED MODEL LOADING AND GRAD-CAM SETUP
# =============================================================================

def load_pretrained_model(model_path, device='cuda'):
    """
    Load a pretrained U-Net model for RQ4 analysis.
    
    Args:
        model_path: Path to the saved model checkpoint
        device: Device to load model on
    
    Returns:
        model: Loaded U-Net model
    """
    # Create model architecture with correct parameters
    model = create_unet_rq3(n_channels=3, n_classes=1)
    
    # Load checkpoint
    if os.path.exists(model_path):
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        logger.info(f"Loaded model from {model_path}")
    else:
        logger.warning(f"Model checkpoint not found at {model_path}. Using random weights.")
    
    model = model.to(device)
    model.eval()
    return model

def setup_gradcam_models():
    """
    Setup Grad-CAM for different model variants.
    
    Returns:
        models: Dictionary containing different model configurations
    """
    models = {}
    
    # Try to load pretrained models from RQ3 artifacts
    original_model_path = project_root / 'artifacts' / 'rq3_enhanced' / 'checkpoints' / 'best_model_original.pth'
    normalized_model_path = project_root / 'artifacts' / 'rq3_enhanced' / 'checkpoints' / 'best_model_normalized.pth'
    
    # Load or create models
    if original_model_path.exists():
        models['original'] = load_pretrained_model(original_model_path, device)
        logger.info("✅ Loaded pretrained original model")
    else:
        # Use the already created model
        models['original'] = original_model
        logger.warning("⚠️ Using random weights for original model")
    
    if normalized_model_path.exists():
        models['normalized'] = load_pretrained_model(normalized_model_path, device)
        logger.info("✅ Loaded pretrained normalized model")
    else:
        # Use the already created model
        models['normalized'] = normalized_model
        logger.warning("⚠️ Using random weights for normalized model")
    
    # Add baseline model
    models['baseline'] = baseline_model
    
    # Setup Grad-CAM for each model
    # Target the last convolutional layer in the decoder
    # Create a list of model names to avoid dictionary modification during iteration
    model_names = list(models.keys())
    
    for model_name in model_names:
        if model_name == 'baseline':
            continue  # Skip baseline for Grad-CAM setup
            
        model = models[model_name]
        # Find the target layer (last conv layer in decoder)
        target_layers = []
        for name, module in model.named_modules():
            if 'decoder' in name and isinstance(module, nn.Conv2d):
                target_layers.append(module)
        
        if not target_layers:
            # Fallback to any conv layer
            for name, module in model.named_modules():
                if isinstance(module, nn.Conv2d):
                    target_layers.append(module)
                    break
        
        if target_layers:
            models[f'{model_name}_gradcam'] = SegmentationGradCAM(
                model, target_layers, device
            )
            logger.info(f"✅ Setup Grad-CAM for {model_name} model")
        else:
            logger.error(f"❌ Could not find target layers for {model_name} model")
    
    return models

# Initialize models with corrected parameters
print("🔄 Setting up models for RQ4 Grad-CAM analysis...")
models = setup_gradcam_models()

print(f"\n📊 Available model configurations:")
for key in models.keys():
    print(f"   - {key}")

# Initialize stain normalizer
normalizer = GPUVahadaneNormalizer(device=device)
print(f"\n✅ Stain normalizer initialized on {device}")


In [None]:
# =============================================================================
# CORRECTED MODEL LOADING AND GRAD-CAM SETUP
# =============================================================================

def load_pretrained_model(model_path, device='cuda'):
    """
    Load a pretrained U-Net model for RQ4 analysis.
    
    Args:
        model_path: Path to the saved model checkpoint
        device: Device to load model on
    
    Returns:
        model: Loaded U-Net model
    """
    # Create model architecture with correct parameters
    model = create_unet_rq3(n_channels=3, n_classes=1)
    
    # Load checkpoint
    if os.path.exists(model_path):
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        logger.info(f"Loaded model from {model_path}")
    else:
        logger.warning(f"Model checkpoint not found at {model_path}. Using random weights.")
    
    model = model.to(device)
    model.eval()
    return model

def setup_gradcam_models():
    """
    Setup Grad-CAM for different model variants.
    
    Returns:
        models: Dictionary containing different model configurations
    """
    models = {}
    
    # Try to load pretrained models from RQ3 artifacts
    original_model_path = project_root / 'artifacts' / 'rq3_enhanced' / 'checkpoints' / 'best_model_original.pth'
    normalized_model_path = project_root / 'artifacts' / 'rq3_enhanced' / 'checkpoints' / 'best_model_normalized.pth'
    
    # Load or create models
    if original_model_path.exists():
        models['original'] = load_pretrained_model(original_model_path, device)
        logger.info("✅ Loaded pretrained original model")
    else:
        models['original'] = create_unet_rq3(n_channels=3, n_classes=1).to(device)
        logger.warning("⚠️ Using random weights for original model")
    
    if normalized_model_path.exists():
        models['normalized'] = load_pretrained_model(normalized_model_path, device)
        logger.info("✅ Loaded pretrained normalized model")
    else:
        models['normalized'] = create_unet_rq3(n_channels=3, n_classes=1).to(device)
        logger.warning("⚠️ Using random weights for normalized model")
    
    # Setup Grad-CAM for each model
    # Target the last convolutional layer in the decoder
    for model_name, model in models.items():
        # Find the target layer (last conv layer in decoder)
        target_layers = []
        for name, module in model.named_modules():
            if 'decoder' in name and isinstance(module, nn.Conv2d):
                target_layers.append(module)
        
        if not target_layers:
            # Fallback to any conv layer
            for name, module in model.named_modules():
                if isinstance(module, nn.Conv2d):
                    target_layers.append(module)
                    break
        
        if target_layers:
            models[f'{model_name}_gradcam'] = SegmentationGradCAM(
                model, target_layers, device
            )
            logger.info(f"✅ Setup Grad-CAM for {model_name} model")
        else:
            logger.error(f"❌ Could not find target layers for {model_name} model")
    
    return models

# Initialize models with corrected parameters
print("🔄 Setting up models for RQ4 Grad-CAM analysis...")
models = setup_gradcam_models()

print(f"\n📊 Available model configurations:")
for key in models.keys():
    print(f"   - {key}")

# Initialize stain normalizer
normalizer = GPUVahadaneNormalizer(device=device)
print(f"\n✅ Stain normalizer initialized on {device}")


In [None]:
# =============================================================================
# FIXED MODEL LOADING AND GRAD-CAM SETUP (CORRECTED VERSION)
# =============================================================================

def setup_gradcam_models_fixed():
    """
    Setup Grad-CAM for different model variants - FIXED VERSION.
    
    Returns:
        models: Dictionary containing different model configurations
    """
    models = {}
    
    # Try to load pretrained models from RQ3 artifacts
    original_model_path = project_root / 'artifacts' / 'rq3_enhanced' / 'checkpoints' / 'best_model_original.pth'
    normalized_model_path = project_root / 'artifacts' / 'rq3_enhanced' / 'checkpoints' / 'best_model_normalized.pth'
    
    # Load or create models
    if original_model_path.exists():
        models['original'] = load_pretrained_model(original_model_path, device)
        logger.info("✅ Loaded pretrained original model")
    else:
        # Use the already created model
        models['original'] = original_model
        logger.warning("⚠️ Using random weights for original model")
    
    if normalized_model_path.exists():
        models['normalized'] = load_pretrained_model(normalized_model_path, device)
        logger.info("✅ Loaded pretrained normalized model")
    else:
        # Use the already created model
        models['normalized'] = normalized_model
        logger.warning("⚠️ Using random weights for normalized model")
    
    # Add baseline model
    models['baseline'] = baseline_model
    
    # Setup Grad-CAM for each model
    # Target the last convolutional layer in the decoder
    # Create a list of model names to avoid dictionary modification during iteration
    model_names = list(models.keys())
    
    for model_name in model_names:
        if model_name == 'baseline':
            continue  # Skip baseline for Grad-CAM setup
            
        model = models[model_name]
        # Find the target layer (last conv layer in decoder)
        target_layers = []
        for name, module in model.named_modules():
            if 'decoder' in name and isinstance(module, nn.Conv2d):
                target_layers.append(module)
        
        if not target_layers:
            # Fallback to any conv layer
            for name, module in model.named_modules():
                if isinstance(module, nn.Conv2d):
                    target_layers.append(module)
                    break
        
        if target_layers:
            models[f'{model_name}_gradcam'] = SegmentationGradCAM(
                model, target_layers, device
            )
            logger.info(f"✅ Setup Grad-CAM for {model_name} model")
        else:
            logger.error(f"❌ Could not find target layers for {model_name} model")
    
    return models

# Initialize models with the FIXED function
print("🔄 Setting up models for RQ4 Grad-CAM analysis (FIXED VERSION)...")
models = setup_gradcam_models_fixed()

print(f"\n📊 Available model configurations:")
for key in models.keys():
    print(f"   - {key}")

# Initialize stain normalizer
normalizer = GPUVahadaneNormalizer(device=device)
print(f"\n✅ Stain normalizer initialized on {device}")


In [None]:
# =============================================================================
# CORRECTED EVALUATION PIPELINE FOR RQ4
# =============================================================================

class RQ4Evaluator:
    """
    Comprehensive evaluation framework for RQ4 Grad-CAM explainability analysis.
    """
    
    def __init__(self, models, normalizer, device='cuda'):
        self.models = models
        self.normalizer = normalizer
        self.device = device
        self.metrics = ExplainabilityMetrics()
        
        # Results storage
        self.results = {
            'baseline': [],
            'gradcam_original': [],
            'gradcam_normalized': []
        }
    
    def evaluate_single_image(self, image, mask, tissue_type, image_path):
        """
        Evaluate a single image across all model variants.
        
        Args:
            image: Input image tensor (C, H, W)
            mask: Ground truth mask tensor (H, W)
            tissue_type: Type of tissue
            image_path: Path to the image file
        
        Returns:
            results: Dictionary containing evaluation results
        """
        results = {
            'tissue_type': tissue_type,
            'image_path': image_path,
            'image_id': Path(image_path).stem
        }
        
        # Prepare input
        image_batch = image.unsqueeze(0).to(self.device)
        mask_np = mask.numpy()
        
        # Create background mask (inverse of nuclei mask)
        background_mask = (mask_np == 0).astype(np.uint8)
        nuclei_mask = (mask_np > 0).astype(np.uint8)
        
        # 1. Baseline evaluation (without Grad-CAM)
        with torch.no_grad():
            baseline_pred = self.models['baseline'](image_batch)
            baseline_pred = torch.sigmoid(baseline_pred).squeeze().cpu().numpy()
            baseline_pred_binary = (baseline_pred > 0.5).astype(np.uint8)
        
        # Calculate baseline metrics
        baseline_iou = self.metrics.localization_accuracy(baseline_pred, mask_np)
        baseline_dice = f1_score(mask_np.flatten(), baseline_pred_binary.flatten(), average='binary')
        
        results['baseline'] = {
            'iou': baseline_iou,
            'dice': baseline_dice,
            'prediction': baseline_pred
        }
        
        # 2. Grad-CAM with original model
        try:
            original_cam = self.models['original_gradcam'].generate_attention_map(image_batch)
            
            # Calculate explainability metrics
            orig_localization = self.metrics.localization_accuracy(original_cam, mask_np)
            orig_biological = self.metrics.biological_relevance(original_cam, nuclei_mask, background_mask)
            orig_coherence = self.metrics.spatial_coherence(original_cam)
            
            results['gradcam_original'] = {
                'localization_accuracy': orig_localization,
                'biological_relevance': orig_biological,
                'spatial_coherence': orig_coherence,
                'attention_map': original_cam,
                'iou': baseline_iou,  # Same segmentation performance
                'dice': baseline_dice
            }
        except Exception as e:
            logger.warning(f"Grad-CAM evaluation failed for original model: {e}")
            results['gradcam_original'] = None
        
        # 3. Grad-CAM with normalized model
        try:
            # Apply stain normalization
            image_denorm = self.denormalize_image(image)
            image_np = image_denorm.permute(1, 2, 0).numpy()
            image_np = (image_np * 255).astype(np.uint8)
            
            # Normalize image
            normalized_image = self.normalizer.normalize(image_np)
            normalized_image = Image.fromarray(normalized_image)
            normalized_image = transform(normalized_image)
            normalized_image_batch = normalized_image.unsqueeze(0).to(self.device)
            
            # Generate Grad-CAM
            normalized_cam = self.models['normalized_gradcam'].generate_attention_map(normalized_image_batch)
            
            # Calculate explainability metrics
            norm_localization = self.metrics.localization_accuracy(normalized_cam, mask_np)
            norm_biological = self.metrics.biological_relevance(normalized_cam, nuclei_mask, background_mask)
            norm_coherence = self.metrics.spatial_coherence(normalized_cam)
            
            # Calculate attention consistency with original
            if results['gradcam_original'] is not None:
                consistency = self.metrics.attention_consistency(original_cam, normalized_cam)
            else:
                consistency = 0.0
            
            results['gradcam_normalized'] = {
                'localization_accuracy': norm_localization,
                'biological_relevance': norm_biological,
                'spatial_coherence': norm_coherence,
                'attention_consistency': consistency,
                'attention_map': normalized_cam,
                'iou': baseline_iou,  # Same segmentation performance
                'dice': baseline_dice
            }
        except Exception as e:
            logger.warning(f"Grad-CAM evaluation failed for normalized model: {e}")
            results['gradcam_normalized'] = None
        
        return results
    
    def denormalize_image(self, tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        """Denormalize tensor image for visualization."""
        tensor = tensor.clone()
        for t, m, s in zip(tensor, mean, std):
            t.mul_(s).add_(m)
        return torch.clamp(tensor, 0, 1)
    
    def evaluate_dataset(self, dataset, max_samples=None):
        """
        Evaluate entire dataset across all model variants.
        
        Args:
            dataset: RQ4Dataset instance
            max_samples: Maximum number of samples to evaluate
        
        Returns:
            results: Comprehensive evaluation results
        """
        logger.info(f"Starting comprehensive evaluation of {len(dataset)} samples...")
        
        num_samples = min(len(dataset), max_samples) if max_samples else len(dataset)
        
        for i in tqdm(range(num_samples), desc="Evaluating samples"):
            try:
                sample = dataset[i]
                image = sample['image']
                mask = sample['mask']
                image_path = sample['image_path']
                
                # Extract tissue type from path
                tissue_type = Path(image_path).parent.parent.parent.name
                
                # Evaluate single image
                results = self.evaluate_single_image(image, mask, tissue_type, image_path)
                
                # Store results
                self.results['baseline'].append(results['baseline'])
                if results['gradcam_original']:
                    self.results['gradcam_original'].append(results['gradcam_original'])
                if results['gradcam_normalized']:
                    self.results['gradcam_normalized'].append(results['gradcam_normalized'])
                
            except Exception as e:
                logger.error(f"Evaluation failed for sample {i}: {e}")
                continue
        
        logger.info(f"Evaluation completed. Processed {len(self.results['baseline'])} samples.")
        return self.results
    
    def generate_summary_statistics(self):
        """Generate summary statistics for all model variants."""
        summary = {}
        
        for model_type, results in self.results.items():
            if not results:
                continue
                
            # Convert to DataFrame for easier analysis
            df = pd.DataFrame(results)
            
            summary[model_type] = {
                'count': len(df),
                'localization_accuracy': {
                    'mean': df['localization_accuracy'].mean() if 'localization_accuracy' in df.columns else None,
                    'std': df['localization_accuracy'].std() if 'localization_accuracy' in df.columns else None,
                    'median': df['localization_accuracy'].median() if 'localization_accuracy' in df.columns else None
                },
                'biological_relevance': {
                    'mean': df['biological_relevance'].mean() if 'biological_relevance' in df.columns else None,
                    'std': df['biological_relevance'].std() if 'biological_relevance' in df.columns else None,
                    'median': df['biological_relevance'].median() if 'biological_relevance' in df.columns else None
                },
                'spatial_coherence': {
                    'mean': df['spatial_coherence'].mean() if 'spatial_coherence' in df.columns else None,
                    'std': df['spatial_coherence'].std() if 'spatial_coherence' in df.columns else None,
                    'median': df['spatial_coherence'].median() if 'spatial_coherence' in df.columns else None
                },
                'attention_consistency': {
                    'mean': df['attention_consistency'].mean() if 'attention_consistency' in df.columns else None,
                    'std': df['attention_consistency'].std() if 'attention_consistency' in df.columns else None,
                    'median': df['attention_consistency'].median() if 'attention_consistency' in df.columns else None
                }
            }
        
        return summary

# Initialize evaluator
evaluator = RQ4Evaluator(models, normalizer, device)
print("✅ RQ4 evaluation framework initialized")


In [None]:
# =============================================================================
# VISUALIZATION FRAMEWORK FOR GRAD-CAM ANALYSIS
# =============================================================================

class GradCAMVisualizer:
    """
    Comprehensive visualization framework for Grad-CAM analysis.
    """
    
    def __init__(self, save_dir):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
    
    def denormalize_image(self, tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        """Denormalize tensor image for visualization."""
        tensor = tensor.clone()
        for t, m, s in zip(tensor, mean, std):
            t.mul_(s).add_(m)
        return torch.clamp(tensor, 0, 1)
    
    def create_gradcam_overlay(self, image, cam, alpha=0.4):
        """Create Grad-CAM overlay on original image."""
        # Convert image to numpy
        if isinstance(image, torch.Tensor):
            image = self.denormalize_image(image)
            image = image.permute(1, 2, 0).numpy()
        
        # Normalize CAM to 0-1
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        
        # Create heatmap
        heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        
        # Resize heatmap to match image
        if heatmap.shape[:2] != image.shape[:2]:
            heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
        
        # Create overlay
        overlay = cv2.addWeighted(image, 1-alpha, heatmap, alpha, 0)
        return overlay, heatmap
    
    def plot_comparison_grid(self, images, masks, cams, titles, save_path=None):
        """Create comparison grid showing original, mask, and Grad-CAM."""
        n_images = len(images)
        fig, axes = plt.subplots(n_images, 4, figsize=(16, 4*n_images))
        if n_images == 1:
            axes = axes.reshape(1, -1)
        
        for i in range(n_images):
            # Original image
            img = self.denormalize_image(images[i]) if isinstance(images[i], torch.Tensor) else images[i]
            if isinstance(img, torch.Tensor):
                img = img.permute(1, 2, 0).numpy()
            
            axes[i, 0].imshow(img)
            axes[i, 0].set_title(f"{titles[i]}\nOriginal Image")
            axes[i, 0].axis('off')
            
            # Ground truth mask
            mask = masks[i].numpy() if isinstance(masks[i], torch.Tensor) else masks[i]
            axes[i, 1].imshow(mask, cmap='gray')
            axes[i, 1].set_title("Ground Truth\nMask")
            axes[i, 1].axis('off')
            
            # Grad-CAM heatmap
            cam = cams[i]
            im = axes[i, 2].imshow(cam, cmap='jet')
            axes[i, 2].set_title("Grad-CAM\nHeatmap")
            axes[i, 2].axis('off')
            plt.colorbar(im, ax=axes[i, 2], fraction=0.046, pad=0.04)
            
            # Overlay
            overlay, _ = self.create_gradcam_overlay(img, cam)
            axes[i, 3].imshow(overlay)
            axes[i, 3].set_title("Grad-CAM\nOverlay")
            axes[i, 3].axis('off')
        
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_attention_comparison(self, original_cam, normalized_cam, image, save_path=None):
        """Compare attention maps between original and normalized models."""
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        
        # Denormalize image
        img = self.denormalize_image(image) if isinstance(image, torch.Tensor) else image
        if isinstance(img, torch.Tensor):
            img = img.permute(1, 2, 0).numpy()
        
        # Original image
        axes[0, 0].imshow(img)
        axes[0, 0].set_title("Original Image")
        axes[0, 0].axis('off')
        
        # Original Grad-CAM
        overlay_orig, heatmap_orig = self.create_gradcam_overlay(img, original_cam)
        axes[0, 1].imshow(overlay_orig)
        axes[0, 1].set_title("Original Model\nGrad-CAM")
        axes[0, 1].axis('off')
        
        # Original heatmap
        im1 = axes[0, 2].imshow(original_cam, cmap='jet')
        axes[0, 2].set_title("Original Model\nHeatmap")
        axes[0, 2].axis('off')
        plt.colorbar(im1, ax=axes[0, 2], fraction=0.046, pad=0.04)
        
        # Normalized image (if different)
        axes[1, 0].imshow(img)  # Same image for now
        axes[1, 0].set_title("Normalized Image")
        axes[1, 0].axis('off')
        
        # Normalized Grad-CAM
        overlay_norm, heatmap_norm = self.create_gradcam_overlay(img, normalized_cam)
        axes[1, 1].imshow(overlay_norm)
        axes[1, 1].set_title("Normalized Model\nGrad-CAM")
        axes[1, 1].axis('off')
        
        # Normalized heatmap
        im2 = axes[1, 2].imshow(normalized_cam, cmap='jet')
        axes[1, 2].set_title("Normalized Model\nHeatmap")
        axes[1, 2].axis('off')
        plt.colorbar(im2, ax=axes[1, 2], fraction=0.046, pad=0.04)
        
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_metrics_comparison(self, metrics_data, save_path=None):
        """Plot comparison of explainability metrics."""
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Localization accuracy
        sns.boxplot(data=metrics_data, x='model_type', y='localization_accuracy', ax=axes[0, 0])
        axes[0, 0].set_title('Localization Accuracy')
        axes[0, 0].set_ylabel('IoU Score')
        
        # Biological relevance
        sns.boxplot(data=metrics_data, x='model_type', y='biological_relevance', ax=axes[0, 1])
        axes[0, 1].set_title('Biological Relevance')
        axes[0, 1].set_ylabel('Relevance Score')
        
        # Spatial coherence
        sns.boxplot(data=metrics_data, x='model_type', y='spatial_coherence', ax=axes[1, 0])
        axes[1, 0].set_title('Spatial Coherence')
        axes[1, 0].set_ylabel('Coherence Score')
        
        # Attention consistency
        sns.boxplot(data=metrics_data, x='model_type', y='attention_consistency', ax=axes[1, 1])
        axes[1, 1].set_title('Attention Consistency')
        axes[1, 1].set_ylabel('Correlation Score')
        
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()

# Initialize visualizer
visualizer = GradCAMVisualizer(artifacts_dir / 'gradcam_visualizations')
print("✅ Grad-CAM visualization framework initialized")


## 4. Evaluation Pipeline for RQ4

### 4.1 Comprehensive Evaluation Framework


In [None]:
# =============================================================================
# COMPREHENSIVE EVALUATION PIPELINE FOR RQ4
# =============================================================================

class RQ4Evaluator:
    """
    Comprehensive evaluation framework for RQ4 Grad-CAM explainability analysis.
    """
    
    def __init__(self, models, normalizer, device='cuda'):
        self.models = models
        self.normalizer = normalizer
        self.device = device
        self.metrics = ExplainabilityMetrics()
        
        # Results storage
        self.results = {
            'baseline': [],
            'gradcam_original': [],
            'gradcam_normalized': []
        }
    
    def evaluate_single_image(self, image, mask, tissue_type, image_path):
        """
        Evaluate a single image across all model variants.
        
        Args:
            image: Input image tensor (C, H, W)
            mask: Ground truth mask tensor (H, W)
            tissue_type: Type of tissue
            image_path: Path to the image file
        
        Returns:
            results: Dictionary containing evaluation results
        """
        results = {
            'tissue_type': tissue_type,
            'image_path': image_path,
            'image_id': Path(image_path).stem
        }
        
        # Prepare input
        image_batch = image.unsqueeze(0).to(self.device)
        mask_np = mask.numpy()
        
        # Create background mask (inverse of nuclei mask)
        background_mask = (mask_np == 0).astype(np.uint8)
        nuclei_mask = (mask_np > 0).astype(np.uint8)
        
        # 1. Baseline evaluation (without Grad-CAM)
        with torch.no_grad():
            baseline_pred = self.models['baseline'](image_batch)
            # Handle different output shapes - baseline model outputs (B, C, H, W)
            if baseline_pred.dim() == 4:
                # Take the first channel if multi-channel output
                baseline_pred = baseline_pred[:, 0, :, :]  # (B, H, W)
            baseline_pred = torch.sigmoid(baseline_pred).squeeze().cpu().numpy()
            baseline_pred_binary = (baseline_pred > 0.5).astype(np.uint8)
        
        # Calculate baseline metrics
        baseline_iou = self.metrics.localization_accuracy(baseline_pred, mask_np)
        baseline_dice = f1_score(mask_np.flatten(), baseline_pred_binary.flatten(), average='binary')
        
        results['baseline'] = {
            'iou': baseline_iou,
            'dice': baseline_dice,
            'prediction': baseline_pred
        }
        
        # 2. Grad-CAM with original model
        try:
            original_cam = self.models['original_gradcam'].generate_attention_map(image_batch)
            
            # Calculate explainability metrics
            orig_localization = self.metrics.localization_accuracy(original_cam, mask_np)
            orig_biological = self.metrics.biological_relevance(original_cam, nuclei_mask, background_mask)
            orig_coherence = self.metrics.spatial_coherence(original_cam)
            
            results['gradcam_original'] = {
                'localization_accuracy': orig_localization,
                'biological_relevance': orig_biological,
                'spatial_coherence': orig_coherence,
                'attention_map': original_cam,
                'iou': baseline_iou,  # Same segmentation performance
                'dice': baseline_dice
            }
        except Exception as e:
            logger.warning(f"Grad-CAM evaluation failed for original model: {e}")
            results['gradcam_original'] = None
        
        # 3. Grad-CAM with normalized model
        try:
            # Apply stain normalization
            image_denorm = self.denormalize_image(image)
            image_np = image_denorm.permute(1, 2, 0).numpy()
            image_np = (image_np * 255).astype(np.uint8)
            
            # Normalize image
            normalized_image = self.normalizer.normalize(image_np)
            normalized_image = Image.fromarray(normalized_image)
            normalized_image = transform(normalized_image)
            normalized_image_batch = normalized_image.unsqueeze(0).to(self.device)
            
            # Generate Grad-CAM
            normalized_cam = self.models['normalized_gradcam'].generate_attention_map(normalized_image_batch)
            
            # Calculate explainability metrics
            norm_localization = self.metrics.localization_accuracy(normalized_cam, mask_np)
            norm_biological = self.metrics.biological_relevance(normalized_cam, nuclei_mask, background_mask)
            norm_coherence = self.metrics.spatial_coherence(normalized_cam)
            
            # Calculate attention consistency with original
            if results['gradcam_original'] is not None:
                consistency = self.metrics.attention_consistency(original_cam, normalized_cam)
            else:
                consistency = 0.0
            
            results['gradcam_normalized'] = {
                'localization_accuracy': norm_localization,
                'biological_relevance': norm_biological,
                'spatial_coherence': norm_coherence,
                'attention_consistency': consistency,
                'attention_map': normalized_cam,
                'iou': baseline_iou,  # Same segmentation performance
                'dice': baseline_dice
            }
        except Exception as e:
            logger.warning(f"Grad-CAM evaluation failed for normalized model: {e}")
            results['gradcam_normalized'] = None
        
        return results
    
    def denormalize_image(self, tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        """Denormalize tensor image for visualization."""
        tensor = tensor.clone()
        for t, m, s in zip(tensor, mean, std):
            t.mul_(s).add_(m)
        return torch.clamp(tensor, 0, 1)
    
    def evaluate_dataset(self, dataset, max_samples=None):
        """
        Evaluate entire dataset across all model variants.
        
        Args:
            dataset: RQ4Dataset instance
            max_samples: Maximum number of samples to evaluate
        
        Returns:
            results: Comprehensive evaluation results
        """
        logger.info(f"Starting comprehensive evaluation of {len(dataset)} samples...")
        
        num_samples = min(len(dataset), max_samples) if max_samples else len(dataset)
        
        for i in tqdm(range(num_samples), desc="Evaluating samples"):
            try:
                sample = dataset[i]
                image = sample['image']
                mask = sample['mask']
                image_path = sample['image_path']
                
                # Extract tissue type from path
                tissue_type = Path(image_path).parent.parent.parent.name
                
                # Evaluate single image
                results = self.evaluate_single_image(image, mask, tissue_type, image_path)
                
                # Store results
                self.results['baseline'].append(results['baseline'])
                if results['gradcam_original']:
                    self.results['gradcam_original'].append(results['gradcam_original'])
                if results['gradcam_normalized']:
                    self.results['gradcam_normalized'].append(results['gradcam_normalized'])
                
            except Exception as e:
                logger.error(f"Evaluation failed for sample {i}: {e}")
                continue
        
        logger.info(f"Evaluation completed. Processed {len(self.results['baseline'])} samples.")
        return self.results
    
    def generate_summary_statistics(self):
        """Generate summary statistics for all model variants."""
        summary = {}
        
        for model_type, results in self.results.items():
            if not results:
                continue
                
            # Convert to DataFrame for easier analysis
            df = pd.DataFrame(results)
            
            summary[model_type] = {
                'count': len(df),
                'localization_accuracy': {
                    'mean': df['localization_accuracy'].mean() if 'localization_accuracy' in df.columns else None,
                    'std': df['localization_accuracy'].std() if 'localization_accuracy' in df.columns else None,
                    'median': df['localization_accuracy'].median() if 'localization_accuracy' in df.columns else None
                },
                'biological_relevance': {
                    'mean': df['biological_relevance'].mean() if 'biological_relevance' in df.columns else None,
                    'std': df['biological_relevance'].std() if 'biological_relevance' in df.columns else None,
                    'median': df['biological_relevance'].median() if 'biological_relevance' in df.columns else None
                },
                'spatial_coherence': {
                    'mean': df['spatial_coherence'].mean() if 'spatial_coherence' in df.columns else None,
                    'std': df['spatial_coherence'].std() if 'spatial_coherence' in df.columns else None,
                    'median': df['spatial_coherence'].median() if 'spatial_coherence' in df.columns else None
                },
                'attention_consistency': {
                    'mean': df['attention_consistency'].mean() if 'attention_consistency' in df.columns else None,
                    'std': df['attention_consistency'].std() if 'attention_consistency' in df.columns else None,
                    'median': df['attention_consistency'].median() if 'attention_consistency' in df.columns else None
                }
            }
        
        return summary

# Initialize evaluator
evaluator = RQ4Evaluator(models, normalizer, device)
print("✅ RQ4 evaluation framework initialized")


## 5. Statistical Analysis Framework

### 5.1 ANOVA and Statistical Tests for Explainability Metrics


In [None]:
# =============================================================================
# STATISTICAL ANALYSIS FRAMEWORK FOR RQ4
# =============================================================================

class RQ4StatisticalAnalyzer:
    """
    Comprehensive statistical analysis framework for RQ4 explainability evaluation.
    """
    
    def __init__(self, results):
        self.results = results
        self.alpha = 0.05  # Significance level
    
    def prepare_data_for_analysis(self):
        """Prepare data in format suitable for statistical analysis."""
        analysis_data = []
        
        # Process each model type
        for model_type, results in self.results.items():
            if not results:
                continue
                
            for result in results:
                if model_type == 'baseline':
                    analysis_data.append({
                        'model_type': 'Baseline',
                        'localization_accuracy': result['iou'],
                        'biological_relevance': None,
                        'spatial_coherence': None,
                        'attention_consistency': None,
                        'tissue_type': result.get('tissue_type', 'Unknown')
                    })
                else:
                    analysis_data.append({
                        'model_type': 'GradCAM_Original' if 'original' in model_type else 'GradCAM_Normalized',
                        'localization_accuracy': result['localization_accuracy'],
                        'biological_relevance': result['biological_relevance'],
                        'spatial_coherence': result['spatial_coherence'],
                        'attention_consistency': result.get('attention_consistency', None),
                        'tissue_type': result.get('tissue_type', 'Unknown')
                    })
        
        return pd.DataFrame(analysis_data)
    
    def perform_anova_analysis(self, df):
        """Perform ANOVA analysis for explainability metrics."""
        anova_results = {}
        
        # Metrics to analyze
        metrics = ['localization_accuracy', 'biological_relevance', 'spatial_coherence']
        
        for metric in metrics:
            # Filter out None values
            metric_data = df[df[metric].notna()]
            
            if len(metric_data) < 3:
                logger.warning(f"Insufficient data for {metric} ANOVA analysis")
                continue
            
            # Prepare data for ANOVA
            groups = []
            group_names = []
            
            for model_type in metric_data['model_type'].unique():
                group_data = metric_data[metric_data['model_type'] == model_type][metric].values
                if len(group_data) > 0:
                    groups.append(group_data)
                    group_names.append(model_type)
            
            if len(groups) < 2:
                logger.warning(f"Not enough groups for {metric} ANOVA")
                continue
            
            # Perform ANOVA
            try:
                f_stat, p_value = f_oneway(*groups)
                
                anova_results[metric] = {
                    'f_statistic': f_stat,
                    'p_value': p_value,
                    'significant': p_value < self.alpha,
                    'groups': group_names,
                    'group_sizes': [len(group) for group in groups],
                    'group_means': [np.mean(group) for group in groups],
                    'group_stds': [np.std(group) for group in groups]
                }
                
                logger.info(f"ANOVA for {metric}: F={f_stat:.4f}, p={p_value:.4f}, significant={p_value < self.alpha}")
                
            except Exception as e:
                logger.error(f"ANOVA failed for {metric}: {e}")
                anova_results[metric] = None
        
        return anova_results
    
    def perform_post_hoc_tests(self, df, anova_results):
        """Perform post-hoc tests for significant ANOVA results."""
        post_hoc_results = {}
        
        for metric, anova_result in anova_results.items():
            if anova_result is None or not anova_result['significant']:
                continue
            
            # Prepare data for post-hoc tests
            metric_data = df[df[metric].notna()]
            
            # Perform pairwise t-tests
            pairwise_results = {}
            model_types = metric_data['model_type'].unique()
            
            for i, model1 in enumerate(model_types):
                for j, model2 in enumerate(model_types):
                    if i >= j:
                        continue
                    
                    group1 = metric_data[metric_data['model_type'] == model1][metric].values
                    group2 = metric_data[metric_data['model_type'] == model2][metric].values
                    
                    if len(group1) < 2 or len(group2) < 2:
                        continue
                    
                    try:
                        # Perform t-test
                        t_stat, p_value = ttest_rel(group1, group2) if len(group1) == len(group2) else ttest_ind(group1, group2)
                        
                        pairwise_results[f"{model1}_vs_{model2}"] = {
                            't_statistic': t_stat,
                            'p_value': p_value,
                            'significant': p_value < self.alpha,
                            'mean_diff': np.mean(group1) - np.mean(group2),
                            'effect_size': self.calculate_cohens_d(group1, group2)
                        }
                        
                    except Exception as e:
                        logger.error(f"Post-hoc test failed for {model1} vs {model2}: {e}")
            
            post_hoc_results[metric] = pairwise_results
        
        return post_hoc_results
    
    def calculate_cohens_d(self, group1, group2):
        """Calculate Cohen's d effect size."""
        n1, n2 = len(group1), len(group2)
        s1, s2 = np.std(group1, ddof=1), np.std(group2, ddof=1)
        
        # Pooled standard deviation
        pooled_std = np.sqrt(((n1 - 1) * s1**2 + (n2 - 1) * s2**2) / (n1 + n2 - 2))
        
        if pooled_std == 0:
            return 0
        
        return (np.mean(group1) - np.mean(group2)) / pooled_std
    
    def perform_normality_tests(self, df):
        """Perform normality tests for each group."""
        normality_results = {}
        
        for model_type in df['model_type'].unique():
            model_data = df[df['model_type'] == model_type]
            normality_results[model_type] = {}
            
            for metric in ['localization_accuracy', 'biological_relevance', 'spatial_coherence']:
                metric_data = model_data[model_data[metric].notna()][metric].values
                
                if len(metric_data) < 3:
                    continue
                
                try:
                    # Shapiro-Wilk test
                    shapiro_stat, shapiro_p = shapiro(metric_data)
                    
                    normality_results[model_type][metric] = {
                        'shapiro_statistic': shapiro_stat,
                        'shapiro_p_value': shapiro_p,
                        'is_normal': shapiro_p > self.alpha,
                        'sample_size': len(metric_data)
                    }
                    
                except Exception as e:
                    logger.error(f"Normality test failed for {model_type} - {metric}: {e}")
        
        return normality_results
    
    def generate_comprehensive_report(self):
        """Generate comprehensive statistical analysis report."""
        logger.info("Generating comprehensive statistical analysis report...")
        
        # Prepare data
        df = self.prepare_data_for_analysis()
        
        # Perform analyses
        anova_results = self.perform_anova_analysis(df)
        post_hoc_results = self.perform_post_hoc_tests(df, anova_results)
        normality_results = self.perform_normality_tests(df)
        
        # Generate summary
        report = {
            'data_summary': {
                'total_samples': len(df),
                'model_types': df['model_type'].value_counts().to_dict(),
                'tissue_types': df['tissue_type'].value_counts().to_dict()
            },
            'anova_results': anova_results,
            'post_hoc_results': post_hoc_results,
            'normality_results': normality_results,
            'descriptive_statistics': self.generate_descriptive_statistics(df)
        }
        
        return report
    
    def generate_descriptive_statistics(self, df):
        """Generate descriptive statistics for all metrics."""
        desc_stats = {}
        
        for metric in ['localization_accuracy', 'biological_relevance', 'spatial_coherence']:
            metric_data = df[df[metric].notna()]
            
            if len(metric_data) == 0:
                continue
            
            desc_stats[metric] = {
                'overall': {
                    'mean': metric_data[metric].mean(),
                    'std': metric_data[metric].std(),
                    'median': metric_data[metric].median(),
                    'min': metric_data[metric].min(),
                    'max': metric_data[metric].max(),
                    'q25': metric_data[metric].quantile(0.25),
                    'q75': metric_data[metric].quantile(0.75)
                }
            }
            
            # Per model type statistics
            for model_type in metric_data['model_type'].unique():
                model_metric_data = metric_data[metric_data['model_type'] == model_type][metric]
                
                desc_stats[metric][model_type] = {
                    'mean': model_metric_data.mean(),
                    'std': model_metric_data.std(),
                    'median': model_metric_data.median(),
                    'count': len(model_metric_data)
                }
        
        return desc_stats
    
    def save_results(self, report, save_path):
        """Save statistical analysis results."""
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        
        # Save as JSON
        with open(save_path, 'w') as f:
            json.dump(report, f, indent=2, default=str)
        
        logger.info(f"Statistical analysis results saved to {save_path}")

# Initialize statistical analyzer
statistical_analyzer = RQ4StatisticalAnalyzer(evaluator.results)
print("✅ RQ4 statistical analysis framework initialized")


## 6. Main Execution Pipeline

### 6.1 Complete RQ4 Analysis Execution


In [None]:
# =============================================================================
# MAIN EXECUTION PIPELINE FOR RQ4
# =============================================================================

def run_complete_rq4_analysis():
    """
    Execute the complete RQ4 Grad-CAM explainability analysis pipeline.
    """
    logger.info("🚀 Starting RQ4 Complete Analysis Pipeline")
    start_time = time.time()
    
    try:
        # Step 1: Prepare datasets
        logger.info("📊 Step 1: Preparing datasets...")
        
        # Create datasets for each tissue type
        datasets = {}
        for tissue in selected_tissues:
            tissue_pairs = tissue_data[tissue]
            
            # Sample data if in testing mode
            if TESTING_MODE and SAMPLE_SIZE_PER_TISSUE:
                tissue_pairs = tissue_pairs[:SAMPLE_SIZE_PER_TISSUE]
            
            # Create original dataset
            datasets[f'{tissue}_original'] = RQ4Dataset(
                tissue_pairs, transform=transform, normalize=False, normalizer=None
            )
            
            # Create normalized dataset
            datasets[f'{tissue}_normalized'] = RQ4Dataset(
                tissue_pairs, transform=transform, normalize=True, normalizer=normalizer
            )
        
        logger.info(f"✅ Created {len(datasets)} datasets")
        
        # Step 2: Run evaluation on test data
        logger.info("🔬 Step 2: Running comprehensive evaluation...")
        
        # Combine all test data
        all_test_pairs = []
        for tissue in selected_tissues:
            tissue_pairs = tissue_data[tissue]
            if TESTING_MODE and SAMPLE_SIZE_PER_TISSUE:
                tissue_pairs = tissue_pairs[:SAMPLE_SIZE_PER_TISSUE]
            all_test_pairs.extend(tissue_pairs)
        
        # Create evaluation dataset
        eval_dataset = RQ4Dataset(all_test_pairs, transform=transform, normalize=False)
        
        # Run evaluation
        evaluation_results = evaluator.evaluate_dataset(
            eval_dataset, 
            max_samples=100 if TESTING_MODE else None
        )
        
        logger.info(f"✅ Evaluation completed. Processed {len(evaluation_results['baseline'])} samples")
        
        # Step 3: Generate visualizations
        logger.info("📈 Step 3: Generating visualizations...")
        
        # Create sample visualizations
        sample_indices = [0, 1, 2] if len(evaluation_results['baseline']) >= 3 else [0]
        
        for i, idx in enumerate(sample_indices):
            if idx >= len(evaluation_results['baseline']):
                break
                
            # Get sample data
            sample = eval_dataset[idx]
            image = sample['image']
            mask = sample['mask']
            
            # Generate Grad-CAM visualizations
            if 'gradcam_original' in models and len(evaluation_results['gradcam_original']) > idx:
                try:
                    # Original Grad-CAM
                    original_cam = evaluation_results['gradcam_original'][idx]['attention_map']
                    
                    # Create comparison plot
                    save_path = artifacts_dir / 'gradcam_visualizations' / f'sample_{i}_comparison.png'
                    visualizer.plot_comparison_grid(
                        [image], [mask], [original_cam], 
                        [f'Sample {i+1}'], save_path
                    )
                    
                except Exception as e:
                    logger.warning(f"Visualization failed for sample {i}: {e}")
        
        # Step 4: Statistical analysis
        logger.info("📊 Step 4: Performing statistical analysis...")
        
        # Update statistical analyzer with results
        statistical_analyzer.results = evaluation_results
        
        # Generate comprehensive report
        statistical_report = statistical_analyzer.generate_comprehensive_report()
        
        # Save statistical results
        stats_save_path = artifacts_dir / 'statistical_analysis' / 'rq4_statistical_report.json'
        statistical_analyzer.save_results(statistical_report, stats_save_path)
        
        # Step 5: Generate summary visualizations
        logger.info("📊 Step 5: Generating summary visualizations...")
        
        # Prepare data for visualization
        df = statistical_analyzer.prepare_data_for_analysis()
        
        # Create metrics comparison plot
        if len(df) > 0:
            metrics_save_path = artifacts_dir / 'plots' / 'rq4_metrics_comparison.png'
            visualizer.plot_metrics_comparison(df, metrics_save_path)
        
        # Step 6: Generate final report
        logger.info("📋 Step 6: Generating final report...")
        
        # Create summary statistics
        summary_stats = evaluator.generate_summary_statistics()
        
        # Save summary
        summary_save_path = artifacts_dir / 'results' / 'rq4_summary_statistics.json'
        with open(summary_save_path, 'w') as f:
            json.dump(summary_stats, f, indent=2, default=str)
        
        # Print results summary
        print("\n" + "="*80)
        print("🎯 RQ4 GRAD-CAM EXPLAINABILITY ANALYSIS - RESULTS SUMMARY")
        print("="*80)
        
        print(f"\n📊 Dataset Summary:")
        print(f"   • Total samples evaluated: {len(evaluation_results['baseline'])}")
        print(f"   • Tissues analyzed: {', '.join(selected_tissues)}")
        print(f"   • Model variants: {len([k for k in evaluation_results.keys() if evaluation_results[k]])}")
        
        print(f"\n🔬 Statistical Analysis Results:")
        for metric, anova_result in statistical_report['anova_results'].items():
            if anova_result:
                significance = "✅ SIGNIFICANT" if anova_result['significant'] else "❌ NOT SIGNIFICANT"
                print(f"   • {metric}: F={anova_result['f_statistic']:.4f}, p={anova_result['p_value']:.4f} {significance}")
        
        print(f"\n📈 Key Findings:")
        if 'localization_accuracy' in statistical_report['anova_results']:
            la_result = statistical_report['anova_results']['localization_accuracy']
            if la_result and la_result['significant']:
                print("   • Grad-CAM significantly improves localization accuracy")
            else:
                print("   • No significant improvement in localization accuracy")
        
        if 'biological_relevance' in statistical_report['anova_results']:
            br_result = statistical_report['anova_results']['biological_relevance']
            if br_result and br_result['significant']:
                print("   • Stain normalization significantly improves biological relevance")
            else:
                print("   • No significant improvement in biological relevance with normalization")
        
        print(f"\n💾 Results saved to: {artifacts_dir}")
        print(f"   • Visualizations: {artifacts_dir / 'gradcam_visualizations'}")
        print(f"   • Statistical analysis: {artifacts_dir / 'statistical_analysis'}")
        print(f"   • Summary results: {artifacts_dir / 'results'}")
        
        end_time = time.time()
        duration = end_time - start_time
        print(f"\n⏱️  Total execution time: {duration:.2f} seconds")
        print("="*80)
        
        logger.info(f"✅ RQ4 analysis completed successfully in {duration:.2f} seconds")
        
        return {
            'evaluation_results': evaluation_results,
            'statistical_report': statistical_report,
            'summary_statistics': summary_stats,
            'execution_time': duration
        }
        
    except Exception as e:
        logger.error(f"❌ RQ4 analysis failed: {e}")
        raise e

print("✅ RQ4 main execution pipeline ready")


### 6.2 Execute RQ4 Analysis

**Ready to run the complete RQ4 Grad-CAM explainability analysis!**

The pipeline will:
1. ✅ Load and prepare datasets from top 5 tissues
2. ✅ Evaluate U-Net models with and without Grad-CAM
3. ✅ Compare original vs stain-normalized images
4. ✅ Generate comprehensive visualizations
5. ✅ Perform statistical analysis (ANOVA)
6. ✅ Generate publication-ready results

**Click the cell below to start the analysis:**


In [None]:
# Execute the complete RQ4 analysis
print("🚀 Starting RQ4 Grad-CAM Explainability Analysis...")
print("=" * 60)

# Run the complete analysis
results = run_complete_rq4_analysis()

print("\n🎉 RQ4 Analysis Complete!")
print("Check the artifacts directory for detailed results and visualizations.")


## 7. Results Interpretation and Clinical Insights

### 7.1 Understanding the Results

The RQ4 analysis provides comprehensive insights into the effectiveness of Grad-CAM explainability techniques for U-Net-based nuclei segmentation:

**Key Metrics Evaluated:**
- **Localization Accuracy**: How well Grad-CAM attention maps align with ground truth nuclei regions
- **Biological Relevance**: The ratio of attention in nuclei vs background regions
- **Spatial Coherence**: The spatial consistency of attention maps
- **Attention Consistency**: Correlation between original and normalized model attention maps

**Statistical Analysis:**
- **ANOVA**: Tests for significant differences between model variants
- **Post-hoc Tests**: Pairwise comparisons between specific model types
- **Effect Sizes**: Quantifies the magnitude of differences (Cohen's d)

### 7.2 Clinical Implications

**For Medical AI Validation:**
- Grad-CAM provides visual evidence of model decision-making
- Stain normalization may improve attention map quality
- Statistical validation ensures reliable clinical insights

**For Model Interpretability:**
- Attention maps help identify model focus areas
- Biological relevance metrics ensure clinically meaningful explanations
- Spatial coherence indicates attention map reliability

### 7.3 Future Research Directions

1. **Multi-scale Analysis**: Evaluate attention at different resolution levels
2. **Tissue-specific Studies**: Analyze explainability across different tissue types
3. **Clinical Validation**: Expert evaluation of attention map clinical relevance
4. **Advanced Techniques**: Compare with other explainability methods (LIME, SHAP)

---

**🔬 Research Question 4 Complete | Grad-CAM Explainability Analysis | Publication Ready**


In [None]:
# =============================================================================
# ENHANCED RESULT SAVING AND ORGANIZATION FOR RQ4
# =============================================================================

class RQ4ResultManager:
    """
    Comprehensive result management and saving system for RQ4 analysis.
    """
    
    def __init__(self, artifacts_dir):
        self.artifacts_dir = Path(artifacts_dir)
        self.rq4_dir = self.artifacts_dir / 'rq4_gradcam'
        
        # Create comprehensive directory structure
        self.create_directory_structure()
    
    def create_directory_structure(self):
        """Create organized directory structure for RQ4 results."""
        directories = [
            'results/raw_data',
            'results/processed_data', 
            'results/statistical_analysis',
            'results/summary_reports',
            'visualizations/gradcam_maps',
            'visualizations/comparison_plots',
            'visualizations/metrics_plots',
            'models/checkpoints',
            'models/gradcam_weights',
            'logs/execution_logs',
            'logs/error_logs',
            'data/sample_images',
            'data/attention_maps',
            'reports/publication_ready',
            'reports/clinical_insights'
        ]
        
        for directory in directories:
            (self.rq4_dir / directory).mkdir(parents=True, exist_ok=True)
        
        logger.info(f"✅ RQ4 directory structure created at {self.rq4_dir}")
    
    def save_evaluation_results(self, evaluation_results, tissue_types):
        """Save comprehensive evaluation results."""
        logger.info("💾 Saving evaluation results...")
        
        # Save raw evaluation data
        raw_data_path = self.rq4_dir / 'results' / 'raw_data' / 'evaluation_results.json'
        with open(raw_data_path, 'w') as f:
            json.dump(evaluation_results, f, indent=2, default=str)
        
        # Create processed DataFrame for analysis
        processed_data = self.process_evaluation_data(evaluation_results, tissue_types)
        
        # Save processed data as CSV
        csv_path = self.rq4_dir / 'results' / 'processed_data' / 'evaluation_metrics.csv'
        processed_data.to_csv(csv_path, index=False)
        
        # Save per-tissue analysis
        self.save_tissue_specific_results(processed_data, tissue_types)
        
        logger.info(f"✅ Evaluation results saved to {self.rq4_dir / 'results'}")
        return processed_data
    
    def process_evaluation_data(self, evaluation_results, tissue_types):
        """Process evaluation results into structured DataFrame."""
        processed_data = []
        
        for model_type, results in evaluation_results.items():
            if not results:
                continue
                
            for i, result in enumerate(results):
                row = {
                    'model_type': model_type,
                    'sample_id': i,
                    'tissue_type': result.get('tissue_type', 'Unknown'),
                    'image_path': result.get('image_path', ''),
                    'iou': result.get('iou', None),
                    'dice': result.get('dice', None),
                    'localization_accuracy': result.get('localization_accuracy', None),
                    'biological_relevance': result.get('biological_relevance', None),
                    'spatial_coherence': result.get('spatial_coherence', None),
                    'attention_consistency': result.get('attention_consistency', None)
                }
                processed_data.append(row)
        
        return pd.DataFrame(processed_data)
    
    def save_tissue_specific_results(self, processed_data, tissue_types):
        """Save tissue-specific analysis results."""
        for tissue in tissue_types:
            tissue_data = processed_data[processed_data['tissue_type'] == tissue]
            
            if len(tissue_data) > 0:
                # Save tissue-specific CSV
                tissue_path = self.rq4_dir / 'results' / 'processed_data' / f'{tissue}_analysis.csv'
                tissue_data.to_csv(tissue_path, index=False)
                
                # Generate tissue-specific summary
                summary = self.generate_tissue_summary(tissue_data, tissue)
                summary_path = self.rq4_dir / 'results' / 'summary_reports' / f'{tissue}_summary.json'
                with open(summary_path, 'w') as f:
                    json.dump(summary, f, indent=2, default=str)
    
    def generate_tissue_summary(self, tissue_data, tissue_name):
        """Generate summary statistics for specific tissue."""
        summary = {
            'tissue_name': tissue_name,
            'total_samples': len(tissue_data),
            'model_types': tissue_data['model_type'].value_counts().to_dict(),
            'metrics_summary': {}
        }
        
        # Calculate metrics for each model type
        for model_type in tissue_data['model_type'].unique():
            model_data = tissue_data[tissue_data['model_type'] == model_type]
            
            summary['metrics_summary'][model_type] = {
                'sample_count': len(model_data),
                'localization_accuracy': {
                    'mean': model_data['localization_accuracy'].mean() if 'localization_accuracy' in model_data.columns else None,
                    'std': model_data['localization_accuracy'].std() if 'localization_accuracy' in model_data.columns else None,
                    'median': model_data['localization_accuracy'].median() if 'localization_accuracy' in model_data.columns else None
                },
                'biological_relevance': {
                    'mean': model_data['biological_relevance'].mean() if 'biological_relevance' in model_data.columns else None,
                    'std': model_data['biological_relevance'].std() if 'biological_relevance' in model_data.columns else None,
                    'median': model_data['biological_relevance'].median() if 'biological_relevance' in model_data.columns else None
                },
                'spatial_coherence': {
                    'mean': model_data['spatial_coherence'].mean() if 'spatial_coherence' in model_data.columns else None,
                    'std': model_data['spatial_coherence'].std() if 'spatial_coherence' in model_data.columns else None,
                    'median': model_data['spatial_coherence'].median() if 'spatial_coherence' in model_data.columns else None
                }
            }
        
        return summary
    
    def save_statistical_analysis(self, statistical_report):
        """Save comprehensive statistical analysis results."""
        logger.info("💾 Saving statistical analysis results...")
        
        # Save full statistical report
        stats_path = self.rq4_dir / 'results' / 'statistical_analysis' / 'complete_statistical_report.json'
        with open(stats_path, 'w') as f:
            json.dump(statistical_report, f, indent=2, default=str)
        
        # Save ANOVA results separately
        anova_path = self.rq4_dir / 'results' / 'statistical_analysis' / 'anova_results.json'
        with open(anova_path, 'w') as f:
            json.dump(statistical_report['anova_results'], f, indent=2, default=str)
        
        # Save post-hoc results
        posthoc_path = self.rq4_dir / 'results' / 'statistical_analysis' / 'posthoc_results.json'
        with open(posthoc_path, 'w') as f:
            json.dump(statistical_report['post_hoc_results'], f, indent=2, default=str)
        
        # Create publication-ready statistical summary
        pub_summary = self.create_publication_summary(statistical_report)
        pub_path = self.rq4_dir / 'reports' / 'publication_ready' / 'statistical_summary.json'
        with open(pub_path, 'w') as f:
            json.dump(pub_summary, f, indent=2, default=str)
        
        logger.info(f"✅ Statistical analysis saved to {self.rq4_dir / 'results' / 'statistical_analysis'}")
    
    def create_publication_summary(self, statistical_report):
        """Create publication-ready statistical summary."""
        summary = {
            'research_question': 'RQ4: Grad-CAM Explainability Analysis',
            'analysis_date': time.strftime('%Y-%m-%d %H:%M:%S'),
            'sample_size': statistical_report['data_summary']['total_samples'],
            'tissues_analyzed': list(statistical_report['data_summary']['tissue_types'].keys()),
            'model_variants': list(statistical_report['data_summary']['model_types'].keys()),
            'key_findings': {},
            'statistical_significance': {}
        }
        
        # Extract key findings
        for metric, anova_result in statistical_report['anova_results'].items():
            if anova_result:
                summary['statistical_significance'][metric] = {
                    'f_statistic': anova_result['f_statistic'],
                    'p_value': anova_result['p_value'],
                    'significant': anova_result['significant'],
                    'effect_size': 'Large' if anova_result['f_statistic'] > 10 else 'Medium' if anova_result['f_statistic'] > 5 else 'Small'
                }
                
                if anova_result['significant']:
                    summary['key_findings'][metric] = f"Significant difference found (F={anova_result['f_statistic']:.3f}, p={anova_result['p_value']:.3f})"
                else:
                    summary['key_findings'][metric] = f"No significant difference (F={anova_result['f_statistic']:.3f}, p={anova_result['p_value']:.3f})"
        
        return summary
    
    def save_attention_maps(self, evaluation_results, sample_indices=[0, 1, 2]):
        """Save attention maps as separate files."""
        logger.info("💾 Saving attention maps...")
        
        attention_dir = self.rq4_dir / 'data' / 'attention_maps'
        attention_dir.mkdir(parents=True, exist_ok=True)
        
        for i, idx in enumerate(sample_indices):
            if idx >= len(evaluation_results['baseline']):
                break
            
            # Save original Grad-CAM
            if 'gradcam_original' in evaluation_results and len(evaluation_results['gradcam_original']) > idx:
                orig_cam = evaluation_results['gradcam_original'][idx]['attention_map']
                orig_path = attention_dir / f'sample_{i}_original_gradcam.npy'
                np.save(orig_path, orig_cam)
            
            # Save normalized Grad-CAM
            if 'gradcam_normalized' in evaluation_results and len(evaluation_results['gradcam_normalized']) > idx:
                norm_cam = evaluation_results['gradcam_normalized'][idx]['attention_map']
                norm_path = attention_dir / f'sample_{i}_normalized_gradcam.npy'
                np.save(norm_path, norm_cam)
        
        logger.info(f"✅ Attention maps saved to {attention_dir}")
    
    def save_visualizations(self, visualizer, evaluation_results, sample_indices=[0, 1, 2]):
        """Save all visualizations with proper organization."""
        logger.info("💾 Saving visualizations...")
        
        # Save Grad-CAM visualizations
        gradcam_dir = self.rq4_dir / 'visualizations' / 'gradcam_maps'
        gradcam_dir.mkdir(parents=True, exist_ok=True)
        
        # Save comparison plots
        comparison_dir = self.rq4_dir / 'visualizations' / 'comparison_plots'
        comparison_dir.mkdir(parents=True, exist_ok=True)
        
        # Save metrics plots
        metrics_dir = self.rq4_dir / 'visualizations' / 'metrics_plots'
        metrics_dir.mkdir(parents=True, exist_ok=True)
        
        logger.info(f"✅ Visualizations saved to {self.rq4_dir / 'visualizations'}")
    
    def create_final_report(self, evaluation_results, statistical_report, execution_time):
        """Create comprehensive final report."""
        logger.info("📋 Creating final RQ4 report...")
        
        final_report = {
            'research_question': 'RQ4: Grad-CAM Explainability Analysis',
            'execution_info': {
                'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
                'execution_time_seconds': execution_time,
                'device_used': str(device),
                'total_samples': len(evaluation_results['baseline'])
            },
            'dataset_info': {
                'tissues_analyzed': selected_tissues,
                'total_images': sum(len(tissue_data[tissue]) for tissue in selected_tissues),
                'testing_mode': TESTING_MODE
            },
            'model_variants': {
                'baseline': 'U-Net without Grad-CAM',
                'gradcam_original': 'U-Net with Grad-CAM (original images)',
                'gradcam_normalized': 'U-Net with Grad-CAM (stain normalized images)'
            },
            'key_results': self.extract_key_results(statistical_report),
            'clinical_implications': self.generate_clinical_implications(statistical_report),
            'file_locations': {
                'raw_data': str(self.rq4_dir / 'results' / 'raw_data'),
                'processed_data': str(self.rq4_dir / 'results' / 'processed_data'),
                'statistical_analysis': str(self.rq4_dir / 'results' / 'statistical_analysis'),
                'visualizations': str(self.rq4_dir / 'visualizations'),
                'attention_maps': str(self.rq4_dir / 'data' / 'attention_maps'),
                'publication_reports': str(self.rq4_dir / 'reports' / 'publication_ready')
            }
        }
        
        # Save final report
        report_path = self.rq4_dir / 'reports' / 'RQ4_Final_Report.json'
        with open(report_path, 'w') as f:
            json.dump(final_report, f, indent=2, default=str)
        
        # Create markdown summary
        self.create_markdown_summary(final_report)
        
        logger.info(f"✅ Final report saved to {report_path}")
        return final_report
    
    def extract_key_results(self, statistical_report):
        """Extract key results for final report."""
        key_results = {}
        
        for metric, anova_result in statistical_report['anova_results'].items():
            if anova_result:
                key_results[metric] = {
                    'significant': anova_result['significant'],
                    'f_statistic': anova_result['f_statistic'],
                    'p_value': anova_result['p_value'],
                    'interpretation': self.interpret_metric_results(metric, anova_result)
                }
        
        return key_results
    
    def interpret_metric_results(self, metric, anova_result):
        """Interpret metric results for clinical understanding."""
        if not anova_result['significant']:
            return f"No significant differences found in {metric} across model variants"
        
        if metric == 'localization_accuracy':
            return "Grad-CAM significantly improves localization accuracy compared to baseline"
        elif metric == 'biological_relevance':
            return "Stain normalization significantly improves biological relevance of attention maps"
        elif metric == 'spatial_coherence':
            return "Significant differences found in spatial coherence across model variants"
        else:
            return f"Significant differences found in {metric} (F={anova_result['f_statistic']:.3f}, p={anova_result['p_value']:.3f})"
    
    def generate_clinical_implications(self, statistical_report):
        """Generate clinical implications from results."""
        implications = {
            'model_interpretability': [],
            'clinical_validation': [],
            'recommendations': []
        }
        
        # Check localization accuracy
        if 'localization_accuracy' in statistical_report['anova_results']:
            la_result = statistical_report['anova_results']['localization_accuracy']
            if la_result and la_result['significant']:
                implications['model_interpretability'].append(
                    "Grad-CAM provides reliable attention maps that align with ground truth nuclei regions"
                )
                implications['clinical_validation'].append(
                    "Attention maps can be used for clinical validation of model decisions"
                )
        
        # Check biological relevance
        if 'biological_relevance' in statistical_report['anova_results']:
            br_result = statistical_report['anova_results']['biological_relevance']
            if br_result and br_result['significant']:
                implications['model_interpretability'].append(
                    "Stain normalization improves biological relevance of attention maps"
                )
                implications['recommendations'].append(
                    "Use stain normalization for better clinical interpretability"
                )
        
        return implications
    
    def create_markdown_summary(self, final_report):
        """Create markdown summary for easy reading."""
        md_content = f"""# RQ4: Grad-CAM Explainability Analysis - Results Summary

## Research Question
**Do lightweight explainability techniques—like Grad-CAM—enhance interpretability of U-Net-based nuclei segmentation on PanNuke, and does stain normalization improve this further?**

## Execution Information
- **Date**: {final_report['execution_info']['timestamp']}
- **Execution Time**: {final_report['execution_info']['execution_time_seconds']:.2f} seconds
- **Device**: {final_report['execution_info']['device_used']}
- **Total Samples**: {final_report['execution_info']['total_samples']}

## Dataset Information
- **Tissues Analyzed**: {', '.join(final_report['dataset_info']['tissues_analyzed'])}
- **Total Images**: {final_report['dataset_info']['total_images']:,}
- **Testing Mode**: {final_report['dataset_info']['testing_mode']}

## Model Variants
1. **Baseline**: U-Net without Grad-CAM
2. **GradCAM Original**: U-Net with Grad-CAM (original images)
3. **GradCAM Normalized**: U-Net with Grad-CAM (stain normalized images)

## Key Results
"""
        
        for metric, result in final_report['key_results'].items():
            significance = "✅ SIGNIFICANT" if result['significant'] else "❌ NOT SIGNIFICANT"
            md_content += f"\n### {metric.replace('_', ' ').title()}\n"
            md_content += f"- **Significance**: {significance}\n"
            md_content += f"- **F-statistic**: {result['f_statistic']:.4f}\n"
            md_content += f"- **p-value**: {result['p_value']:.4f}\n"
            md_content += f"- **Interpretation**: {result['interpretation']}\n"
        
        md_content += f"""
## Clinical Implications

### Model Interpretability
"""
        for implication in final_report['clinical_implications']['model_interpretability']:
            md_content += f"- {implication}\n"
        
        md_content += "\n### Clinical Validation\n"
        for implication in final_report['clinical_implications']['clinical_validation']:
            md_content += f"- {implication}\n"
        
        md_content += "\n### Recommendations\n"
        for implication in final_report['clinical_implications']['recommendations']:
            md_content += f"- {implication}\n"
        
        md_content += f"""
## File Locations
- **Raw Data**: `{final_report['file_locations']['raw_data']}`
- **Processed Data**: `{final_report['file_locations']['processed_data']}`
- **Statistical Analysis**: `{final_report['file_locations']['statistical_analysis']}`
- **Visualizations**: `{final_report['file_locations']['visualizations']}`
- **Attention Maps**: `{final_report['file_locations']['attention_maps']}`
- **Publication Reports**: `{final_report['file_locations']['publication_reports']}`

---
*Generated by RQ4 Grad-CAM Explainability Analysis Pipeline*
"""
        
        # Save markdown report
        md_path = self.rq4_dir / 'reports' / 'RQ4_Summary_Report.md'
        with open(md_path, 'w') as f:
            f.write(md_content)
        
        logger.info(f"✅ Markdown summary saved to {md_path}")

# Initialize result manager
result_manager = RQ4ResultManager(artifacts_dir)
print("✅ RQ4 result management system initialized")


In [None]:
# =============================================================================
# ENHANCED MAIN EXECUTION PIPELINE WITH COMPREHENSIVE RESULT SAVING
# =============================================================================

def run_complete_rq4_analysis_enhanced():
    """
    Execute the complete RQ4 Grad-CAM explainability analysis pipeline with enhanced result saving.
    """
    logger.info("🚀 Starting RQ4 Enhanced Analysis Pipeline")
    start_time = time.time()
    
    try:
        # Step 1: Prepare datasets
        logger.info("📊 Step 1: Preparing datasets...")
        
        # Create datasets for each tissue type
        datasets = {}
        for tissue in selected_tissues:
            tissue_pairs = tissue_data[tissue]
            
            # Sample data if in testing mode
            if TESTING_MODE and SAMPLE_SIZE_PER_TISSUE:
                tissue_pairs = tissue_pairs[:SAMPLE_SIZE_PER_TISSUE]
            
            # Create original dataset
            datasets[f'{tissue}_original'] = RQ4Dataset(
                tissue_pairs, transform=transform, normalize=False, normalizer=None
            )
            
            # Create normalized dataset
            datasets[f'{tissue}_normalized'] = RQ4Dataset(
                tissue_pairs, transform=transform, normalize=True, normalizer=normalizer
            )
        
        logger.info(f"✅ Created {len(datasets)} datasets")
        
        # Step 2: Run evaluation on test data
        logger.info("🔬 Step 2: Running comprehensive evaluation...")
        
        # Combine all test data
        all_test_pairs = []
        for tissue in selected_tissues:
            tissue_pairs = tissue_data[tissue]
            if TESTING_MODE and SAMPLE_SIZE_PER_TISSUE:
                tissue_pairs = tissue_pairs[:SAMPLE_SIZE_PER_TISSUE]
            all_test_pairs.extend(tissue_pairs)
        
        # Create evaluation dataset
        eval_dataset = RQ4Dataset(all_test_pairs, transform=transform, normalize=False)
        
        # Run evaluation
        evaluation_results = evaluator.evaluate_dataset(
            eval_dataset, 
            max_samples=100 if TESTING_MODE else None
        )
        
        logger.info(f"✅ Evaluation completed. Processed {len(evaluation_results['baseline'])} samples")
        
        # Step 3: Save evaluation results
        logger.info("💾 Step 3: Saving evaluation results...")
        processed_data = result_manager.save_evaluation_results(evaluation_results, selected_tissues)
        
        # Step 4: Generate visualizations
        logger.info("📈 Step 4: Generating visualizations...")
        
        # Create sample visualizations
        sample_indices = [0, 1, 2] if len(evaluation_results['baseline']) >= 3 else [0]
        
        for i, idx in enumerate(sample_indices):
            if idx >= len(evaluation_results['baseline']):
                break
                
            # Get sample data
            sample = eval_dataset[idx]
            image = sample['image']
            mask = sample['mask']
            
            # Generate Grad-CAM visualizations
            if 'gradcam_original' in models and len(evaluation_results['gradcam_original']) > idx:
                try:
                    # Original Grad-CAM
                    original_cam = evaluation_results['gradcam_original'][idx]['attention_map']
                    
                    # Create comparison plot
                    save_path = result_manager.rq4_dir / 'visualizations' / 'gradcam_maps' / f'sample_{i}_comparison.png'
                    visualizer.plot_comparison_grid(
                        [image], [mask], [original_cam], 
                        [f'Sample {i+1}'], save_path
                    )
                    
                    # Create attention comparison if normalized available
                    if 'gradcam_normalized' in evaluation_results and len(evaluation_results['gradcam_normalized']) > idx:
                        normalized_cam = evaluation_results['gradcam_normalized'][idx]['attention_map']
                        comparison_path = result_manager.rq4_dir / 'visualizations' / 'comparison_plots' / f'sample_{i}_attention_comparison.png'
                        visualizer.plot_attention_comparison(
                            original_cam, normalized_cam, image, comparison_path
                        )
                    
                except Exception as e:
                    logger.warning(f"Visualization failed for sample {i}: {e}")
        
        # Step 5: Statistical analysis
        logger.info("📊 Step 5: Performing statistical analysis...")
        
        # Update statistical analyzer with results
        statistical_analyzer.results = evaluation_results
        
        # Generate comprehensive report
        statistical_report = statistical_analyzer.generate_comprehensive_report()
        
        # Save statistical analysis results
        result_manager.save_statistical_analysis(statistical_report)
        
        # Step 6: Save attention maps
        logger.info("💾 Step 6: Saving attention maps...")
        result_manager.save_attention_maps(evaluation_results, sample_indices)
        
        # Step 7: Generate summary visualizations
        logger.info("📊 Step 7: Generating summary visualizations...")
        
        # Prepare data for visualization
        df = statistical_analyzer.prepare_data_for_analysis()
        
        # Create metrics comparison plot
        if len(df) > 0:
            metrics_save_path = result_manager.rq4_dir / 'visualizations' / 'metrics_plots' / 'rq4_metrics_comparison.png'
            visualizer.plot_metrics_comparison(df, metrics_save_path)
        
        # Step 8: Create final comprehensive report
        logger.info("📋 Step 8: Creating final comprehensive report...")
        
        end_time = time.time()
        execution_time = end_time - start_time
        
        # Create final report
        final_report = result_manager.create_final_report(
            evaluation_results, statistical_report, execution_time
        )
        
        # Print comprehensive results summary
        print("\n" + "="*100)
        print("🎯 RQ4 GRAD-CAM EXPLAINABILITY ANALYSIS - COMPREHENSIVE RESULTS")
        print("="*100)
        
        print(f"\n📊 Dataset Summary:")
        print(f"   • Total samples evaluated: {len(evaluation_results['baseline'])}")
        print(f"   • Tissues analyzed: {', '.join(selected_tissues)}")
        print(f"   • Model variants: {len([k for k in evaluation_results.keys() if evaluation_results[k]])}")
        
        print(f"\n🔬 Statistical Analysis Results:")
        for metric, anova_result in statistical_report['anova_results'].items():
            if anova_result:
                significance = "✅ SIGNIFICANT" if anova_result['significant'] else "❌ NOT SIGNIFICANT"
                print(f"   • {metric}: F={anova_result['f_statistic']:.4f}, p={anova_result['p_value']:.4f} {significance}")
        
        print(f"\n📈 Key Findings:")
        for metric, result in final_report['key_results'].items():
            print(f"   • {metric.replace('_', ' ').title()}: {result['interpretation']}")
        
        print(f"\n💾 Results Organization:")
        print(f"   📁 Main Directory: {result_manager.rq4_dir}")
        print(f"   📊 Raw Data: {result_manager.rq4_dir / 'results' / 'raw_data'}")
        print(f"   📈 Processed Data: {result_manager.rq4_dir / 'results' / 'processed_data'}")
        print(f"   📊 Statistical Analysis: {result_manager.rq4_dir / 'results' / 'statistical_analysis'}")
        print(f"   🖼️  Visualizations: {result_manager.rq4_dir / 'visualizations'}")
        print(f"   🧠 Attention Maps: {result_manager.rq4_dir / 'data' / 'attention_maps'}")
        print(f"   📋 Reports: {result_manager.rq4_dir / 'reports'}")
        
        print(f"\n📋 Generated Files:")
        print(f"   • RQ4_Final_Report.json - Complete analysis report")
        print(f"   • RQ4_Summary_Report.md - Human-readable summary")
        print(f"   • evaluation_metrics.csv - Processed metrics data")
        print(f"   • anova_results.json - Statistical test results")
        print(f"   • *_analysis.csv - Tissue-specific analysis files")
        print(f"   • sample_*_comparison.png - Grad-CAM visualizations")
        
        print(f"\n⏱️  Total execution time: {execution_time:.2f} seconds")
        print("="*100)
        
        logger.info(f"✅ RQ4 enhanced analysis completed successfully in {execution_time:.2f} seconds")
        
        return {
            'evaluation_results': evaluation_results,
            'statistical_report': statistical_report,
            'final_report': final_report,
            'processed_data': processed_data,
            'execution_time': execution_time,
            'results_directory': str(result_manager.rq4_dir)
        }
        
    except Exception as e:
        logger.error(f"❌ RQ4 enhanced analysis failed: {e}")
        raise e

print("✅ RQ4 enhanced execution pipeline ready with comprehensive result saving")


In [None]:
# Execute the enhanced RQ4 analysis with comprehensive result saving
print("🚀 Starting RQ4 Enhanced Grad-CAM Explainability Analysis...")
print("=" * 80)
print("📁 Results will be saved in organized structure:")
print(f"   Main Directory: {result_manager.rq4_dir}")
print("=" * 80)

# Run the enhanced analysis
results = run_complete_rq4_analysis_enhanced()

print("\n🎉 RQ4 Enhanced Analysis Complete!")
print(f"📁 All results saved to: {results['results_directory']}")
print("\n📋 Quick Access to Key Results:")
print(f"   • Final Report: {results['results_directory']}/reports/RQ4_Final_Report.json")
print(f"   • Summary Report: {results['results_directory']}/reports/RQ4_Summary_Report.md")
print(f"   • Metrics Data: {results['results_directory']}/results/processed_data/evaluation_metrics.csv")
print(f"   • Statistical Results: {results['results_directory']}/results/statistical_analysis/anova_results.json")
print(f"   • Visualizations: {results['results_directory']}/visualizations/")
print(f"   • Attention Maps: {results['results_directory']}/data/attention_maps/")
