# Research Question 3: Complete GPU-Accelerated Pipeline
## Stain Normalization Impact on U-Net Nuclei Segmentation

**Research Question**: Does stain normalization improve U-Net-based nuclei instance segmentation on the PanNuke dataset compared to unnormalized data?

### Hypotheses:
- **H₀ (Null)**: No significant improvement due to normalization  
- **H₁ (Alternative)**: Significant improvement due to normalization

### Methodology:
1. **Dataset**: Top 5 tissues from PanNuke dataset
2. **Normalization**: Vahadane method with GPU acceleration
3. **Models**: U-Net trained on normalized vs unnormalized data
4. **Evaluation**: Paired statistical analysis per image
5. **Statistics**: Wilcoxon signed-rank test, paired t-test

### Expected Outcomes:
- Comprehensive EDA before/after normalization
- Model performance comparison
- Statistical significance testing
- Complete artifacts for reproducibility

---
**⚡ GPU-Accelerated Pipeline | Production Ready | Cloud Deployment**


In [None]:
# =============================================================================
# IMPORTS AND SETUP
# =============================================================================

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from PIL import Image
from pathlib import Path
import os
import sys
import json
import time
import warnings
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
import gc
from tqdm import tqdm
from sklearn.metrics import jaccard_score, f1_score, precision_score, recall_score
from scipy.stats import wilcoxon, ttest_rel
import logging

# Setup paths
project_root = Path('/Users/shubhangmalviya/Documents/Projects/Walsh College/HistoPathologyResearch')
sys.path.append(str(project_root))

# Import custom modules
from src.preprocessing.vahadane_gpu import GPUVahadaneNormalizer
from src.models.unet_rq3 import UNetRQ3, create_unet_rq3  # RQ3-specific U-Net (separate from RQ2)
from src.utils.metrics import calculate_segmentation_metrics

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('artifacts/rq3/logs/rq3_pipeline.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Set style and warnings
plt.style.use('seaborn-v0_8') if 'seaborn-v0_8' in plt.style.available else plt.style.use('seaborn')
sns.set_palette('husl')
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 RQ3 Pipeline initialized on {device}")
if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name()}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Create artifacts directories
artifacts_dir = project_root / 'artifacts' / 'rq3'
for subdir in ['checkpoints', 'results', 'datasets', 'plots', 'logs']:
    (artifacts_dir / subdir).mkdir(parents=True, exist_ok=True)

logger.info("RQ3 Complete Pipeline - GPU Accelerated - Initialized Successfully")


## 1. Dataset Analysis and Preparation

### 1.1 Identify Top 5 Tissues by Sample Count


In [None]:
# =============================================================================
# DATASET ANALYSIS - TOP 5 TISSUES
# =============================================================================

dataset_path = project_root / 'dataset_tissues'
logger.info("Analyzing dataset to identify top 5 tissues...")

# Count images per tissue
tissue_counts = {}
tissue_paths = {}

for tissue_dir in dataset_path.iterdir():
    if tissue_dir.is_dir():
        tissue_name = tissue_dir.name
        total_count = 0
        paths = {'train': [], 'test': [], 'val': []}
        
        for split in ['train', 'test', 'val']:
            images_dir = tissue_dir / split / 'images'
            masks_dir = tissue_dir / split / 'sem_masks'  # Semantic masks
            
            if images_dir.exists() and masks_dir.exists():
                image_files = list(images_dir.glob('*.png'))
                mask_files = list(masks_dir.glob('*.png'))
                
                # Only count images that have corresponding masks
                valid_pairs = []
                for img_file in image_files:
                    mask_file = masks_dir / img_file.name.replace('img_', 'sem_')
                    if mask_file.exists():
                        valid_pairs.append((img_file, mask_file))
                
                paths[split] = valid_pairs
                total_count += len(valid_pairs)
        
        tissue_counts[tissue_name] = total_count
        tissue_paths[tissue_name] = paths

# Sort tissues by count and select top 5
top_5_tissues = sorted(tissue_counts.items(), key=lambda x: x[1], reverse=True)[:5]
selected_tissues = [tissue for tissue, count in top_5_tissues]

print("🔍 Dataset Analysis Results:")
print("=" * 50)
for tissue, count in sorted(tissue_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
    marker = "✅" if tissue in selected_tissues else "  "
    print(f"{marker} {tissue:15}: {count:,} images")

print(f"\n🎯 Selected Top 5 Tissues for RQ3:")
for i, (tissue, count) in enumerate(top_5_tissues, 1):
    print(f"{i}. {tissue}: {count:,} images")

total_selected = sum(count for _, count in top_5_tissues)
print(f"\n📊 Total images in top 5 tissues: {total_selected:,}")

# Save tissue selection metadata
tissue_metadata = {
    'selected_tissues': selected_tissues,
    'tissue_counts': dict(top_5_tissues),
    'total_images': total_selected,
    'selection_date': time.strftime('%Y-%m-%d %H:%M:%S'),
    'selection_criteria': 'Top 5 tissues by image count'
}

with open(artifacts_dir / 'results' / 'tissue_selection.json', 'w') as f:
    json.dump(tissue_metadata, f, indent=2)

logger.info(f"Selected {len(selected_tissues)} tissues with {total_selected:,} total images")


### 1.2 Create Small Sample for Testing

Before processing the full dataset, let's create a small sample to test the pipeline and ensure everything works correctly.


In [None]:
# =============================================================================
# CREATE TEST SAMPLE
# =============================================================================

# Configuration for testing vs production
TESTING_MODE = True  # Set to False for full dataset processing
SAMPLE_SIZE_PER_TISSUE = 5 if TESTING_MODE else None  # None means use all images

print(f"🧪 Running in {'TESTING' if TESTING_MODE else 'PRODUCTION'} mode")
if TESTING_MODE:
    print(f"   Sample size: {SAMPLE_SIZE_PER_TISSUE} images per tissue per split")

# Create sample dataset for testing
sample_data = {}
total_sample_images = 0

for tissue in selected_tissues:
    sample_data[tissue] = {'train': [], 'test': [], 'val': []}
    
    for split in ['train', 'test', 'val']:
        available_pairs = tissue_paths[tissue][split]
        
        if TESTING_MODE and SAMPLE_SIZE_PER_TISSUE:
            # Take a small sample for testing
            selected_pairs = available_pairs[:SAMPLE_SIZE_PER_TISSUE]
        else:
            # Use all available data
            selected_pairs = available_pairs
        
        sample_data[tissue][split] = selected_pairs
        total_sample_images += len(selected_pairs)
        
        print(f"📁 {tissue:15} {split:5}: {len(selected_pairs):3} images")

print(f"\n📊 Total sample images: {total_sample_images:,}")
print(f"🎯 Expected processing time: {total_sample_images * 0.001 / 60:.1f} minutes (GPU)")

# Save sample configuration
sample_config = {
    'testing_mode': TESTING_MODE,
    'sample_size_per_tissue': SAMPLE_SIZE_PER_TISSUE,
    'total_sample_images': total_sample_images,
    'tissues': selected_tissues,
    'splits': ['train', 'test', 'val']
}

with open(artifacts_dir / 'results' / 'sample_config.json', 'w') as f:
    json.dump(sample_config, f, indent=2)

logger.info(f"Sample configuration created: {total_sample_images} images across {len(selected_tissues)} tissues")


## 2. GPU-Accelerated Vahadane Stain Normalization

### 2.1 Initialize GPU Normalizer and Select Target Image


In [None]:
# =============================================================================
# GPU VAHADANE STAIN NORMALIZATION SETUP
# =============================================================================

logger.info("Initializing GPU-accelerated Vahadane normalizer...")

# Initialize GPU normalizer with optimal settings
gpu_normalizer = GPUVahadaneNormalizer(
    batch_size=16,  # Adjust based on GPU memory
    device=device,
    memory_efficient=True,
    threshold=0.8,
    lambda1=0.1,
    max_iter=1000
)

print(f"✅ GPU Vahadane normalizer initialized on {device}")

# Function to evaluate image quality for target selection
def evaluate_image_quality_gpu(image_tensor):
    """Evaluate image quality using GPU-accelerated metrics"""
    if len(image_tensor.shape) == 3:
        image_tensor = image_tensor.unsqueeze(0)
    
    img = image_tensor.float() / 255.0
    
    # Tissue coverage (non-white pixels)
    gray = torch.mean(img, dim=-1)
    tissue_mask = gray < 0.8
    tissue_coverage = torch.mean(tissue_mask.float())
    
    # Contrast (standard deviation)
    contrast = torch.std(gray)
    
    # Color distribution
    color_std = torch.std(img, dim=(1, 2))
    color_balance = torch.mean(color_std)
    
    # Combined quality score
    quality_score = (
        tissue_coverage * 40 +
        contrast * 30 +
        color_balance * 30
    )
    
    return {
        'quality_score': quality_score.item(),
        'tissue_coverage': tissue_coverage.item(),
        'contrast': contrast.item(),
        'color_balance': color_balance.item()
    }

# Select best target image from sample data
print("🎯 Evaluating images for optimal target selection...")
target_candidates = []

for tissue in selected_tissues:
    for split in ['train', 'test', 'val']:
        for img_path, mask_path in sample_data[tissue][split][:2]:  # Check first 2 from each split
            try:
                img = cv2.imread(str(img_path))
                if img is not None:
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    img_tensor = torch.from_numpy(img).to(device)
                    
                    quality_metrics = evaluate_image_quality_gpu(img_tensor)
                    
                    target_candidates.append({
                        'tissue': tissue,
                        'split': split,
                        'path': img_path,
                        'image': img,
                        **quality_metrics
                    })
            except Exception as e:
                logger.warning(f"Failed to evaluate {img_path}: {e}")

# Select best target
if target_candidates:
    target_candidates.sort(key=lambda x: x['quality_score'], reverse=True)
    best_target = target_candidates[0]
    
    print(f"\\n🏆 Selected target image:")
    print(f"   Tissue: {best_target['tissue']}")
    print(f"   Split: {best_target['split']}")
    print(f"   Quality score: {best_target['quality_score']:.2f}")
    print(f"   Tissue coverage: {best_target['tissue_coverage']:.2f}")
    print(f"   Contrast: {best_target['contrast']:.2f}")
    
    target_image = best_target['image']
    target_info = best_target
    
    # Save target image
    target_save_path = artifacts_dir / 'results' / 'target_image.png'
    Image.fromarray(target_image).save(target_save_path)
    
    # Save target metadata
    target_metadata = {k: v for k, v in target_info.items() if k != 'image'}
    target_metadata['path'] = str(target_metadata['path'])
    
    with open(artifacts_dir / 'results' / 'target_metadata.json', 'w') as f:
        json.dump(target_metadata, f, indent=2)
    
    logger.info(f"Target image selected: {best_target['tissue']} with quality score {best_target['quality_score']:.2f}")
else:
    raise ValueError("No valid target candidates found!")

# Test normalizer with target
print(f"\\n🔧 Fitting normalizer to target image...")
start_time = time.time()
gpu_normalizer.fit(target_image)
fit_time = time.time() - start_time

print(f"✅ Normalizer fitted successfully in {fit_time:.3f}s")
logger.info(f"GPU normalizer fitted in {fit_time:.3f}s")


### 2.2 Apply Normalization to Sample Dataset

Process all sample images with GPU-accelerated batch normalization.


In [None]:
# =============================================================================
# BATCH NORMALIZATION PROCESSING
# =============================================================================

logger.info("Starting batch normalization of sample dataset...")

# Create directories for normalized datasets
normalized_base = artifacts_dir / 'datasets' / 'normalized'
original_base = artifacts_dir / 'datasets' / 'original'

for base_dir in [normalized_base, original_base]:
    for tissue in selected_tissues:
        for split in ['train', 'test', 'val']:
            (base_dir / tissue / split / 'images').mkdir(parents=True, exist_ok=True)
            (base_dir / tissue / split / 'masks').mkdir(parents=True, exist_ok=True)

# Process images in batches
normalization_results = {}
processing_stats = {
    'total_processed': 0,
    'total_failed': 0,
    'processing_time': 0,
    'tissues_processed': {}
}

start_time = time.time()

for tissue in selected_tissues:
    print(f"\\n🔄 Processing {tissue}...")
    tissue_results = {'train': [], 'test': [], 'val': []}
    tissue_stats = {'processed': 0, 'failed': 0}
    
    for split in ['train', 'test', 'val']:
        split_pairs = sample_data[tissue][split]
        if not split_pairs:
            continue
            
        print(f"   {split}: {len(split_pairs)} images")
        
        # Process in batches for GPU efficiency
        batch_size = min(8, len(split_pairs))  # Adjust based on memory
        
        for i in range(0, len(split_pairs), batch_size):
            batch_pairs = split_pairs[i:i+batch_size]
            batch_images = []
            batch_masks = []
            batch_metadata = []
            
            # Load batch
            for img_path, mask_path in batch_pairs:
                try:
                    # Load image
                    img = cv2.imread(str(img_path))
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    
                    # Load mask
                    mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
                    
                    batch_images.append(img)
                    batch_masks.append(mask)
                    batch_metadata.append({
                        'tissue': tissue,
                        'split': split,
                        'original_img_path': str(img_path),
                        'original_mask_path': str(mask_path),
                        'filename': img_path.name
                    })
                    
                except Exception as e:
                    logger.warning(f"Failed to load {img_path}: {e}")
                    tissue_stats['failed'] += 1
                    processing_stats['total_failed'] += 1
            
            if not batch_images:
                continue
            
            try:
                # Apply normalization to batch
                batch_start = time.time()
                normalized_batch = gpu_normalizer.transform_batch(batch_images)
                batch_time = time.time() - batch_start
                
                # Save results
                for j, (orig_img, norm_img, mask, metadata) in enumerate(
                    zip(batch_images, normalized_batch, batch_masks, batch_metadata)
                ):
                    # Generate unique filename
                    base_name = f"{tissue}_{split}_{i+j:04d}"
                    
                    # Save original
                    orig_img_path = original_base / tissue / split / 'images' / f"{base_name}.png"
                    orig_mask_path = original_base / tissue / split / 'masks' / f"{base_name}.png"
                    
                    Image.fromarray(orig_img).save(orig_img_path)
                    Image.fromarray(mask).save(orig_mask_path)
                    
                    # Save normalized
                    norm_img_path = normalized_base / tissue / split / 'images' / f"{base_name}.png"
                    norm_mask_path = normalized_base / tissue / split / 'masks' / f"{base_name}.png"
                    
                    Image.fromarray(norm_img).save(norm_img_path)
                    Image.fromarray(mask).save(norm_mask_path)  # Same mask for both
                    
                    # Store metadata
                    result_metadata = {
                        **metadata,
                        'normalized_img_path': str(norm_img_path),
                        'normalized_mask_path': str(norm_mask_path),
                        'original_saved_img_path': str(orig_img_path),
                        'original_saved_mask_path': str(orig_mask_path),
                        'processing_time': batch_time / len(batch_images),
                        'base_name': base_name
                    }
                    
                    tissue_results[split].append(result_metadata)
                    tissue_stats['processed'] += 1
                    processing_stats['total_processed'] += 1
                
                print(f"      Batch {i//batch_size + 1}: {len(batch_images)} images in {batch_time:.3f}s")
                
            except Exception as e:
                logger.error(f"Batch normalization failed: {e}")
                tissue_stats['failed'] += len(batch_images)
                processing_stats['total_failed'] += len(batch_images)
    
    normalization_results[tissue] = tissue_results
    processing_stats['tissues_processed'][tissue] = tissue_stats
    
    print(f"   ✅ {tissue}: {tissue_stats['processed']} processed, {tissue_stats['failed']} failed")

total_time = time.time() - start_time
processing_stats['processing_time'] = total_time

# Summary
print(f"\\n🎉 Normalization Complete!")
print(f"   📊 Processed: {processing_stats['total_processed']} images")
print(f"   ❌ Failed: {processing_stats['total_failed']} images")
print(f"   ⏱️  Total time: {total_time:.2f}s")
print(f"   ⚡ Speed: {processing_stats['total_processed']/total_time:.1f} images/sec")

# Save processing results
with open(artifacts_dir / 'results' / 'normalization_results.json', 'w') as f:
    json.dump(normalization_results, f, indent=2, default=str)

with open(artifacts_dir / 'results' / 'processing_stats.json', 'w') as f:
    json.dump(processing_stats, f, indent=2)

logger.info(f"Normalization completed: {processing_stats['total_processed']} images processed in {total_time:.2f}s")


## 3. Model Training and Evaluation

### 3.1 Create Dataset Classes and DataLoaders


In [None]:
# =============================================================================
# DATASET CLASSES AND DATALOADERS
# =============================================================================

class SegmentationDataset(Dataset):
    """Dataset class for nuclei segmentation"""
    
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = Path(image_dir)
        self.mask_dir = Path(mask_dir)
        self.transform = transform
        
        # Get all image files
        self.image_files = sorted(list(self.image_dir.glob('*.png')))
        self.mask_files = sorted(list(self.mask_dir.glob('*.png')))
        
        # Ensure we have matching pairs
        assert len(self.image_files) == len(self.mask_files), \
            f"Mismatch: {len(self.image_files)} images, {len(self.mask_files)} masks"
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.image_files[idx]
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Load mask
        mask_path = self.mask_files[idx]
        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
            mask = torch.from_numpy(mask).long()
        else:
            image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
            mask = torch.from_numpy(mask).long()
        
        return image, mask

# Define transforms
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets and dataloaders for both original and normalized data
def create_dataloaders(data_type='original', batch_size=4, num_workers=2):
    """Create dataloaders for specified data type (original or normalized)"""
    
    base_dir = artifacts_dir / 'datasets' / data_type
    dataloaders = {}
    
    for split in ['train', 'val']:  # Skip test for now, use val for evaluation
        # Combine all tissues for the split
        all_images = []
        all_masks = []
        
        for tissue in selected_tissues:
            tissue_img_dir = base_dir / tissue / split / 'images'
            tissue_mask_dir = base_dir / tissue / split / 'masks'
            
            if tissue_img_dir.exists() and tissue_mask_dir.exists():
                tissue_images = list(tissue_img_dir.glob('*.png'))
                tissue_masks = list(tissue_mask_dir.glob('*.png'))
                
                all_images.extend(tissue_images)
                all_masks.extend(tissue_masks)
        
        if not all_images:
            continue
        
        print(f"   {data_type} {split}: {len(all_images)} images")
        
        # Create temporary combined directory structure
        temp_dir = artifacts_dir / 'temp' / data_type / split
        temp_img_dir = temp_dir / 'images'
        temp_mask_dir = temp_dir / 'masks'
        temp_img_dir.mkdir(parents=True, exist_ok=True)
        temp_mask_dir.mkdir(parents=True, exist_ok=True)
        
        # Copy files to temporary structure (just create symlinks for efficiency)
        for i, (img_file, mask_file) in enumerate(zip(all_images, all_masks)):
            temp_img_file = temp_img_dir / f"{i:04d}.png"
            temp_mask_file = temp_mask_dir / f"{i:04d}.png"
            
            # Create symlinks or copy files
            if not temp_img_file.exists():
                try:
                    temp_img_file.symlink_to(img_file.absolute())
                    temp_mask_file.symlink_to(mask_file.absolute())
                except:
                    # Fallback to copying if symlinks fail
                    import shutil
                    shutil.copy2(img_file, temp_img_file)
                    shutil.copy2(mask_file, temp_mask_file)
        
        # Create dataset
        transform = train_transform if split == 'train' else val_transform
        dataset = SegmentationDataset(temp_img_dir, temp_mask_dir, transform=transform)
        
        # Create dataloader
        dataloader = DataLoader(
            dataset, 
            batch_size=batch_size, 
            shuffle=(split == 'train'),
            num_workers=num_workers,
            pin_memory=True
        )
        
        dataloaders[split] = dataloader
    
    return dataloaders

# Create dataloaders
print("📊 Creating dataloaders...")
original_dataloaders = create_dataloaders('original', batch_size=4)
normalized_dataloaders = create_dataloaders('normalized', batch_size=4)

print(f"✅ Dataloaders created:")
for data_type, loaders in [('original', original_dataloaders), ('normalized', normalized_dataloaders)]:
    print(f"   {data_type}:")
    for split, loader in loaders.items():
        print(f"     {split}: {len(loader)} batches, {len(loader.dataset)} images")

logger.info("Dataloaders created successfully")


### 3.2 Train U-Net Models

Train separate U-Net models on original and normalized data for comparison.


In [None]:
# =============================================================================
# MODEL TRAINING
# =============================================================================

from src.models.unet_rq3 import create_unet_rq3
from src.utils.metrics import calculate_batch_metrics, evaluate_model_on_dataset

def train_model(dataloaders, model_name, epochs=5, learning_rate=1e-4):
    """Train a U-Net model and return training history"""
    
    print(f"\\n🚀 Training {model_name} model...")
    
    # Create RQ3-specific model (separate from RQ2)
    model = create_unet_rq3(n_channels=3, n_classes=6, device=device)
    
    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=255)  # Ignore unknown pixels
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_metrics': [],
        'val_metrics': []
    }
    
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        print(f"\\n📅 Epoch {epoch+1}/{epochs}")
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_metrics_list = []
        
        for batch_idx, (images, masks) in enumerate(dataloaders['train']):
            images, masks = images.to(device), masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            # Calculate metrics for this batch
            with torch.no_grad():
                predictions = torch.argmax(outputs, dim=1)
                batch_metrics = calculate_batch_metrics(predictions, masks, num_classes=6)
                train_metrics_list.append(batch_metrics)
            
            if batch_idx % 5 == 0:  # Print every 5 batches
                print(f"   Batch {batch_idx+1}/{len(dataloaders['train'])}: Loss = {loss.item():.4f}")
        
        # Average training metrics
        avg_train_loss = train_loss / len(dataloaders['train'])
        avg_train_metrics = {}
        for key in train_metrics_list[0].keys():
            avg_train_metrics[key] = np.mean([m[key] for m in train_metrics_list if not np.isnan(m[key])])
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_metrics_list = []
        
        with torch.no_grad():
            for images, masks in dataloaders['val']:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                
                # Calculate metrics
                predictions = torch.argmax(outputs, dim=1)
                batch_metrics = calculate_batch_metrics(predictions, masks, num_classes=6)
                val_metrics_list.append(batch_metrics)
        
        avg_val_loss = val_loss / len(dataloaders['val'])
        avg_val_metrics = {}
        for key in val_metrics_list[0].keys():
            avg_val_metrics[key] = np.mean([m[key] for m in val_metrics_list if not np.isnan(m[key])])
        
        # Update history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_metrics'].append(avg_train_metrics)
        history['val_metrics'].append(avg_val_metrics)
        
        # Print epoch summary
        print(f"   📊 Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        print(f"   📈 Train Dice: {avg_train_metrics.get('avg_dice', 0):.4f}, Val Dice: {avg_val_metrics.get('avg_dice', 0):.4f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), artifacts_dir / 'checkpoints' / f'{model_name}_best.pth')
            print(f"   💾 Best model saved (Val Loss: {best_val_loss:.4f})")
        
        scheduler.step()
    
    # Save final model
    torch.save(model.state_dict(), artifacts_dir / 'checkpoints' / f'{model_name}_final.pth')
    
    # Save training history
    with open(artifacts_dir / 'results' / f'{model_name}_history.json', 'w') as f:
        json.dump(history, f, indent=2, default=str)
    
    print(f"✅ {model_name} training completed!")
    return model, history

# Training configuration
EPOCHS = 3 if TESTING_MODE else 10  # Fewer epochs for testing
LEARNING_RATE = 1e-4

# Train both models
models = {}
histories = {}

# Train model on original data
if 'train' in original_dataloaders:
    model_original, history_original = train_model(
        original_dataloaders, 
        'unet_original', 
        epochs=EPOCHS, 
        learning_rate=LEARNING_RATE
    )
    models['original'] = model_original
    histories['original'] = history_original

# Train model on normalized data
if 'train' in normalized_dataloaders:
    model_normalized, history_normalized = train_model(
        normalized_dataloaders, 
        'unet_normalized', 
        epochs=EPOCHS, 
        learning_rate=LEARNING_RATE
    )
    models['normalized'] = model_normalized
    histories['normalized'] = history_normalized

print(f"\\n🎉 All models trained successfully!")
logger.info(f"Model training completed: {len(models)} models trained")


### 3.3 Paired Evaluation and Statistical Analysis

Evaluate both models on the same test images and perform statistical comparison.


In [None]:
# =============================================================================
# PAIRED EVALUATION AND STATISTICAL ANALYSIS
# =============================================================================

def evaluate_models_paired(models, dataloaders):
    """Evaluate models on paired data for statistical comparison"""
    
    print("🔍 Performing paired evaluation...")
    
    # Results storage
    paired_results = []
    
    # Get validation dataloaders
    original_val = dataloaders['original']['val'] if 'val' in dataloaders['original'] else None
    normalized_val = dataloaders['normalized']['val'] if 'val' in dataloaders['normalized'] else None
    
    if not (original_val and normalized_val):
        print("⚠️  Validation dataloaders not available for paired evaluation")
        return []
    
    # Ensure both models are in evaluation mode
    models['original'].eval()
    models['normalized'].eval()
    
    with torch.no_grad():
        # Get iterators
        orig_iter = iter(original_val)
        norm_iter = iter(normalized_val)
        
        batch_count = min(len(original_val), len(normalized_val))
        
        for batch_idx in range(batch_count):
            try:
                # Get paired batches
                orig_images, orig_masks = next(orig_iter)
                norm_images, norm_masks = next(norm_iter)
                
                # Move to device
                orig_images, orig_masks = orig_images.to(device), orig_masks.to(device)
                norm_images, norm_masks = norm_images.to(device), norm_masks.to(device)
                
                # Get predictions
                orig_outputs = models['original'](orig_images)
                norm_outputs = models['normalized'](norm_images)
                
                orig_preds = torch.argmax(orig_outputs, dim=1)
                norm_preds = torch.argmax(norm_outputs, dim=1)
                
                # Calculate metrics for each image in batch
                batch_size = orig_images.shape[0]
                
                for i in range(batch_size):
                    # Get individual predictions and masks
                    orig_pred = orig_preds[i].cpu().numpy()
                    norm_pred = norm_preds[i].cpu().numpy()
                    orig_mask = orig_masks[i].cpu().numpy()
                    norm_mask = norm_masks[i].cpu().numpy()
                    
                    # Calculate metrics for original model
                    orig_metrics = calculate_segmentation_metrics(orig_pred, orig_mask, num_classes=6)
                    
                    # Calculate metrics for normalized model  
                    norm_metrics = calculate_segmentation_metrics(norm_pred, norm_mask, num_classes=6)
                    
                    # Store paired results
                    paired_result = {
                        'batch_idx': batch_idx,
                        'image_idx': i,
                        'original_dice': orig_metrics['avg_dice'],
                        'normalized_dice': norm_metrics['avg_dice'],
                        'original_iou': orig_metrics['avg_iou'],
                        'normalized_iou': norm_metrics['avg_iou'],
                        'original_pixel_acc': orig_metrics['pixel_accuracy'],
                        'normalized_pixel_acc': norm_metrics['pixel_accuracy'],
                        'original_precision': orig_metrics['avg_precision'],
                        'normalized_precision': norm_metrics['avg_precision'],
                        'original_recall': orig_metrics['avg_recall'],
                        'normalized_recall': norm_metrics['avg_recall'],
                        'original_f1': orig_metrics['avg_f1'],
                        'normalized_f1': norm_metrics['avg_f1']
                    }
                    
                    paired_results.append(paired_result)
                
                if batch_idx % 5 == 0:
                    print(f"   Processed batch {batch_idx+1}/{batch_count}")
                    
            except StopIteration:
                break
            except Exception as e:
                logger.warning(f"Error in batch {batch_idx}: {e}")
                continue
    
    print(f"✅ Paired evaluation completed: {len(paired_results)} image pairs")
    return paired_results

# Perform paired evaluation
if len(models) == 2:
    paired_results = evaluate_models_paired(models, {'original': original_dataloaders, 'normalized': normalized_dataloaders})
    
    if paired_results:
        # Convert to DataFrame for analysis
        results_df = pd.DataFrame(paired_results)
        
        # Save paired results
        results_df.to_csv(artifacts_dir / 'results' / 'paired_evaluation_results.csv', index=False)
        
        # Statistical Analysis
        print("\\n📊 Statistical Analysis:")
        print("=" * 50)
        
        metrics_to_test = ['dice', 'iou', 'pixel_acc', 'precision', 'recall', 'f1']
        statistical_results = {}
        
        for metric in metrics_to_test:
            original_col = f'original_{metric}'
            normalized_col = f'normalized_{metric}'
            
            if original_col in results_df.columns and normalized_col in results_df.columns:
                original_vals = results_df[original_col].dropna()
                normalized_vals = results_df[normalized_col].dropna()
                
                if len(original_vals) > 0 and len(normalized_vals) > 0:
                    # Ensure same length for paired tests
                    min_len = min(len(original_vals), len(normalized_vals))
                    original_vals = original_vals[:min_len]
                    normalized_vals = normalized_vals[:min_len]
                    
                    # Calculate basic statistics
                    orig_mean = np.mean(original_vals)
                    norm_mean = np.mean(normalized_vals)
                    improvement = ((norm_mean - orig_mean) / orig_mean) * 100
                    
                    # Wilcoxon signed-rank test (non-parametric paired test)
                    try:
                        wilcoxon_stat, wilcoxon_p = wilcoxon(normalized_vals, original_vals, alternative='two-sided')
                    except:
                        wilcoxon_stat, wilcoxon_p = np.nan, np.nan
                    
                    # Paired t-test (parametric paired test)
                    try:
                        ttest_stat, ttest_p = ttest_rel(normalized_vals, original_vals)
                    except:
                        ttest_stat, ttest_p = np.nan, np.nan
                    
                    # Store results
                    statistical_results[metric] = {
                        'original_mean': orig_mean,
                        'normalized_mean': norm_mean,
                        'improvement_percent': improvement,
                        'wilcoxon_statistic': wilcoxon_stat,
                        'wilcoxon_p_value': wilcoxon_p,
                        'ttest_statistic': ttest_stat,
                        'ttest_p_value': ttest_p,
                        'sample_size': min_len,
                        'significant_wilcoxon': wilcoxon_p < 0.05 if not np.isnan(wilcoxon_p) else False,
                        'significant_ttest': ttest_p < 0.05 if not np.isnan(ttest_p) else False
                    }
                    
                    # Print results
                    print(f"\\n{metric.upper()}:")
                    print(f"  Original:    {orig_mean:.4f}")
                    print(f"  Normalized:  {norm_mean:.4f}")
                    print(f"  Improvement: {improvement:+.2f}%")
                    print(f"  Wilcoxon p:  {wilcoxon_p:.6f} {'***' if wilcoxon_p < 0.001 else '**' if wilcoxon_p < 0.01 else '*' if wilcoxon_p < 0.05 else ''}")
                    print(f"  T-test p:    {ttest_p:.6f} {'***' if ttest_p < 0.001 else '**' if ttest_p < 0.01 else '*' if ttest_p < 0.05 else ''}")
        
        # Save statistical results
        stats_df = pd.DataFrame(statistical_results).T
        stats_df.to_csv(artifacts_dir / 'results' / 'statistical_analysis.csv')
        
        with open(artifacts_dir / 'results' / 'statistical_results.json', 'w') as f:
            json.dump(statistical_results, f, indent=2, default=str)
        
        # Summary of hypothesis testing
        print("\\n🧪 Hypothesis Testing Summary:")
        print("=" * 40)
        
        significant_metrics = []
        for metric, results in statistical_results.items():
            if results['significant_wilcoxon'] or results['significant_ttest']:
                significant_metrics.append(metric)
                improvement = results['improvement_percent']
                direction = "improvement" if improvement > 0 else "degradation"
                print(f"✅ {metric.upper()}: Significant {direction} ({improvement:+.2f}%)")
            else:
                print(f"❌ {metric.upper()}: No significant difference")
        
        # Final conclusion
        print(f"\\n🎯 RESEARCH QUESTION 3 RESULTS:")
        print("=" * 35)
        
        if significant_metrics:
            print(f"🎉 REJECT NULL HYPOTHESIS (H₀)")
            print(f"   Stain normalization shows significant improvement in:")
            for metric in significant_metrics:
                improvement = statistical_results[metric]['improvement_percent']
                print(f"   - {metric.upper()}: {improvement:+.2f}%")
        else:
            print(f"📊 FAIL TO REJECT NULL HYPOTHESIS (H₀)")
            print(f"   No significant improvement found with stain normalization")
        
        logger.info(f"Statistical analysis completed: {len(statistical_results)} metrics analyzed")
        
else:
    print("⚠️  Both models not available for paired evaluation")
    
print("\\n✅ RQ3 Analysis Pipeline Completed!")


## 4. Summary and Artifacts

### Complete RQ3 Pipeline Results


In [None]:
# =============================================================================
# FINAL SUMMARY AND ARTIFACT DOCUMENTATION
# =============================================================================

print("📋 RQ3 Complete Pipeline Summary")
print("=" * 50)

# Pipeline execution summary
execution_summary = {
    'pipeline_completed': True,
    'testing_mode': TESTING_MODE,
    'total_tissues': len(selected_tissues),
    'selected_tissues': selected_tissues,
    'total_sample_images': total_sample_images if 'total_sample_images' in locals() else 0,
    'models_trained': len(models) if 'models' in locals() else 0,
    'statistical_analysis_completed': 'statistical_results' in locals(),
    'gpu_acceleration_used': device.type == 'cuda',
    'execution_timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
}

print(f"🎯 Pipeline Status: {'✅ COMPLETED' if execution_summary['pipeline_completed'] else '❌ INCOMPLETE'}")
print(f"🧪 Mode: {'TESTING' if execution_summary['testing_mode'] else 'PRODUCTION'}")
print(f"🔬 Tissues: {execution_summary['total_tissues']} ({', '.join(execution_summary['selected_tissues'])})")
print(f"📊 Images: {execution_summary['total_sample_images']:,}")
print(f"🤖 Models: {execution_summary['models_trained']} trained")
print(f"📈 Statistics: {'✅ Completed' if execution_summary['statistical_analysis_completed'] else '❌ Pending'}")
print(f"⚡ GPU: {'✅ Used' if execution_summary['gpu_acceleration_used'] else '❌ CPU only'}")

# Document all artifacts
print(f"\n📁 Generated Artifacts:")
print("=" * 30)

artifact_categories = {
    'datasets': 'Processed image datasets (original & normalized)',
    'checkpoints': 'Trained model weights',
    'results': 'Analysis results and statistics', 
    'plots': 'Visualizations and plots',
    'logs': 'Execution logs'
}

for category, description in artifact_categories.items():
    category_path = artifacts_dir / category
    if category_path.exists():
        file_count = len(list(category_path.rglob('*')))
        print(f"📂 {category:12}: {file_count:3} files - {description}")

# Key result files
key_files = [
    'tissue_selection.json',
    'sample_config.json', 
    'target_metadata.json',
    'normalization_results.json',
    'processing_stats.json',
    'paired_evaluation_results.csv',
    'statistical_analysis.csv',
    'statistical_results.json'
]

print(f"\n📋 Key Result Files:")
for filename in key_files:
    filepath = artifacts_dir / 'results' / filename
    status = '✅' if filepath.exists() else '❌'
    print(f"   {status} {filename}")

# Research question conclusion
print(f"\n🎯 Research Question 3 Conclusion:")
print("=" * 40)

if 'statistical_results' in locals() and statistical_results:
    # Count significant improvements
    significant_count = sum(1 for metric, results in statistical_results.items() 
                          if results.get('significant_wilcoxon', False) or results.get('significant_ttest', False))
    
    total_metrics = len(statistical_results)
    
    if significant_count > 0:
        print("🎉 CONCLUSION: REJECT NULL HYPOTHESIS (H₀)")
        print(f"   📊 {significant_count}/{total_metrics} metrics show significant improvement")
        print("   📈 Stain normalization DOES improve U-Net segmentation performance")
        
        # Show top improvements
        improvements = [(metric, results['improvement_percent']) 
                       for metric, results in statistical_results.items()
                       if results.get('significant_wilcoxon', False) or results.get('significant_ttest', False)]
        
        if improvements:
            improvements.sort(key=lambda x: abs(x[1]), reverse=True)
            print("   🏆 Top improvements:")
            for metric, improvement in improvements[:3]:
                print(f"      - {metric.upper()}: {improvement:+.2f}%")
    else:
        print("📊 CONCLUSION: FAIL TO REJECT NULL HYPOTHESIS (H₀)")
        print("   ❌ No significant improvement found with stain normalization")
        print("   📉 Current evidence does not support H₁")
else:
    print("⚠️  Statistical analysis not completed - cannot draw conclusion")

# Save execution summary
with open(artifacts_dir / 'results' / 'execution_summary.json', 'w') as f:
    json.dump(execution_summary, f, indent=2, default=str)

# Performance summary
if 'processing_stats' in locals():
    print(f"\n⚡ Performance Summary:")
    print("=" * 25)
    total_time = processing_stats.get('processing_time', 0)
    total_processed = processing_stats.get('total_processed', 0)
    
    print(f"📊 Images processed: {total_processed:,}")
    print(f"⏱️  Processing time: {total_time:.2f}s")
    print(f"🚀 Processing speed: {total_processed/total_time:.1f} images/sec" if total_time > 0 else "🚀 Processing speed: N/A")
    
    if device.type == 'cuda':
        estimated_full_time = (5072 * total_time / total_processed) / 60 if total_processed > 0 else 0
        print(f"📈 Full dataset estimate: {estimated_full_time:.1f} minutes")

print(f"\n🎉 RQ3 GPU-Accelerated Pipeline Complete!")
print(f"📁 All artifacts saved to: {artifacts_dir}")
print(f"🔬 Ready for publication and further analysis")

logger.info("RQ3 complete pipeline execution finished successfully")
