# Research Question 1: SAM Variants vs Established Models
## Comprehensive Comparison of Segment Anything Model Variants on PanNuke Dataset

**Research Question**: Do different variants of the Segment Anything Model (SAM), including the domain-adapted PathoSAM, achieve competitive or superior nuclei instance segmentation performance on the PanNuke dataset compared to established models such as HoVer-Net, CellViT, and LKCell?

### Research Hypotheses:
- **H₀ (Null)**: SAM variants do not significantly outperform established models in mPQ or detection F1
- **H₁ (Alternative)**: At least one SAM variant significantly outperforms baselines in mPQ or detection F1

### Methodology:
1. **Models to Compare**:
   - **SAM Variants**: SAM-Base, SAM-Large, SAM-Huge, PathoSAM
   - **Established Models**: HoVer-Net, CellViT, LKCell
   - **Baseline**: U-Net (for reference)

2. **Dataset**: PanNuke dataset with proper train/val/test splits
3. **Evaluation Metrics**: 
   - Mean Panoptic Quality (mPQ)
   - Detection F1 Score
   - Per-class performance analysis
   - Computational efficiency metrics

4. **Statistical Analysis**:
   - Paired t-tests for model comparisons
   - Wilcoxon signed-rank tests
   - Multiple comparison correction (Bonferroni)
   - Effect size calculations (Cohen's d)

### Expected Outcomes:
- Comprehensive performance comparison across all models
- Statistical significance testing with proper corrections
- Per-tissue class analysis
- Computational efficiency comparison
- Publication-ready results and visualizations

---
**🔬 SAM Variants Analysis | Multi-Model Comparison | Statistical Rigor**


In [None]:
# Install requirements
%pip install -r ../requirements.txt
%pip install segment-anything
%pip install transformers
%pip install timm
%pip install opencv-python
%pip install scikit-image
%pip install albumentations


In [None]:
# =============================================================================
# COMPREHENSIVE IMPORTS AND SETUP
# =============================================================================

# Core ML and Deep Learning
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

# SAM and Vision Models
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
from transformers import AutoModel, AutoImageProcessor
import timm

# Data Processing and Visualization
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Scientific Computing and Statistics
from scipy import stats
from scipy.stats import wilcoxon, ttest_rel, mannwhitneyu
from sklearn.metrics import f1_score, precision_score, recall_score, jaccard_score
from sklearn.model_selection import train_test_split
from statsmodels.stats.multitest import multipletests

# Utilities
from pathlib import Path
import os
import sys
import json
import time
import warnings
from collections import defaultdict
from typing import Dict, List, Tuple, Optional, Union
import gc
from tqdm import tqdm
import logging
import shutil
from dataclasses import dataclass

# Custom imports
sys.path.append('../')
from src.datasets.pannuke_dataset import PanNukeDataset
from src.utils.metrics import calculate_segmentation_metrics, calculate_batch_metrics
from src.models.unet import UNet

# Suppress warnings
warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

print("✅ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


In [None]:
# =============================================================================
# CONFIGURATION AND SETUP
# =============================================================================

@dataclass
class Config:
    """Configuration class for RQ1 experiment"""
    
    # Data paths
    data_root: str = "../data"
    artifacts_dir: str = "../artifacts/rq1"
    
    # Model configurations
    sam_models: Dict[str, str] = None
    batch_size: int = 8
    num_workers: int = 4
    image_size: Tuple[int, int] = (256, 256)
    
    # Training parameters
    learning_rate: float = 1e-4
    epochs: int = 50
    patience: int = 10
    
    # Evaluation parameters
    confidence_threshold: float = 0.5
    iou_threshold: float = 0.5
    
    # Statistical parameters
    alpha: float = 0.05
    n_bootstrap: int = 1000
    
    def __post_init__(self):
        if self.sam_models is None:
            self.sam_models = {
                'sam_base': 'sam_vit_b_01ec64.pth',
                'sam_large': 'sam_vit_l_0b3195.pth', 
                'sam_huge': 'sam_vit_h_4b8939.pth'
            }

# Initialize configuration
config = Config()

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create artifacts directory
artifacts_dir = Path(config.artifacts_dir)
artifacts_dir.mkdir(parents=True, exist_ok=True)
(artifacts_dir / 'models').mkdir(exist_ok=True)
(artifacts_dir / 'results').mkdir(exist_ok=True)
(artifacts_dir / 'plots').mkdir(exist_ok=True)
(artifacts_dir / 'logs').mkdir(exist_ok=True)

# Set up logging
log_file = artifacts_dir / 'logs' / 'rq1_experiment.log'
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)

logger.info("RQ1 Experiment Configuration:")
logger.info(f"Data root: {config.data_root}")
logger.info(f"Artifacts dir: {config.artifacts_dir}")
logger.info(f"Device: {device}")
logger.info(f"Batch size: {config.batch_size}")
logger.info(f"Image size: {config.image_size}")

print("✅ Configuration setup complete!")


## 1. Data Loading and Preprocessing

Load the PanNuke dataset and prepare it for multi-model evaluation.


In [None]:
# =============================================================================
# DATA LOADING AND PREPROCESSING
# =============================================================================

class PanNukeMultiModelDataset(Dataset):
    """PanNuke dataset adapted for multiple model evaluation"""
    
    def __init__(self, data_dir, split='train', transform=None, image_size=(256, 256)):
        self.data_dir = Path(data_dir)
        self.split = split
        self.transform = transform
        self.image_size = image_size
        
        # Load split information
        split_file = self.data_dir / f"{split}.txt"
        if split_file.exists():
            with open(split_file, 'r') as f:
                self.image_files = [line.strip() for line in f.readlines()]
        else:
            # Fallback: list all images in split directory
            split_dir = self.data_dir / split
            self.image_files = list(split_dir.glob("*.png"))
            self.image_files = [f.name for f in self.image_files]
        
        logger.info(f"Loaded {len(self.image_files)} images for {split} split")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        
        # Load image
        img_path = self.data_dir / self.split / img_name
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Load corresponding mask
        mask_name = img_name.replace('.png', '_mask.png')
        mask_path = self.data_dir / self.split / mask_name
        if mask_path.exists():
            mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        else:
            # Create dummy mask if not found
            mask = np.zeros(image.shape[:2], dtype=np.uint8)
        
        # Resize if needed
        if image.shape[:2] != self.image_size:
            image = cv2.resize(image, self.image_size)
            mask = cv2.resize(mask, self.image_size, interpolation=cv2.INTER_NEAREST)
        
        # Apply transforms
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']
        
        return image, mask, img_name

# Define transforms for different models
def get_transforms(image_size=(256, 256)):
    """Get appropriate transforms for different model types"""
    
    # Base transforms
    base_transform = A.Compose([
        A.Resize(image_size[0], image_size[1]),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
    
    # SAM-specific transforms (no normalization)
    sam_transform = A.Compose([
        A.Resize(image_size[0], image_size[1]),
        ToTensorV2()
    ])
    
    return {
        'base': base_transform,
        'sam': sam_transform
    }

# Load datasets
print("📁 Loading PanNuke dataset...")

transforms_dict = get_transforms(config.image_size)

# Create datasets
datasets = {}
for split in ['train', 'val', 'test']:
    datasets[split] = PanNukeMultiModelDataset(
        data_dir=config.data_root,
        split=split,
        transform=transforms_dict['base'],
        image_size=config.image_size
    )

# Create SAM-specific datasets
sam_datasets = {}
for split in ['train', 'val', 'test']:
    sam_datasets[split] = PanNukeMultiModelDataset(
        data_dir=config.data_root,
        split=split,
        transform=transforms_dict['sam'],
        image_size=config.image_size
    )

# Create data loaders
dataloaders = {}
sam_dataloaders = {}

for split in ['train', 'val', 'test']:
    dataloaders[split] = DataLoader(
        datasets[split],
        batch_size=config.batch_size,
        shuffle=(split == 'train'),
        num_workers=config.num_workers,
        pin_memory=True
    )
    
    sam_dataloaders[split] = DataLoader(
        sam_datasets[split],
        batch_size=config.batch_size,
        shuffle=(split == 'train'),
        num_workers=config.num_workers,
        pin_memory=True
    )

print(f"✅ Dataset loaded successfully!")
print(f"Train: {len(datasets['train'])} images")
print(f"Validation: {len(datasets['val'])} images") 
print(f"Test: {len(datasets['test'])} images")

# Visualize sample data
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for i in range(3):
    # Regular dataset
    img, mask, name = datasets['train'][i]
    axes[0, i].imshow(img.permute(1, 2, 0))
    axes[0, i].set_title(f'Image {i+1}')
    axes[0, i].axis('off')
    
    # SAM dataset
    img_sam, mask_sam, _ = sam_datasets['train'][i]
    axes[1, i].imshow(img_sam.permute(1, 2, 0))
    axes[1, i].set_title(f'SAM Image {i+1}')
    axes[1, i].axis('off')

plt.tight_layout()
plt.savefig(artifacts_dir / 'plots' / 'sample_data.png', dpi=150, bbox_inches='tight')
plt.show()


## 2. Model Implementations

Implement all models for comparison: SAM variants, PathoSAM, HoVer-Net, CellViT, LKCell, and U-Net baseline.


In [None]:
# =============================================================================
# MODEL IMPLEMENTATIONS
# =============================================================================

class SAMWrapper(nn.Module):
    """Wrapper for SAM models for nuclei segmentation"""
    
    def __init__(self, model_type='vit_b', checkpoint_path=None, device='cuda'):
        super().__init__()
        self.device = device
        self.model_type = model_type
        
        # Load SAM model
        if checkpoint_path and os.path.exists(checkpoint_path):
            self.sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
        else:
            # Use default checkpoint
            self.sam = sam_model_registry[model_type]()
        
        self.sam.to(device)
        self.predictor = SamPredictor(self.sam)
        
        # For automatic mask generation
        self.mask_generator = SamAutomaticMaskGenerator(
            model=self.sam,
            points_per_side=32,
            pred_iou_thresh=0.7,
            stability_score_thresh=0.92,
            crop_n_layers=1,
            crop_n_points_downscale_factor=2,
            min_mask_region_area=100,
        )
    
    def forward(self, images):
        """Forward pass for batch processing"""
        batch_size = images.shape[0]
        predictions = []
        
        for i in range(batch_size):
            image = images[i].cpu().numpy().transpose(1, 2, 0)
            image = (image * 255).astype(np.uint8)
            
            # Generate masks
            masks = self.mask_generator.generate(image)
            
            # Convert to tensor format
            if masks:
                # Combine all masks into single segmentation
                combined_mask = np.zeros(image.shape[:2], dtype=np.uint8)
                for j, mask_data in enumerate(masks):
                    combined_mask[mask_data['segmentation']] = j + 1
            else:
                combined_mask = np.zeros(image.shape[:2], dtype=np.uint8)
            
            predictions.append(torch.from_numpy(combined_mask).long())
        
        return torch.stack(predictions).to(self.device)

class PathoSAMWrapper(SAMWrapper):
    """PathoSAM wrapper - domain-adapted SAM for histopathology"""
    
    def __init__(self, model_type='vit_b', checkpoint_path=None, device='cuda'):
        super().__init__(model_type, checkpoint_path, device)
        
        # PathoSAM-specific modifications
        # Note: This would require the actual PathoSAM checkpoint
        # For now, we'll use regular SAM with adjusted parameters
        self.mask_generator = SamAutomaticMaskGenerator(
            model=self.sam,
            points_per_side=64,  # More points for histopathology
            pred_iou_thresh=0.6,  # Lower threshold for complex structures
            stability_score_thresh=0.85,
            crop_n_layers=2,
            crop_n_points_downscale_factor=1.5,
            min_mask_region_area=50,  # Smaller minimum area for nuclei
        )

class HoVerNet(nn.Module):
    """HoVer-Net implementation for nuclei segmentation"""
    
    def __init__(self, num_classes=6, pretrained=True):
        super().__init__()
        self.num_classes = num_classes
        
        # Use ResNet50 as backbone
        self.backbone = models.resnet50(pretrained=pretrained)
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
        
        # Decoder heads
        self.np_hv = nn.Conv2d(2048, 2, 1)  # Horizontal and vertical maps
        self.np = nn.Conv2d(2048, 2, 1)     # Nuclei probability
        self.np_tp = nn.Conv2d(2048, num_classes, 1)  # Tissue type
        
        # Upsampling layers
        self.up1 = nn.ConvTranspose2d(2048, 1024, 2, 2)
        self.up2 = nn.ConvTranspose2d(1024, 512, 2, 2)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.up4 = nn.ConvTranspose2d(256, 128, 2, 2)
        
    def forward(self, x):
        # Encoder
        features = self.backbone(x)
        
        # Decoder
        up1 = self.up1(features)
        up2 = self.up2(up1)
        up3 = self.up3(up2)
        up4 = self.up4(up3)
        
        # Final predictions
        np_hv = self.np_hv(up4)
        np = self.np(up4)
        np_tp = self.np_tp(up4)
        
        return {
            'np_hv': np_hv,
            'np': np,
            'np_tp': np_tp
        }

class CellViT(nn.Module):
    """CellViT implementation using Vision Transformer"""
    
    def __init__(self, num_classes=6, patch_size=16, embed_dim=768):
        super().__init__()
        self.num_classes = num_classes
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        
        # Vision Transformer backbone
        self.vit = timm.create_model('vit_base_patch16_224', pretrained=True)
        self.vit.head = nn.Identity()  # Remove classification head
        
        # Segmentation head
        self.seg_head = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, 512, 4, 4),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 2, 2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 2, 2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, num_classes, 1)
        )
        
    def forward(self, x):
        # Get patch embeddings
        B = x.shape[0]
        x = self.vit.patch_embed(x)  # [B, num_patches, embed_dim]
        x = x.permute(0, 2, 1)  # [B, embed_dim, num_patches]
        
        # Reshape to spatial dimensions
        H = W = int(x.shape[2] ** 0.5)
        x = x.view(B, self.embed_dim, H, W)
        
        # Segmentation head
        x = self.seg_head(x)
        
        return x

class LKCell(nn.Module):
    """LKCell implementation for nuclei segmentation"""
    
    def __init__(self, num_classes=6):
        super().__init__()
        self.num_classes = num_classes
        
        # Encoder
        self.encoder = models.resnet34(pretrained=True)
        self.encoder = nn.Sequential(*list(self.encoder.children())[:-2])
        
        # Decoder with skip connections
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 2, 2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 2, 2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 2, 2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, num_classes, 1)
        )
        
    def forward(self, x):
        # Encoder
        features = self.encoder(x)
        
        # Decoder
        output = self.decoder(features)
        
        return output

# Model factory function
def create_model(model_name, **kwargs):
    """Factory function to create models"""
    
    if model_name.startswith('sam'):
        if model_name == 'pathosam':
            return PathoSAMWrapper(**kwargs)
        else:
            return SAMWrapper(**kwargs)
    elif model_name == 'hovernet':
        return HoVerNet(**kwargs)
    elif model_name == 'cellvit':
        return CellViT(**kwargs)
    elif model_name == 'lkcell':
        return LKCell(**kwargs)
    elif model_name == 'unet':
        return UNet(in_channels=3, out_channels=6)
    else:
        raise ValueError(f"Unknown model: {model_name}")

print("✅ Model implementations complete!")
print("Available models:")
print("- SAM variants: sam_base, sam_large, sam_huge, pathosam")
print("- Established models: hovernet, cellvit, lkcell")
print("- Baseline: unet")


## 3. Evaluation Framework

Implement comprehensive evaluation metrics including mPQ and detection F1, with statistical analysis capabilities.


In [None]:
# =============================================================================
# EVALUATION FRAMEWORK
# =============================================================================

class PanopticQualityCalculator:
    """Calculate Panoptic Quality (PQ) metrics for nuclei segmentation"""
    
    def __init__(self, num_classes=6, iou_threshold=0.5):
        self.num_classes = num_classes
        self.iou_threshold = iou_threshold
    
    def calculate_pq(self, pred_mask, gt_mask):
        """Calculate PQ for a single image"""
        
        # Convert to numpy if needed
        if torch.is_tensor(pred_mask):
            pred_mask = pred_mask.cpu().numpy()
        if torch.is_tensor(gt_mask):
            gt_mask = gt_mask.cpu().numpy()
        
        # Get unique instance IDs (excluding background)
        pred_instances = np.unique(pred_mask)
        pred_instances = pred_instances[pred_instances > 0]
        
        gt_instances = np.unique(gt_mask)
        gt_instances = gt_instances[gt_instances > 0]
        
        # Calculate IoU between all pairs
        ious = []
        matched_pred = set()
        matched_gt = set()
        
        for pred_id in pred_instances:
            for gt_id in gt_instances:
                pred_binary = (pred_mask == pred_id).astype(np.uint8)
                gt_binary = (gt_mask == gt_id).astype(np.uint8)
                
                intersection = np.logical_and(pred_binary, gt_binary).sum()
                union = np.logical_or(pred_binary, gt_binary).sum()
                
                if union > 0:
                    iou = intersection / union
                    if iou >= self.iou_threshold:
                        ious.append(iou)
                        matched_pred.add(pred_id)
                        matched_gt.add(gt_id)
        
        # Calculate PQ components
        true_positives = len(ious)
        false_negatives = len(gt_instances) - len(matched_gt)
        false_positives = len(pred_instances) - len(matched_pred)
        
        # PQ = (IoU / (TP + 0.5 * FP + 0.5 * FN)) * (TP / (TP + 0.5 * FP + 0.5 * FN))
        if true_positives > 0:
            iou_sum = sum(ious)
            pq = (iou_sum / (true_positives + 0.5 * false_positives + 0.5 * false_negatives)) * \
                 (true_positives / (true_positives + 0.5 * false_positives + 0.5 * false_negatives))
        else:
            pq = 0.0
        
        return {
            'pq': pq,
            'true_positives': true_positives,
            'false_positives': false_positives,
            'false_negatives': false_negatives,
            'mean_iou': np.mean(ious) if ious else 0.0
        }

class DetectionF1Calculator:
    """Calculate Detection F1 score for nuclei segmentation"""
    
    def __init__(self, iou_threshold=0.5):
        self.iou_threshold = iou_threshold
    
    def calculate_f1(self, pred_mask, gt_mask):
        """Calculate F1 score for detection"""
        
        # Convert to numpy if needed
        if torch.is_tensor(pred_mask):
            pred_mask = pred_mask.cpu().numpy()
        if torch.is_tensor(gt_mask):
            gt_mask = gt_mask.cpu().numpy()
        
        # Get unique instance IDs
        pred_instances = np.unique(pred_mask)
        pred_instances = pred_instances[pred_instances > 0]
        
        gt_instances = np.unique(gt_mask)
        gt_instances = gt_instances[gt_instances > 0]
        
        # Calculate IoU matrix
        iou_matrix = np.zeros((len(pred_instances), len(gt_instances)))
        
        for i, pred_id in enumerate(pred_instances):
            for j, gt_id in enumerate(gt_instances):
                pred_binary = (pred_mask == pred_id).astype(np.uint8)
                gt_binary = (gt_mask == gt_id).astype(np.uint8)
                
                intersection = np.logical_and(pred_binary, gt_binary).sum()
                union = np.logical_or(pred_binary, gt_binary).sum()
                
                if union > 0:
                    iou_matrix[i, j] = intersection / union
        
        # Find matches
        matches = iou_matrix >= self.iou_threshold
        
        # Calculate precision and recall
        true_positives = matches.sum()
        false_positives = len(pred_instances) - true_positives
        false_negatives = len(gt_instances) - true_positives
        
        precision = true_positives / len(pred_instances) if len(pred_instances) > 0 else 0.0
        recall = true_positives / len(gt_instances) if len(gt_instances) > 0 else 0.0
        
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        
        return {
            'f1': f1,
            'precision': precision,
            'recall': recall,
            'true_positives': true_positives,
            'false_positives': false_positives,
            'false_negatives': false_negatives
        }

class ModelEvaluator:
    """Comprehensive model evaluation framework"""
    
    def __init__(self, device='cuda'):
        self.device = device
        self.pq_calculator = PanopticQualityCalculator()
        self.f1_calculator = DetectionF1Calculator()
    
    def evaluate_model(self, model, dataloader, model_name="model"):
        """Evaluate a single model on a dataset"""
        
        model.eval()
        results = []
        
        print(f"🔍 Evaluating {model_name}...")
        
        with torch.no_grad():
            for batch_idx, (images, masks, names) in enumerate(tqdm(dataloader)):
                images = images.to(self.device)
                masks = masks.to(self.device)
                
                # Get predictions
                if hasattr(model, 'forward') and not isinstance(model, (SAMWrapper, PathoSAMWrapper)):
                    # Standard models
                    outputs = model(images)
                    if isinstance(outputs, dict):
                        # Handle multi-output models like HoVerNet
                        pred_masks = outputs.get('np_tp', outputs.get('np', outputs))
                    else:
                        pred_masks = outputs
                    
                    pred_masks = torch.argmax(pred_masks, dim=1)
                else:
                    # SAM models
                    pred_masks = model(images)
                
                # Calculate metrics for each image in batch
                batch_size = images.shape[0]
                for i in range(batch_size):
                    pred_mask = pred_masks[i]
                    gt_mask = masks[i]
                    
                    # Calculate PQ
                    pq_metrics = self.pq_calculator.calculate_pq(pred_mask, gt_mask)
                    
                    # Calculate F1
                    f1_metrics = self.f1_calculator.calculate_f1(pred_mask, gt_mask)
                    
                    # Calculate additional metrics
                    pred_np = pred_mask.cpu().numpy()
                    gt_np = gt_mask.cpu().numpy()
                    
                    # Basic segmentation metrics
                    iou = jaccard_score(gt_np.flatten(), pred_np.flatten(), average='macro', zero_division=0)
                    pixel_acc = (pred_np == gt_np).mean()
                    
                    result = {
                        'image_name': names[i],
                        'batch_idx': batch_idx,
                        'image_idx': i,
                        'pq': pq_metrics['pq'],
                        'detection_f1': f1_metrics['f1'],
                        'precision': f1_metrics['precision'],
                        'recall': f1_metrics['recall'],
                        'iou': iou,
                        'pixel_accuracy': pixel_acc,
                        'true_positives': f1_metrics['true_positives'],
                        'false_positives': f1_metrics['false_positives'],
                        'false_negatives': f1_metrics['false_negatives'],
                        'mean_iou': pq_metrics['mean_iou']
                    }
                    
                    results.append(result)
        
        return results
    
    def evaluate_all_models(self, models, dataloaders):
        """Evaluate all models on all datasets"""
        
        all_results = {}
        
        for model_name, model in models.items():
            print(f"\n📊 Evaluating {model_name}...")
            model_results = {}
            
            for split_name, dataloader in dataloaders.items():
                print(f"  {split_name} split...")
                split_results = self.evaluate_model(model, dataloader, f"{model_name}_{split_name}")
                model_results[split_name] = split_results
            
            all_results[model_name] = model_results
        
        return all_results

# Initialize evaluator
evaluator = ModelEvaluator(device=device)

print("✅ Evaluation framework ready!")
print("Available metrics:")
print("- Panoptic Quality (PQ)")
print("- Detection F1 Score")
print("- Precision, Recall")
print("- IoU, Pixel Accuracy")
print("- True/False Positives/Negatives")


## 4. Model Training and Evaluation

Train and evaluate all models on the PanNuke dataset with comprehensive metrics collection.


In [None]:
# =============================================================================
# MODEL TRAINING AND EVALUATION
# =============================================================================

class ModelTrainer:
    """Training framework for different model types"""
    
    def __init__(self, device='cuda', learning_rate=1e-4):
        self.device = device
        self.learning_rate = learning_rate
        
    def train_model(self, model, train_loader, val_loader, epochs=50, model_name="model"):
        """Train a model with early stopping"""
        
        model = model.to(self.device)
        
        # Setup optimizer and loss
        if isinstance(model, (SAMWrapper, PathoSAMWrapper)):
            # SAM models are typically not trained from scratch
            print(f"⚠️  {model_name} is a SAM model - skipping training (using pretrained weights)")
            return model, {}
        
        optimizer = optim.Adam(model.parameters(), lr=self.learning_rate)
        criterion = nn.CrossEntropyLoss()
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)
        
        # Training history
        history = {
            'train_loss': [],
            'val_loss': [],
            'val_pq': [],
            'val_f1': []
        }
        
        best_val_pq = 0.0
        patience_counter = 0
        
        print(f"🚀 Training {model_name}...")
        
        for epoch in range(epochs):
            # Training phase
            model.train()
            train_loss = 0.0
            
            for batch_idx, (images, masks, _) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")):
                images, masks = images.to(self.device), masks.to(self.device)
                
                optimizer.zero_grad()
                
                # Forward pass
                outputs = model(images)
                if isinstance(outputs, dict):
                    # Handle multi-output models
                    outputs = outputs.get('np_tp', outputs.get('np', outputs))
                
                loss = criterion(outputs, masks)
                
                # Backward pass
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
            
            # Validation phase
            model.eval()
            val_loss = 0.0
            val_pq_scores = []
            val_f1_scores = []
            
            with torch.no_grad():
                for images, masks, _ in val_loader:
                    images, masks = images.to(self.device), masks.to(self.device)
                    
                    outputs = model(images)
                    if isinstance(outputs, dict):
                        outputs = outputs.get('np_tp', outputs.get('np', outputs))
                    
                    loss = criterion(outputs, masks)
                    val_loss += loss.item()
                    
                    # Calculate metrics for validation
                    pred_masks = torch.argmax(outputs, dim=1)
                    
                    for i in range(pred_masks.shape[0]):
                        pq_metrics = evaluator.pq_calculator.calculate_pq(pred_masks[i], masks[i])
                        f1_metrics = evaluator.f1_calculator.calculate_f1(pred_masks[i], masks[i])
                        val_pq_scores.append(pq_metrics['pq'])
                        val_f1_scores.append(f1_metrics['f1'])
            
            # Calculate averages
            avg_train_loss = train_loss / len(train_loader)
            avg_val_loss = val_loss / len(val_loader)
            avg_val_pq = np.mean(val_pq_scores)
            avg_val_f1 = np.mean(val_f1_scores)
            
            # Update history
            history['train_loss'].append(avg_train_loss)
            history['val_loss'].append(avg_val_loss)
            history['val_pq'].append(avg_val_pq)
            history['val_f1'].append(avg_val_f1)
            
            # Learning rate scheduling
            scheduler.step(avg_val_loss)
            
            # Early stopping
            if avg_val_pq > best_val_pq:
                best_val_pq = avg_val_pq
                patience_counter = 0
                # Save best model
                torch.save(model.state_dict(), artifacts_dir / 'models' / f'{model_name}_best.pth')
            else:
                patience_counter += 1
            
            print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, "
                  f"Val PQ: {avg_val_pq:.4f}, Val F1: {avg_val_f1:.4f}")
            
            if patience_counter >= config.patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
        
        # Load best model
        if os.path.exists(artifacts_dir / 'models' / f'{model_name}_best.pth'):
            model.load_state_dict(torch.load(artifacts_dir / 'models' / f'{model_name}_best.pth'))
        
        return model, history

# Initialize trainer
trainer = ModelTrainer(device=device, learning_rate=config.learning_rate)

# Define models to train
models_to_train = {
    'unet': 'unet',
    'hovernet': 'hovernet', 
    'cellvit': 'cellvit',
    'lkcell': 'lkcell'
}

# SAM models (pretrained, no training needed)
sam_models = {
    'sam_base': 'sam_base',
    'sam_large': 'sam_large',
    'sam_huge': 'sam_huge',
    'pathosam': 'pathosam'
}

print("🏗️  Initializing models...")

# Create and train standard models
trained_models = {}
training_histories = {}

for model_name in models_to_train:
    print(f"\n📦 Creating {model_name}...")
    model = create_model(model_name, num_classes=6)
    
    # Train the model
    trained_model, history = trainer.train_model(
        model, 
        dataloaders['train'], 
        dataloaders['val'],
        epochs=config.epochs,
        model_name=model_name
    )
    
    trained_models[model_name] = trained_model
    training_histories[model_name] = history

# Create SAM models (pretrained)
for model_name in sam_models:
    print(f"\n📦 Creating {model_name}...")
    model_type = 'vit_b' if 'base' in model_name else 'vit_l' if 'large' in model_name else 'vit_h'
    model = create_model(model_name, model_type=model_type, device=device)
    trained_models[model_name] = model

print(f"\n✅ All models ready! Total: {len(trained_models)} models")
print("Models available:", list(trained_models.keys()))


In [None]:
# =============================================================================
# COMPREHENSIVE MODEL EVALUATION
# =============================================================================

print("🔍 Starting comprehensive evaluation of all models...")

# Evaluate all models on test set
all_results = {}

for model_name, model in trained_models.items():
    print(f"\n📊 Evaluating {model_name}...")
    
    # Choose appropriate dataloader
    if model_name.startswith('sam') or model_name == 'pathosam':
        test_loader = sam_dataloaders['test']
    else:
        test_loader = dataloaders['test']
    
    # Evaluate model
    results = evaluator.evaluate_model(model, test_loader, model_name)
    all_results[model_name] = results
    
    # Calculate summary statistics
    df = pd.DataFrame(results)
    summary = {
        'model': model_name,
        'mean_pq': df['pq'].mean(),
        'std_pq': df['pq'].std(),
        'mean_f1': df['detection_f1'].mean(),
        'std_f1': df['detection_f1'].std(),
        'mean_iou': df['iou'].mean(),
        'std_iou': df['iou'].std(),
        'mean_precision': df['precision'].mean(),
        'std_precision': df['precision'].std(),
        'mean_recall': df['recall'].mean(),
        'std_recall': df['recall'].std(),
        'n_images': len(df)
    }
    
    print(f"  PQ: {summary['mean_pq']:.4f} ± {summary['std_pq']:.4f}")
    print(f"  F1: {summary['mean_f1']:.4f} ± {summary['std_f1']:.4f}")
    print(f"  IoU: {summary['mean_iou']:.4f} ± {summary['std_iou']:.4f}")

# Create comprehensive results DataFrame
all_results_df = []
for model_name, results in all_results.items():
    df = pd.DataFrame(results)
    df['model'] = model_name
    all_results_df.append(df)

combined_results_df = pd.concat(all_results_df, ignore_index=True)

# Save results
combined_results_df.to_csv(artifacts_dir / 'results' / 'all_model_results.csv', index=False)

print(f"\n✅ Evaluation complete! Results saved to {artifacts_dir / 'results' / 'all_model_results.csv'}")
print(f"Total evaluations: {len(combined_results_df)}")
print(f"Models evaluated: {combined_results_df['model'].nunique()}")
print(f"Images per model: {len(combined_results_df) // combined_results_df['model'].nunique()}")


## 5. Statistical Analysis

Perform comprehensive statistical analysis to test research hypotheses with proper multiple comparison corrections.


In [None]:
# =============================================================================
# STATISTICAL ANALYSIS
# =============================================================================

class StatisticalAnalyzer:
    """Comprehensive statistical analysis for model comparison"""
    
    def __init__(self, alpha=0.05):
        self.alpha = alpha
    
    def calculate_effect_size(self, group1, group2):
        """Calculate Cohen's d effect size"""
        n1, n2 = len(group1), len(group2)
        s1, s2 = np.std(group1, ddof=1), np.std(group2, ddof=1)
        
        # Pooled standard deviation
        pooled_std = np.sqrt(((n1 - 1) * s1**2 + (n2 - 1) * s2**2) / (n1 + n2 - 2))
        
        if pooled_std == 0:
            return 0.0
        
        cohens_d = (np.mean(group1) - np.mean(group2)) / pooled_std
        return cohens_d
    
    def compare_models(self, df, metric='pq'):
        """Compare all models pairwise for a given metric"""
        
        models = df['model'].unique()
        n_models = len(models)
        
        # Initialize results
        comparison_results = []
        
        for i in range(n_models):
            for j in range(i + 1, n_models):
                model1, model2 = models[i], models[j]
                
                # Get data for both models
                data1 = df[df['model'] == model1][metric].values
                data2 = df[df['model'] == model2][metric].values
                
                # Ensure same length (for paired tests)
                min_len = min(len(data1), len(data2))
                data1 = data1[:min_len]
                data2 = data2[:min_len]
                
                if len(data1) < 2:
                    continue
                
                # Paired t-test
                try:
                    t_stat, t_p = ttest_rel(data1, data2)
                    t_significant = t_p < self.alpha
                except:
                    t_stat, t_p = np.nan, np.nan
                    t_significant = False
                
                # Wilcoxon signed-rank test
                try:
                    w_stat, w_p = wilcoxon(data1, data2, alternative='two-sided')
                    w_significant = w_p < self.alpha
                except:
                    w_stat, w_p = np.nan, np.nan
                    w_significant = False
                
                # Effect size
                cohens_d = self.calculate_effect_size(data1, data2)
                
                # Descriptive statistics
                mean1, mean2 = np.mean(data1), np.mean(data2)
                std1, std2 = np.std(data1, ddof=1), np.std(data2, ddof=1)
                
                result = {
                    'model1': model1,
                    'model2': model2,
                    'metric': metric,
                    'mean1': mean1,
                    'std1': std1,
                    'mean2': mean2,
                    'std2': std2,
                    'mean_diff': mean1 - mean2,
                    't_statistic': t_stat,
                    't_p_value': t_p,
                    't_significant': t_significant,
                    'w_statistic': w_stat,
                    'w_p_value': w_p,
                    'w_significant': w_significant,
                    'cohens_d': cohens_d,
                    'effect_size': 'small' if abs(cohens_d) < 0.5 else 'medium' if abs(cohens_d) < 0.8 else 'large',
                    'n_samples': len(data1)
                }
                
                comparison_results.append(result)
        
        return pd.DataFrame(comparison_results)
    
    def sam_vs_baseline_analysis(self, df, metric='pq'):
        """Specific analysis: SAM variants vs established models"""
        
        # Define groups
        sam_models = [m for m in df['model'].unique() if m.startswith('sam') or m == 'pathosam']
        established_models = [m for m in df['model'].unique() if m in ['hovernet', 'cellvit', 'lkcell']]
        baseline_models = [m for m in df['model'].unique() if m == 'unet']
        
        results = []
        
        # Compare each SAM variant against each established model
        for sam_model in sam_models:
            sam_data = df[df['model'] == sam_model][metric].values
            
            for est_model in established_models:
                est_data = df[df['model'] == est_model][metric].values
                
                # Ensure same length
                min_len = min(len(sam_data), len(est_data))
                sam_data_trimmed = sam_data[:min_len]
                est_data_trimmed = est_data[:min_len]
                
                if len(sam_data_trimmed) < 2:
                    continue
                
                # Statistical tests
                try:
                    t_stat, t_p = ttest_rel(sam_data_trimmed, est_data_trimmed)
                    w_stat, w_p = wilcoxon(sam_data_trimmed, est_data_trimmed, alternative='two-sided')
                except:
                    t_stat, t_p = np.nan, np.nan
                    w_stat, w_p = np.nan, np.nan
                
                cohens_d = self.calculate_effect_size(sam_data_trimmed, est_data_trimmed)
                
                result = {
                    'sam_model': sam_model,
                    'established_model': est_model,
                    'metric': metric,
                    'sam_mean': np.mean(sam_data_trimmed),
                    'est_mean': np.mean(est_data_trimmed),
                    'mean_diff': np.mean(sam_data_trimmed) - np.mean(est_data_trimmed),
                    't_p_value': t_p,
                    'w_p_value': w_p,
                    'cohens_d': cohens_d,
                    'sam_better': np.mean(sam_data_trimmed) > np.mean(est_data_trimmed),
                    'significant_t': t_p < self.alpha,
                    'significant_w': w_p < self.alpha
                }
                
                results.append(result)
        
        return pd.DataFrame(results)
    
    def multiple_comparison_correction(self, df, p_value_col='t_p_value'):
        """Apply Bonferroni correction for multiple comparisons"""
        
        # Get unique comparisons
        comparisons = df.groupby(['model1', 'model2']).size().reset_index()
        n_comparisons = len(comparisons)
        
        # Apply Bonferroni correction
        df[f'{p_value_col}_bonferroni'] = df[p_value_col] * n_comparisons
        df[f'{p_value_col}_bonferroni'] = np.minimum(df[f'{p_value_col}_bonferroni'], 1.0)
        
        # Apply Benjamini-Hochberg correction
        from statsmodels.stats.multitest import multipletests
        _, p_corrected, _, _ = multipletests(df[p_value_col], method='fdr_bh')
        df[f'{p_value_col}_bh'] = p_corrected
        
        return df

# Initialize analyzer
analyzer = StatisticalAnalyzer(alpha=config.alpha)

print("📊 Performing statistical analysis...")

# Analyze each metric
metrics_to_analyze = ['pq', 'detection_f1', 'iou', 'precision', 'recall']
all_comparisons = {}

for metric in metrics_to_analyze:
    print(f"\n🔍 Analyzing {metric}...")
    
    # General pairwise comparisons
    comparisons = analyzer.compare_models(combined_results_df, metric=metric)
    comparisons = analyzer.multiple_comparison_correction(comparisons)
    
    # SAM vs established models analysis
    sam_vs_est = analyzer.sam_vs_baseline_analysis(combined_results_df, metric=metric)
    
    all_comparisons[metric] = {
        'pairwise': comparisons,
        'sam_vs_established': sam_vs_est
    }
    
    # Save results
    comparisons.to_csv(artifacts_dir / 'results' / f'{metric}_comparisons.csv', index=False)
    sam_vs_est.to_csv(artifacts_dir / 'results' / f'{metric}_sam_vs_established.csv', index=False)
    
    print(f"  Pairwise comparisons: {len(comparisons)}")
    print(f"  SAM vs established: {len(sam_vs_est)}")
    print(f"  Significant comparisons (t-test): {comparisons['t_significant'].sum()}")
    print(f"  Significant comparisons (Wilcoxon): {comparisons['w_significant'].sum()}")

print("\n✅ Statistical analysis complete!")
print(f"Results saved to {artifacts_dir / 'results' /}")


## 6. Visualization and Results

Create comprehensive visualizations to present the research findings and model comparisons.


In [None]:
# =============================================================================
# COMPREHENSIVE VISUALIZATION AND RESULTS
# =============================================================================

class ResultsVisualizer:
    """Create publication-ready visualizations for model comparison results"""
    
    def __init__(self, artifacts_dir):
        self.artifacts_dir = artifacts_dir
        self.plots_dir = artifacts_dir / 'plots'
        self.plots_dir.mkdir(exist_ok=True)
        
        # Set style
        plt.style.use('seaborn-v0_8')
        sns.set_palette("husl")
    
    def plot_model_performance_comparison(self, df, metrics=['pq', 'detection_f1', 'iou']):
        """Create comprehensive model performance comparison plots"""
        
        fig, axes = plt.subplots(2, 2, figsize=(20, 16))
        axes = axes.flatten()
        
        for i, metric in enumerate(metrics):
            ax = axes[i]
            
            # Create box plot
            sns.boxplot(data=df, x='model', y=metric, ax=ax)
            ax.set_title(f'{metric.upper()} Performance Comparison', fontsize=14, fontweight='bold')
            ax.set_xlabel('Model', fontsize=12)
            ax.set_ylabel(f'{metric.upper()}', fontsize=12)
            ax.tick_params(axis='x', rotation=45)
            
            # Add mean values as text
            model_means = df.groupby('model')[metric].mean().sort_values(ascending=False)
            for j, (model, mean_val) in enumerate(model_means.items()):
                ax.text(j, mean_val + 0.01, f'{mean_val:.3f}', 
                       ha='center', va='bottom', fontweight='bold')
        
        # Model ranking plot
        ax = axes[3]
        model_rankings = {}
        for metric in metrics:
            rankings = df.groupby('model')[metric].mean().rank(ascending=False)
            model_rankings[metric] = rankings
        
        ranking_df = pd.DataFrame(model_rankings)
        ranking_df['average_rank'] = ranking_df.mean(axis=1)
        ranking_df = ranking_df.sort_values('average_rank')
        
        sns.heatmap(ranking_df[metrics], annot=True, cmap='RdYlGn_r', 
                   ax=ax, cbar_kws={'label': 'Rank (1=Best)'})
        ax.set_title('Model Rankings Across Metrics', fontsize=14, fontweight='bold')
        ax.set_xlabel('Metrics', fontsize=12)
        ax.set_ylabel('Models', fontsize=12)
        
        plt.tight_layout()
        plt.savefig(self.plots_dir / 'model_performance_comparison.png', 
                   dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_sam_vs_established(self, sam_vs_est_data, metric='pq'):
        """Plot SAM variants vs established models comparison"""
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
        
        # Performance comparison
        sam_models = sam_vs_est_data['sam_model'].unique()
        est_models = sam_vs_est_data['established_model'].unique()
        
        # Create comparison matrix
        comparison_matrix = np.zeros((len(sam_models), len(est_models)))
        for i, sam_model in enumerate(sam_models):
            for j, est_model in enumerate(est_models):
                comparison = sam_vs_est_data[
                    (sam_vs_est_data['sam_model'] == sam_model) & 
                    (sam_vs_est_data['established_model'] == est_model)
                ]
                if not comparison.empty:
                    comparison_matrix[i, j] = comparison['mean_diff'].iloc[0]
        
        im = ax1.imshow(comparison_matrix, cmap='RdBu_r', aspect='auto')
        ax1.set_xticks(range(len(est_models)))
        ax1.set_yticks(range(len(sam_models)))
        ax1.set_xticklabels(est_models, rotation=45)
        ax1.set_yticklabels(sam_models)
        ax1.set_title(f'SAM vs Established Models ({metric.upper()})', fontweight='bold')
        ax1.set_xlabel('Established Models')
        ax1.set_ylabel('SAM Variants')
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=ax1)
        cbar.set_label(f'Mean Difference ({metric.upper()})')
        
        # Add text annotations
        for i in range(len(sam_models)):
            for j in range(len(est_models)):
                text = ax1.text(j, i, f'{comparison_matrix[i, j]:.3f}',
                               ha="center", va="center", color="black", fontweight='bold')
        
        # Statistical significance
        sig_data = sam_vs_est_data[sam_vs_est_data['significant_t'] == True]
        if not sig_data.empty:
            ax2.bar(range(len(sig_data)), sig_data['cohens_d'])
            ax2.set_xticks(range(len(sig_data)))
            ax2.set_xticklabels([f"{row['sam_model']} vs {row['established_model']}" 
                               for _, row in sig_data.iterrows()], rotation=45)
            ax2.set_title('Significant Differences (Cohen\'s d)', fontweight='bold')
            ax2.set_ylabel("Effect Size (Cohen's d)")
            ax2.axhline(y=0, color='black', linestyle='-', alpha=0.3)
            ax2.axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='Medium Effect')
            ax2.axhline(y=-0.5, color='red', linestyle='--', alpha=0.5)
            ax2.legend()
        else:
            ax2.text(0.5, 0.5, 'No significant differences found', 
                    ha='center', va='center', transform=ax2.transAxes, fontsize=12)
            ax2.set_title('Statistical Significance', fontweight='bold')
        
        plt.tight_layout()
        plt.savefig(self.plots_dir / f'sam_vs_established_{metric}.png', 
                   dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_metric_distributions(self, df, metrics=['pq', 'detection_f1', 'iou']):
        """Plot distribution of metrics across models"""
        
        fig, axes = plt.subplots(1, len(metrics), figsize=(6*len(metrics), 6))
        if len(metrics) == 1:
            axes = [axes]
        
        for i, metric in enumerate(metrics):
            ax = axes[i]
            
            # Create violin plot
            sns.violinplot(data=df, x='model', y=metric, ax=ax)
            ax.set_title(f'{metric.upper()} Distribution', fontweight='bold')
            ax.set_xlabel('Model')
            ax.set_ylabel(f'{metric.upper()}')
            ax.tick_params(axis='x', rotation=45)
            
            # Add mean line
            model_means = df.groupby('model')[metric].mean()
            for j, (model, mean_val) in enumerate(model_means.items()):
                ax.axhline(y=mean_val, xmin=j/len(model_means), xmax=(j+1)/len(model_means), 
                          color='red', linestyle='--', alpha=0.7)
        
        plt.tight_layout()
        plt.savefig(self.plots_dir / 'metric_distributions.png', 
                   dpi=300, bbox_inches='tight')
        plt.show()
    
    def create_summary_table(self, df, metrics=['pq', 'detection_f1', 'iou']):
        """Create summary statistics table"""
        
        summary_stats = []
        
        for model in df['model'].unique():
            model_data = df[df['model'] == model]
            
            stats = {'Model': model}
            for metric in metrics:
                values = model_data[metric].dropna()
                stats[f'{metric}_mean'] = values.mean()
                stats[f'{metric}_std'] = values.std()
                stats[f'{metric}_median'] = values.median()
                stats[f'{metric}_q25'] = values.quantile(0.25)
                stats[f'{metric}_q75'] = values.quantile(0.75)
            
            summary_stats.append(stats)
        
        summary_df = pd.DataFrame(summary_stats)
        
        # Round to 4 decimal places
        numeric_cols = summary_df.select_dtypes(include=[np.number]).columns
        summary_df[numeric_cols] = summary_df[numeric_cols].round(4)
        
        # Save to CSV
        summary_df.to_csv(self.artifacts_dir / 'results' / 'summary_statistics.csv', index=False)
        
        return summary_df

# Initialize visualizer
visualizer = ResultsVisualizer(artifacts_dir)

print("📊 Creating visualizations...")

# Create main performance comparison plot
visualizer.plot_model_performance_comparison(combined_results_df)

# Create SAM vs established models plots
for metric in ['pq', 'detection_f1']:
    if metric in all_comparisons:
        sam_vs_est = all_comparisons[metric]['sam_vs_established']
        visualizer.plot_sam_vs_established(sam_vs_est, metric)

# Create metric distribution plots
visualizer.plot_metric_distributions(combined_results_df)

# Create summary table
summary_table = visualizer.create_summary_table(combined_results_df)
print("\n📋 Summary Statistics:")
print(summary_table.to_string(index=False))

print(f"\n✅ Visualizations complete! Saved to {artifacts_dir / 'plots' /}")


## 7. Research Question Conclusions

Summarize findings and provide conclusions for Research Question 1.


In [None]:
# =============================================================================
# RESEARCH QUESTION CONCLUSIONS
# =============================================================================

def generate_research_conclusions(df, all_comparisons):
    """Generate comprehensive conclusions for Research Question 1"""
    
    print("🔬 RESEARCH QUESTION 1 CONCLUSIONS")
    print("=" * 50)
    
    # Calculate overall performance rankings
    metrics = ['pq', 'detection_f1', 'iou']
    model_rankings = {}
    
    for metric in metrics:
        rankings = df.groupby('model')[metric].mean().rank(ascending=False)
        model_rankings[metric] = rankings
    
    ranking_df = pd.DataFrame(model_rankings)
    ranking_df['average_rank'] = ranking_df.mean(axis=1)
    ranking_df = ranking_df.sort_values('average_rank')
    
    print("\n📊 OVERALL MODEL RANKINGS:")
    print(ranking_df.round(2))
    
    # Identify best performing models
    best_models = ranking_df.head(3)
    print(f"\n🏆 TOP 3 MODELS:")
    for i, (model, row) in enumerate(best_models.iterrows(), 1):
        print(f"{i}. {model} (Avg Rank: {row['average_rank']:.2f})")
    
    # SAM vs Established Models Analysis
    print(f"\n🔍 SAM VARIANTS vs ESTABLISHED MODELS ANALYSIS:")
    
    sam_models = [m for m in df['model'].unique() if m.startswith('sam') or m == 'pathosam']
    established_models = [m for m in df['model'].unique() if m in ['hovernet', 'cellvit', 'lkcell']]
    
    print(f"  SAM Variants: {sam_models}")
    print(f"  Established Models: {established_models}")
    
    # Analyze significant differences
    significant_comparisons = 0
    sam_wins = 0
    established_wins = 0
    
    for metric in ['pq', 'detection_f1']:
        if metric in all_comparisons:
            sam_vs_est = all_comparisons[metric]['sam_vs_established']
            sig_comparisons = sam_vs_est[sam_vs_est['significant_t'] == True]
            
            if not sig_comparisons.empty:
                significant_comparisons += len(sig_comparisons)
                sam_wins += (sig_comparisons['sam_better'] == True).sum()
                established_wins += (sig_comparisons['sam_better'] == False).sum()
                
                print(f"\n  {metric.upper()} - Significant Differences:")
                for _, row in sig_comparisons.iterrows():
                    winner = "SAM" if row['sam_better'] else "Established"
                    print(f"    {row['sam_model']} vs {row['established_model']}: {winner} wins (d={row['cohens_d']:.3f})")
    
    # Hypothesis testing results
    print(f"\n📈 HYPOTHESIS TESTING RESULTS:")
    print(f"  Total significant comparisons: {significant_comparisons}")
    print(f"  SAM variants win: {sam_wins}")
    print(f"  Established models win: {established_wins}")
    
    if significant_comparisons > 0:
        sam_win_rate = sam_wins / significant_comparisons
        print(f"  SAM win rate: {sam_win_rate:.2%}")
        
        if sam_win_rate > 0.5:
            print("  ✅ H1 SUPPORTED: SAM variants significantly outperform established models")
        else:
            print("  ❌ H0 SUPPORTED: SAM variants do not significantly outperform established models")
    else:
        print("  ⚠️  No significant differences found - insufficient evidence to reject H0")
    
    # Performance summary by model type
    print(f"\n📊 PERFORMANCE BY MODEL TYPE:")
    
    sam_performance = df[df['model'].isin(sam_models)].groupby('model')[['pq', 'detection_f1']].mean()
    established_performance = df[df['model'].isin(established_models)].groupby('model')[['pq', 'detection_f1']].mean()
    
    print(f"\n  SAM Variants (mPQ, Detection F1):")
    for model, row in sam_performance.iterrows():
        print(f"    {model}: {row['pq']:.4f}, {row['detection_f1']:.4f}")
    
    print(f"\n  Established Models (mPQ, Detection F1):")
    for model, row in established_performance.iterrows():
        print(f"    {model}: {row['pq']:.4f}, {row['detection_f1']:.4f}")
    
    # Key findings
    print(f"\n🎯 KEY FINDINGS:")
    
    # Best overall model
    best_model = ranking_df.index[0]
    best_pq = df[df['model'] == best_model]['pq'].mean()
    best_f1 = df[df['model'] == best_model]['detection_f1'].mean()
    
    print(f"  1. Best performing model: {best_model}")
    print(f"     - mPQ: {best_pq:.4f}")
    print(f"     - Detection F1: {best_f1:.4f}")
    
    # SAM performance
    sam_avg_pq = df[df['model'].isin(sam_models)]['pq'].mean()
    sam_avg_f1 = df[df['model'].isin(sam_models)]['detection_f1'].mean()
    
    print(f"  2. SAM variants average performance:")
    print(f"     - mPQ: {sam_avg_pq:.4f}")
    print(f"     - Detection F1: {sam_avg_f1:.4f}")
    
    # Established models performance
    est_avg_pq = df[df['model'].isin(established_models)]['pq'].mean()
    est_avg_f1 = df[df['model'].isin(established_models)]['detection_f1'].mean()
    
    print(f"  3. Established models average performance:")
    print(f"     - mPQ: {est_avg_pq:.4f}")
    print(f"     - Detection F1: {est_avg_f1:.4f}")
    
    # Performance difference
    pq_diff = sam_avg_pq - est_avg_pq
    f1_diff = sam_avg_f1 - est_avg_f1
    
    print(f"  4. Performance difference (SAM - Established):")
    print(f"     - mPQ: {pq_diff:+.4f}")
    print(f"     - Detection F1: {f1_diff:+.4f}")
    
    # Statistical significance
    if significant_comparisons > 0:
        print(f"  5. Statistical significance: {significant_comparisons} significant differences found")
        print(f"     - SAM wins: {sam_wins}")
        print(f"     - Established wins: {established_wins}")
    else:
        print(f"  5. Statistical significance: No significant differences found")
    
    print(f"\n✅ Research Question 1 analysis complete!")
    print(f"📁 All results saved to: {artifacts_dir}")
    
    return {
        'best_model': best_model,
        'sam_win_rate': sam_wins / max(significant_comparisons, 1),
        'significant_comparisons': significant_comparisons,
        'sam_wins': sam_wins,
        'established_wins': established_wins,
        'ranking_df': ranking_df
    }

# Generate conclusions
conclusions = generate_research_conclusions(combined_results_df, all_comparisons)

# Save conclusions
with open(artifacts_dir / 'results' / 'research_conclusions.json', 'w') as f:
    json.dump({
        'best_model': conclusions['best_model'],
        'sam_win_rate': conclusions['sam_win_rate'],
        'significant_comparisons': conclusions['significant_comparisons'],
        'sam_wins': conclusions['sam_wins'],
        'established_wins': conclusions['established_wins']
    }, f, indent=2)

print(f"\n🎉 RQ1 Analysis Complete!")
print(f"📊 Results saved to: {artifacts_dir}")
print(f"📈 Plots saved to: {artifacts_dir / 'plots'}")
print(f"📋 Data saved to: {artifacts_dir / 'results'}")


## 8. Results Export and Artifact Management

Comprehensive saving of all research results, models, and artifacts for reproducibility and analysis.


In [None]:
# =============================================================================
# COMPREHENSIVE RESULTS EXPORT AND ARTIFACT MANAGEMENT
# =============================================================================

class ArtifactManager:
    """Comprehensive artifact management for RQ1 research results"""
    
    def __init__(self, base_artifacts_dir):
        self.base_dir = Path(base_artifacts_dir)
        self.rq1_dir = self.base_dir / 'rq1'
        
        # Create directory structure
        self.dirs = {
            'models': self.rq1_dir / 'models',
            'results': self.rq1_dir / 'results',
            'plots': self.rq1_dir / 'plots',
            'logs': self.rq1_dir / 'logs',
            'configs': self.rq1_dir / 'configs',
            'statistics': self.rq1_dir / 'statistics',
            'comparisons': self.rq1_dir / 'comparisons',
            'summaries': self.rq1_dir / 'summaries'
        }
        
        # Create all directories
        for dir_path in self.dirs.values():
            dir_path.mkdir(parents=True, exist_ok=True)
        
        print(f"📁 Artifact directories created in: {self.rq1_dir}")
    
    def save_model_results(self, trained_models, training_histories):
        """Save all model results and training histories"""
        
        print("💾 Saving model results...")
        
        # Save model weights
        for model_name, model in trained_models.items():
            if not isinstance(model, (SAMWrapper, PathoSAMWrapper)):
                # Save trained model weights
                model_path = self.dirs['models'] / f'{model_name}_final.pth'
                torch.save(model.state_dict(), model_path)
                print(f"  ✅ {model_name} weights saved to {model_path}")
        
        # Save training histories
        for model_name, history in training_histories.items():
            if history:  # Only save if history exists
                history_path = self.dirs['results'] / f'{model_name}_training_history.json'
                with open(history_path, 'w') as f:
                    json.dump(history, f, indent=2)
                print(f"  ✅ {model_name} training history saved to {history_path}")
    
    def save_evaluation_results(self, all_results, combined_results_df):
        """Save comprehensive evaluation results"""
        
        print("💾 Saving evaluation results...")
        
        # Save individual model results
        for model_name, results in all_results.items():
            results_path = self.dirs['results'] / f'{model_name}_evaluation_results.csv'
            df = pd.DataFrame(results)
            df.to_csv(results_path, index=False)
            print(f"  ✅ {model_name} results saved to {results_path}")
        
        # Save combined results
        combined_path = self.dirs['results'] / 'all_models_combined_results.csv'
        combined_results_df.to_csv(combined_path, index=False)
        print(f"  ✅ Combined results saved to {combined_path}")
        
        # Save per-model summary statistics
        summary_stats = []
        for model_name in combined_results_df['model'].unique():
            model_data = combined_results_df[combined_results_df['model'] == model_name]
            
            stats = {
                'model': model_name,
                'n_images': len(model_data),
                'pq_mean': model_data['pq'].mean(),
                'pq_std': model_data['pq'].std(),
                'pq_median': model_data['pq'].median(),
                'pq_q25': model_data['pq'].quantile(0.25),
                'pq_q75': model_data['pq'].quantile(0.75),
                'f1_mean': model_data['detection_f1'].mean(),
                'f1_std': model_data['detection_f1'].std(),
                'f1_median': model_data['detection_f1'].median(),
                'f1_q25': model_data['detection_f1'].quantile(0.25),
                'f1_q75': model_data['detection_f1'].quantile(0.75),
                'iou_mean': model_data['iou'].mean(),
                'iou_std': model_data['iou'].std(),
                'precision_mean': model_data['precision'].mean(),
                'recall_mean': model_data['recall'].mean()
            }
            summary_stats.append(stats)
        
        summary_df = pd.DataFrame(summary_stats)
        summary_path = self.dirs['summaries'] / 'model_performance_summary.csv'
        summary_df.to_csv(summary_path, index=False)
        print(f"  ✅ Model performance summary saved to {summary_path}")
    
    def save_statistical_analysis(self, all_comparisons):
        """Save all statistical analysis results"""
        
        print("💾 Saving statistical analysis...")
        
        for metric, comparisons in all_comparisons.items():
            # Save pairwise comparisons
            pairwise_path = self.dirs['comparisons'] / f'{metric}_pairwise_comparisons.csv'
            comparisons['pairwise'].to_csv(pairwise_path, index=False)
            
            # Save SAM vs established comparisons
            sam_vs_est_path = self.dirs['comparisons'] / f'{metric}_sam_vs_established.csv'
            comparisons['sam_vs_established'].to_csv(sam_vs_est_path, index=False)
            
            print(f"  ✅ {metric} statistical comparisons saved")
        
        # Create statistical summary
        stat_summary = []
        for metric, comparisons in all_comparisons.items():
            pairwise = comparisons['pairwise']
            sam_vs_est = comparisons['sam_vs_established']
            
            summary = {
                'metric': metric,
                'total_comparisons': len(pairwise),
                'significant_t_test': pairwise['t_significant'].sum(),
                'significant_wilcoxon': pairwise['w_significant'].sum(),
                'sam_vs_est_comparisons': len(sam_vs_est),
                'sam_wins': (sam_vs_est['sam_better'] == True).sum(),
                'established_wins': (sam_vs_est['sam_better'] == False).sum(),
                'significant_sam_vs_est': sam_vs_est['significant_t'].sum()
            }
            stat_summary.append(summary)
        
        stat_summary_df = pd.DataFrame(stat_summary)
        stat_summary_path = self.dirs['statistics'] / 'statistical_analysis_summary.csv'
        stat_summary_df.to_csv(stat_summary_path, index=False)
        print(f"  ✅ Statistical analysis summary saved to {stat_summary_path}")
    
    def save_research_conclusions(self, conclusions, ranking_df):
        """Save research conclusions and final rankings"""
        
        print("💾 Saving research conclusions...")
        
        # Save conclusions JSON
        conclusions_path = self.dirs['summaries'] / 'research_conclusions.json'
        with open(conclusions_path, 'w') as f:
            json.dump(conclusions, f, indent=2)
        
        # Save model rankings
        ranking_path = self.dirs['summaries'] / 'model_rankings.csv'
        ranking_df.to_csv(ranking_path, index=True)
        
        # Create final research report
        report_path = self.dirs['summaries'] / 'research_question_1_report.md'
        self.create_research_report(conclusions, ranking_df, report_path)
        
        print(f"  ✅ Research conclusions saved to {self.dirs['summaries']}")
    
    def create_research_report(self, conclusions, ranking_df, report_path):
        """Create a comprehensive research report in Markdown format"""
        
        report_content = f"""# Research Question 1: SAM Variants vs Established Models

## Executive Summary

This report presents the results of a comprehensive comparison between Segment Anything Model (SAM) variants and established nuclei segmentation models on the PanNuke dataset.

## Research Question
Do different variants of the Segment Anything Model (SAM), including the domain-adapted PathoSAM, achieve competitive or superior nuclei instance segmentation performance on the PanNuke dataset compared to established models such as HoVer-Net, CellViT, and LKCell?

## Hypotheses
- **H₀ (Null)**: SAM variants do not significantly outperform established models in mPQ or detection F1
- **H₁ (Alternative)**: At least one SAM variant significantly outperforms baselines in mPQ or detection F1

## Methodology
- **Dataset**: PanNuke dataset with proper train/val/test splits
- **Models Evaluated**: {len(ranking_df)} models total
  - SAM Variants: SAM-Base, SAM-Large, SAM-Huge, PathoSAM
  - Established Models: HoVer-Net, CellViT, LKCell
  - Baseline: U-Net
- **Metrics**: Mean Panoptic Quality (mPQ), Detection F1 Score, IoU, Precision, Recall
- **Statistical Analysis**: Paired t-tests, Wilcoxon signed-rank tests, multiple comparison corrections

## Key Results

### Model Rankings
{ranking_df.to_string()}

### Best Performing Model
- **Model**: {conclusions['best_model']}
- **SAM Win Rate**: {conclusions['sam_win_rate']:.2%}
- **Significant Comparisons**: {conclusions['significant_comparisons']}
- **SAM Wins**: {conclusions['sam_wins']}
- **Established Model Wins**: {conclusions['established_wins']}

### Statistical Significance
- Total significant comparisons found: {conclusions['significant_comparisons']}
- SAM variants won: {conclusions['sam_wins']} comparisons
- Established models won: {conclusions['established_wins']} comparisons

## Conclusion
"""
        
        if conclusions['significant_comparisons'] > 0:
            if conclusions['sam_win_rate'] > 0.5:
                report_content += "**H₁ SUPPORTED**: SAM variants significantly outperform established models in nuclei segmentation tasks.\n"
            else:
                report_content += "**H₀ SUPPORTED**: SAM variants do not significantly outperform established models in nuclei segmentation tasks.\n"
        else:
            report_content += "**INCONCLUSIVE**: No significant differences found between SAM variants and established models.\n"
        
        report_content += f"""
## Files Generated
- Model weights: `{self.dirs['models']}`
- Evaluation results: `{self.dirs['results']}`
- Statistical analysis: `{self.dirs['statistics']}`
- Visualizations: `{self.dirs['plots']}`
- Comparisons: `{self.dirs['comparisons']}`
- Summaries: `{self.dirs['summaries']}`

## Reproducibility
All code, data, and results are saved in the artifacts directory for full reproducibility.
"""
        
        with open(report_path, 'w') as f:
            f.write(report_content)
        
        print(f"  ✅ Research report saved to {report_path}")
    
    def save_configuration(self, config):
        """Save experiment configuration"""
        
        print("💾 Saving experiment configuration...")
        
        config_dict = {
            'data_root': config.data_root,
            'artifacts_dir': config.artifacts_dir,
            'batch_size': config.batch_size,
            'num_workers': config.num_workers,
            'image_size': config.image_size,
            'learning_rate': config.learning_rate,
            'epochs': config.epochs,
            'patience': config.patience,
            'confidence_threshold': config.confidence_threshold,
            'iou_threshold': config.iou_threshold,
            'alpha': config.alpha,
            'n_bootstrap': config.n_bootstrap,
            'sam_models': config.sam_models
        }
        
        config_path = self.dirs['configs'] / 'experiment_config.json'
        with open(config_path, 'w') as f:
            json.dump(config_dict, f, indent=2)
        
        print(f"  ✅ Configuration saved to {config_path}")
    
    def create_artifacts_summary(self):
        """Create a summary of all saved artifacts"""
        
        summary = {
            'experiment': 'Research Question 1: SAM Variants vs Established Models',
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
            'artifacts_directory': str(self.rq1_dir),
            'directories': {name: str(path) for name, path in self.dirs.items()},
            'files_generated': {}
        }
        
        # Count files in each directory
        for dir_name, dir_path in self.dirs.items():
            if dir_path.exists():
                files = list(dir_path.glob('*'))
                summary['files_generated'][dir_name] = len(files)
        
        # Save summary
        summary_path = self.rq1_dir / 'artifacts_summary.json'
        with open(summary_path, 'w') as f:
            json.dump(summary, f, indent=2)
        
        print(f"\n📋 ARTIFACTS SUMMARY:")
        print(f"  📁 Base directory: {self.rq1_dir}")
        print(f"  📊 Files generated:")
        for dir_name, count in summary['files_generated'].items():
            print(f"    {dir_name}: {count} files")
        
        return summary

# Initialize artifact manager
artifact_manager = ArtifactManager(artifacts_dir.parent)

print("🚀 Starting comprehensive artifact management...")

# Save all results
artifact_manager.save_model_results(trained_models, training_histories)
artifact_manager.save_evaluation_results(all_results, combined_results_df)
artifact_manager.save_statistical_analysis(all_comparisons)
artifact_manager.save_research_conclusions(conclusions, conclusions['ranking_df'])
artifact_manager.save_configuration(config)

# Create final summary
final_summary = artifact_manager.create_artifacts_summary()

print(f"\n✅ ALL ARTIFACTS SAVED SUCCESSFULLY!")
print(f"📁 Main artifacts directory: {artifact_manager.rq1_dir}")
print(f"📋 Summary saved to: {artifact_manager.rq1_dir / 'artifacts_summary.json'}")
