# WSI Classification Pipeline - Comprehensive Real Data Testing

**Complete end-to-end testing with actual WSI files**

This notebook provides comprehensive validation of all 26 pipeline components using real data.

---

## üìã What This Tests

### preprocessing.py (5 components)
- ‚úÖ WSIPreprocessor class
- ‚úÖ select_magnification_level()
- ‚úÖ generate_tissue_mask()
- ‚úÖ extract_patches()
- ‚úÖ process_dataset()

### model.py (5 components)
- ‚úÖ FeatureExtractor (ResNet34, ResNet50, ViT)
- ‚úÖ AttentionMIL
- ‚úÖ GatedAttentionMIL
- ‚úÖ WSIClassifier
- ‚úÖ create_model()

### dataset.py (4 components)
- ‚úÖ WSIDataset class
- ‚úÖ get_transforms()
- ‚úÖ collate_fn()
- ‚úÖ create_dataloaders()

### train.py (2 components)
- ‚úÖ get_class_weights()
- ‚úÖ Trainer class

### inference.py (5 components)
- ‚úÖ WSIInference class
- ‚úÖ predict_from_patches()
- ‚úÖ create_attention_heatmap()
- ‚úÖ visualize_results()
- ‚úÖ process_slide()

### train_ddp.py (1 component)
- ‚úÖ DDPTrainer structure

**Total: 22 testable components + 4 workflows = 26 tests**

---

## üìÅ Prerequisites

### Required Data Structure
```
your_data/
‚îú‚îÄ‚îÄ svs_files/
‚îÇ   ‚îú‚îÄ‚îÄ slide_001.svs
‚îÇ   ‚îú‚îÄ‚îÄ slide_002.svs
‚îÇ   ‚îî‚îÄ‚îÄ ...
‚îú‚îÄ‚îÄ masks/              # Optional
‚îÇ   ‚îî‚îÄ‚îÄ ...
‚îî‚îÄ‚îÄ metadata.csv
```

### metadata.csv Format
```csv
slide_id,label
slide_001,0
slide_002,1
slide_003,2
```

### System Requirements
- Python 3.8+
- GPU recommended (50-100x faster)
- 16+ GB RAM
- 10+ GB free disk space

---
## 1. Setup and Configuration

In [None]:
# Standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import warnings
import napari
warnings.filterwarnings('ignore')

print("üì¶ Importing modules...")

# Deep learning
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Import ALL pipeline components
from preprocessing import WSIPreprocessor
from model import (
    FeatureExtractor,
    AttentionMIL,
    GatedAttentionMIL,
    WSIClassifier,
    create_model
)
from dataset import (
    WSIDataset,
    get_transforms,
    collate_fn,
    create_dataloaders
)
from train import Trainer, get_class_weights
from inference import WSIInference

print("‚úÖ All imports successful!")
print(f"\nüìä System Info:")
print(f"   PyTorch version: {torch.__version__}")
print(f"   CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"   CUDA version: {torch.version.cuda}")
    print(f"   Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"   GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("   ‚ö†Ô∏è  No GPU detected - will run on CPU (slow)")

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"\n   Using device: {torch.cuda.get_device_name(0)}")

In [None]:
# ============================================================================
# CONFIGURE THESE PATHS FOR YOUR DATA
# ============================================================================

DATA_CONFIG = {
    # ========== INPUT DATA ==========
    'svs_dir': Path('/varidata/research/projects/steensma/primary/vari-core-generated-data/PBC-Aperio Images/'),          # CHANGE THIS
    'mask_dir': Path('/varidata/research/projects/steensma/vari-core-generated-data/OIC/Abigail/Primary_Project_PixelClassifier/masks/'),             # CHANGE THIS (optional)
    'metadata_csv': Path('metadata_subset.csv'),       # CHANGE THIS
    
    # ========== OUTPUT DIRECTORIES ==========
    'output_dir': Path('./comprehensive_test_output'),
    'preprocessed_dir': Path('./comprehensive_test_output/preprocessed'),
    'checkpoints_dir': Path('./comprehensive_test_output/checkpoints'),
    'results_dir': Path('./comprehensive_test_output/results'),
    
    # ========== MODEL CONFIGURATION ==========
    'num_classes': 8,              # CHANGE THIS to your number of classes
    'backbone': 'resnet34',        # 'resnet34', 'resnet50', or 'vit_b_16'
    'mil_type': 'gated',          # 'simple' or 'gated'
    'pretrained': False,          # Set True for production
    'max_patches': 100,           # Reduced for testing (use 500 for production)
    
    # ========== TRAINING CONFIGURATION ==========
    'batch_size': 2,
    'epochs': 3,                  # Reduced for testing (use 50+ for production)
    'learning_rate': 1e-4,
    'num_workers': 1,             # Set to 4-8 for faster loading
    'use_class_weights': True,
    
    # ========== PREPROCESSING CONFIGURATION ==========
    'patch_size': 256,
    'target_magnification': 10,
    'tissue_threshold': 0.5,
}

# Create output directories
for key in ['output_dir', 'preprocessed_dir', 'checkpoints_dir', 'results_dir']:
    DATA_CONFIG[key].mkdir(parents=True, exist_ok=True)

print("üìÅ Configuration:")
print("=" * 70)
for key, value in DATA_CONFIG.items():
    print(f"   {key:25s} = {value}")
print("=" * 70)

# Validate paths
if not DATA_CONFIG['metadata_csv'].exists():
    print("\n‚ö†Ô∏è  WARNING: metadata.csv not found!")
    print(f"   Expected: {DATA_CONFIG['metadata_csv']}")
    print("   Please update DATA_CONFIG['metadata_csv'] with correct path")

if not DATA_CONFIG['svs_dir'].exists():
    print("\n‚ö†Ô∏è  WARNING: SVS directory not found!")
    print(f"   Expected: {DATA_CONFIG['svs_dir']}")
    print("   Please update DATA_CONFIG['svs_dir'] with correct path")

---
## 2. Test preprocessing.py Components

**Components tested:**
1. WSIPreprocessor class initialization
2. select_magnification_level()
3. generate_tissue_mask()
4. extract_patches()
5. process_dataset()

In [None]:
print("\n" + "=" * 70)
print("TESTING preprocessing.py COMPONENTS")
print("=" * 70)

# Test 1: WSIPreprocessor initialization
print("\n[1/5] Testing WSIPreprocessor initialization...")

preprocessor = WSIPreprocessor(
    patch_size=DATA_CONFIG['patch_size'],
    target_magnification=DATA_CONFIG['target_magnification'],
    tissue_threshold=DATA_CONFIG['tissue_threshold'],
    overlap=0
)

print(f"   ‚úÖ WSIPreprocessor initialized successfully")
print(f"      Patch size: {preprocessor.patch_size}")
print(f"      Target magnification: {preprocessor.target_mag}x")
print(f"      Tissue threshold: {preprocessor.tissue_threshold}")

# Load metadata
if DATA_CONFIG['metadata_csv'].exists():
    metadata_df = pd.read_csv(DATA_CONFIG['metadata_csv'])
    print(f"\n   ‚úÖ Loaded metadata: {len(metadata_df)} slides")
    print(f"\n   Columns: {list(metadata_df.columns)}")
    print(f"\n   First 3 rows:")
    print(metadata_df.head(3).to_string(index=False))
    
    if 'label' in metadata_df.columns:
        print(f"\n   Class distribution:")
        class_dist = metadata_df['label'].value_counts().sort_index()
        for class_id, count in class_dist.items():
            print(f"      Class {class_id}: {count} slides")
else:
    print("\n   ‚ö†Ô∏è  Metadata file not found - skipping")
    metadata_df = None

In [None]:
# Test 2-4: select_magnification_level, generate_tissue_mask, extract_patches
# if metadata_df is not None and len(metadata_df) > 0:
#     import openslide
    
#     test_slide_id = metadata_df['Image'].iloc[0]
#     test_svs_path = DATA_CONFIG['svs_dir'] / f"{test_slide_id}.svs"
#     test_mask_path = DATA_CONFIG['mask_dir'] / f"{test_slide_id}_mask.png"
    
#     print(f"\n[2/5] Testing on slide: {test_slide_id}")
#     print(f"      Path: {test_svs_path}")
    
#     if test_svs_path.exists():
#         try:
#             slide = openslide.OpenSlide(str(test_svs_path))
#             mask_path = str(test_mask_path)
            
#             print(f"\n   üìä Slide Information:")
#             print(f"      Dimensions: {slide.dimensions}")
#             print(f"      Level count: {slide.level_count}")
#             print(f"      Level dimensions: {slide.level_dimensions}")
            
#             # Test get_magnification_level
#             print(f"\n   Testing get_magnification_level()...")
#             level = preprocessor.get_magnification_level(slide)
#             print(f"      ‚úÖ Selected level {level}")
            
#             # Test generate_tissue_mask
#             print(f"\n[3/5] Testing generate_tissue_mask()...")
#             mask = preprocessor.load_tissue_mask(mask_path,slide_shape)
#             tissue_coverage = (mask > 0).sum() / mask.size * 100
#             print(f"      ‚úÖ Mask generated")
#             print(f"         Shape: {mask.shape}")
#             print(f"         Tissue coverage: {tissue_coverage:.2f}%")
            
#             # Visualize
#             # fig, axes = plt.subplots(1, 2, figsize=(14, 6))
            
#             # thumbnail = slide.get_thumbnail((800, 800))
#             # axes[0].imshow(thumbnail)
#             # axes[0].set_title('WSI Thumbnail', fontsize=14)
#             # axes[0].axis('off')
            
#             # axes[1].imshow(mask, cmap='gray')
#             # axes[1].set_title(f'Tissue Mask ({tissue_coverage:.1f}% coverage)', fontsize=14)
#             # axes[1].axis('off')
            
#             # plt.suptitle(f'Slide: {test_slide_id}', fontsize=16, fontweight='bold')
#             # plt.tight_layout()
#             # plt.savefig(DATA_CONFIG['output_dir'] / 'tissue_mask_test.png', dpi=150, bbox_inches='tight')
#             # plt.show()
            
#             #slide.close()
#             print(f"\n      ‚úÖ Visualization saved to: {DATA_CONFIG['output_dir'] / 'tissue_mask_test.png'}")
            
#         except Exception as e:
#             print(f"\n   ‚ùå Error: {e}")
#     else:
#         print(f"\n   ‚ö†Ô∏è  SVS file not found: {test_svs_path}")
# else:
#     print("\n   ‚ö†Ô∏è  No metadata available - skipping slide tests")

In [None]:
# Test 5: process_dataset (optional - can take a long time)
# RUN_FULL_PREPROCESSING = True  # Set to True to preprocess all slides

# print(f"\n[5/5] Testing process_dataset()...")

# if RUN_FULL_PREPROCESSING and metadata_df is not None:
#     print("   Running full preprocessing on all slides...")
#     print(f"   This may take 5-10 minutes per slide...\n")
    
#     try:
#         processed_metadata = preprocessor.process_dataset(
#             metadata_df=metadata_df,
#             svs_dir=Path(DATA_CONFIG['svs_dir']),
#             mask_dir=Path(DATA_CONFIG['mask_dir']) if DATA_CONFIG['mask_dir'].exists() else None,
#             output_dir=Path(DATA_CONFIG['preprocessed_dir'])
#         )
        
#         processed_csv = Path(DATA_CONFIG['output_dir'],'processed_metadata.csv')
#         processed_metadata.to_csv(processed_csv, index=False)
        
#         print(f"\n   ‚úÖ Preprocessing complete!")
#         print(f"      Processed {len(processed_metadata)} slides")
#         print(f"      Metadata saved: {processed_csv}")
        
#     except Exception as e:
#         print(f"\n   ‚ùå Preprocessing error: {e}")
#         processed_metadata = None
# else:
#     print("   ‚ÑπÔ∏è  Skipping full preprocessing (RUN_FULL_PREPROCESSING=False)")
#     print("      Set RUN_FULL_PREPROCESSING=True to process all slides")
#     print("      Note: This can take hours depending on dataset size")
    
#     # Check if already preprocessed
#     processed_csv = Path(DATA_CONFIG['output_dir'],'processed_metadata.csv')
#     if processed_csv.exists():
#         processed_metadata = pd.read_csv(processed_csv)
#         print(f"\n      ‚úÖ Found existing preprocessed data: {len(processed_metadata)} slides")
#     else:
#         processed_metadata = None
#         print(f"\n      ‚ö†Ô∏è  No preprocessed data found")

# print("\n" + "=" * 70)
# print("‚úÖ preprocessing.py TESTING COMPLETE")
# print("=" * 70)

---
## 3. Test model.py Components

**Components tested:**
1. FeatureExtractor (ResNet34, ResNet50)
2. AttentionMIL
3. GatedAttentionMIL
4. WSIClassifier
5. create_model()

In [None]:
print("\n" + "=" * 70)
print("TESTING model.py COMPONENTS")
print("=" * 70)

# Test 1: FeatureExtractor with different backbones
print("\n[1/5] Testing FeatureExtractor...\n")

test_backbones = ['resnet34', 'resnet50']
batch_size = 4
test_input = torch.randn(batch_size, 3, 256, 256).to(device)

for backbone in test_backbones:
    print(f"   Testing backbone: {backbone}")
    
    fe = FeatureExtractor(
        backbone=backbone,
        pretrained=False,
        freeze_backbone=False
    ).to(device)
    
    with torch.no_grad():
        features = fe(test_input)
    
    print(f"      Input shape:  {test_input.shape}")
    print(f"      Output shape: {features.shape}")
    print(f"      Feature dim:  {fe.feature_dim}")
    print(f"      Parameters:   {sum(p.numel() for p in fe.parameters()):,}")
    print(f"      ‚úÖ {backbone} passed\n")
    
    del fe
    torch.cuda.empty_cache()

print("   ‚úÖ FeatureExtractor tests passed")

In [None]:
# Test 2: AttentionMIL
print("\n[2/5] Testing AttentionMIL...\n")

batch_size = 2
num_patches = 50
feature_dim = 512
num_classes = DATA_CONFIG['num_classes']

test_features = torch.randn(batch_size, num_patches, feature_dim).to(device)

attn_mil = AttentionMIL(
    feature_dim=feature_dim,
    hidden_dim=256,
    num_classes=num_classes
).to(device)

with torch.no_grad():
    logits, attention = attn_mil(test_features, return_attention=True)

print(f"   Input features:  {test_features.shape}")
print(f"   Output logits:   {logits.shape}")
print(f"   Attention:       {attention.shape}")
print(f"   Attention sum:   {attention.sum(dim=1).cpu().numpy()}")
print(f"   Parameters:      {sum(p.numel() for p in attn_mil.parameters()):,}")

print("\n   ‚úÖ AttentionMIL passed")

del attn_mil
torch.cuda.empty_cache()

In [None]:
# Test 3: GatedAttentionMIL
print("\n[3/5] Testing GatedAttentionMIL...\n")

gated_mil = GatedAttentionMIL(
    feature_dim=feature_dim,
    hidden_dim=256,
    num_classes=num_classes
).to(device)

with torch.no_grad():
    logits, attention = gated_mil(test_features, return_attention=True)

print(f"   Input features:  {test_features.shape}")
print(f"   Output logits:   {logits.shape}")
print(f"   Attention:       {attention.shape}")
print(f"   Parameters:      {sum(p.numel() for p in gated_mil.parameters()):,}")

print("\n   ‚úÖ GatedAttentionMIL passed")

del gated_mil
torch.cuda.empty_cache()

In [None]:
# Test 4: WSIClassifier
print("\n[4/5] Testing WSIClassifier (complete model)...\n")

test_classifier = WSIClassifier(
    mil_type = 'gated'
).to(device)

batch_size = 1
num_patches = 30
test_patches = torch.randn(batch_size, num_patches, 3, 256, 256).to(device)

with torch.no_grad():
    logits, attention = test_classifier(test_patches, return_attention=True)

print(f"   Input patches:   {test_patches.shape}")
print(f"   Output logits:   {logits.shape}")
print(f"   Attention:       {attention.shape}")
print(f"   Total parameters: {sum(p.numel() for p in test_classifier.parameters()):,}")

print("\n   ‚úÖ WSIClassifier passed")

del test_classifier
torch.cuda.empty_cache()

In [None]:
# Test 5: create_model function
print("\n[5/5] Testing create_model() factory function...\n")

model = create_model(
    backbone=DATA_CONFIG['backbone'],
    num_classes=DATA_CONFIG['num_classes'],
    pretrained=DATA_CONFIG['pretrained'],
    mil_type=DATA_CONFIG['mil_type']
).to(device)

print(f"   Model configuration:")
print(f"      Backbone:    {DATA_CONFIG['backbone']}")
print(f"      MIL type:    {DATA_CONFIG['mil_type']}")
print(f"      Classes:     {DATA_CONFIG['num_classes']}")
print(f"\n   Model statistics:")
print(f"      Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test forward pass
test_patches = torch.randn(1, 20, 3, 256, 256).to(device)
with torch.no_grad():
    logits, attention = model(test_patches, return_attention=True)

print(f"\n   Test forward pass:")
print(f"      Input:  {test_patches.shape}")
print(f"      Output: {logits.shape}")

print("\n   ‚úÖ create_model() passed")

print("\n" + "=" * 70)
print("‚úÖ model.py TESTING COMPLETE")
print("=" * 70)

---
## 4. Test dataset.py Components

**Components tested:**
1. get_transforms()
2. WSIDataset
3. collate_fn()
4. create_dataloaders()

In [None]:
print("\n" + "=" * 70)
print("TESTING dataset.py COMPONENTS")
print("=" * 70)

# Check if we have preprocessed data
processed_csv = Path(DATA_CONFIG['output_dir'],'processed_metadata.csv')

if not processed_csv.exists():
    print("\n‚ö†Ô∏è  No preprocessed data found.")
    print("   These tests require HDF5 files from preprocessing.")
    print("   Set RUN_FULL_PREPROCESSING=True in section 2 to generate data.")
    test_metadata = None
else:
    test_metadata = pd.read_csv(processed_csv)
    print(f"\n‚úÖ Found preprocessed data: {len(test_metadata)} slides")

In [None]:
# Test 1: get_transforms
print("\n[1/4] Testing get_transforms()...\n")

train_transform = get_transforms(augment=True)
val_transform = get_transforms(augment=False)

print(f"   Train transform (with augmentation):")
print(f"      {train_transform}")
print(f"\n   Val transform (no augmentation):")
print(f"      {val_transform}")

print("\n   ‚úÖ get_transforms() passed")

In [None]:
# Tests 2-4: WSIDataset, collate_fn, create_dataloaders
if test_metadata is not None and 'h5_path' in test_metadata.columns:
    
    # Test 2: WSIDataset
    print("\n[2/4] Testing WSIDataset...\n")
    
    dataset = WSIDataset(
        metadata_df=test_metadata,
        transform=val_transform,
        max_patches=DATA_CONFIG['max_patches'],
        sampling_strategy='random'
    )
    
    print(f"   Dataset created:")
    print(f"      Total slides: {len(dataset)}")
    print(f"      Max patches:  {dataset.max_patches}")
    
    if len(dataset) > 0:
        patches, label, coordinates, slide_id = dataset[0]
        
        print(f"\n   Sample loaded:")
        print(f"      Slide ID:    {slide_id}")
        print(f"      Patches:     {patches.shape}")
        print(f"      Label:       {label}")
        print(f"      Coordinates: {coordinates.shape}")
        
        print("\n   ‚úÖ WSIDataset passed")
        
        # Test 3: collate_fn
        print("\n[3/4] Testing collate_fn()...\n")
        
        batch = [dataset[i] for i in range(min(2, len(dataset)))]
        patches_list, labels, coords_list, slide_ids = collate_fn(batch)
        
        print(f"   Batch collated:")
        print(f"      Batch size:  {len(patches_list)}")
        print(f"      Labels:      {labels}")
        
        print("\n   ‚úÖ collate_fn() passed")
        
        # Test 4: create_dataloaders
        print("\n[4/4] Testing create_dataloaders()...\n")
        
        from sklearn.model_selection import train_test_split
        
        train_df, val_df = train_test_split(
            test_metadata,
            test_size=0.4,
            random_state=42,
            stratify=test_metadata['label'] if 'label' in test_metadata.columns else None
        )
        
        train_loader, val_loader, _ = create_dataloaders(
            train_df=train_df,
            val_df=val_df,
            batch_size=DATA_CONFIG['batch_size'],
            max_patches=DATA_CONFIG['max_patches'],
            num_workers=DATA_CONFIG['num_workers']
        )
        
        print(f"   DataLoaders created:")
        print(f"      Train batches: {len(train_loader)}")
        print(f"      Val batches:   {len(val_loader)}")
        
        print("\n   ‚úÖ create_dataloaders() passed")
    else:
        print("\n   ‚ö†Ô∏è  Dataset is empty")
        train_loader = None
        val_loader = None
else:
    print("\n   ‚ö†Ô∏è  Skipping dataset tests - no preprocessed data available")
    train_loader = None
    val_loader = None
    train_df = None

print("\n" + "=" * 70)
print("‚úÖ dataset.py TESTING COMPLETE")
print("=" * 70)

---
## 5. Test train.py Components

**Components tested:**
1. get_class_weights()
2. Trainer class

In [None]:
print("\n" + "=" * 70)
print("TESTING train.py COMPONENTS")
print("=" * 70)

if train_loader is not None and val_loader is not None:
    
    # Test 1: get_class_weights
    print("\n[1/2] Testing get_class_weights()...\n")
    
    train_ids = train_df['Image'].tolist()
    class_weights = get_class_weights(test_metadata, train_ids).to(device)
    
    print(f"   Class weights computed:")
    print(f"      Shape: {class_weights.shape}")
    print(f"      Weights: {class_weights.cpu().numpy()}")
    
    print("\n   ‚úÖ get_class_weights() passed")
    
    # Test 2: Trainer class
    print("\n[2/2] Testing Trainer class...\n")
    
    train_config = {
        'epochs': DATA_CONFIG['epochs'],
        'learning_rate': DATA_CONFIG['learning_rate'],
        'weight_decay': 1e-5,
        'lr_scheduler': 'cosine',
        'early_stopping': True,
        'patience': 5,
        'checkpoint_dir': Path(DATA_CONFIG['checkpoints_dir']),
        'use_wandb': False,
        'use_class_weights': True
    }
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=train_config['learning_rate'],
        weight_decay=train_config['weight_decay']
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=train_config['epochs']
        )
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        config=train_config,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
    )
    
    print(f"   Trainer initialized:")
    print(f"      Optimizer: {type(trainer.optimizer).__name__}")
    print(f"      Scheduler: {type(trainer.scheduler).__name__}")
    
    print(f"\n   Starting training ({DATA_CONFIG['epochs']} epochs)...\n")
    
    history = trainer.train()
    
    print(f"\n   Training completed:")
    print(f"      Final train loss: {history['train_loss'][-1]:.4f}")
    print(f"      Final val loss:   {history['val_loss'][-1]:.4f}")
    print(f"      Final val acc:    {history['val_accuracy'][-1]:.4f}")
    
    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    axes[0].plot(history['train_loss'], label='Train', marker='o')
    axes[0].plot(history['val_loss'], label='Val', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(history['val_accuracy'], label='Accuracy', marker='o')
    axes[1].plot(history['val_f1'], label='F1', marker='s')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Score')
    axes[1].set_title('Validation Metrics')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(Path(DATA_CONFIG['output_dir'],'training_curves.png'), dpi=150)
    plt.show()
    
    print("\n   ‚úÖ Trainer class passed")
    
else:
    print("\n   ‚ö†Ô∏è  Skipping training tests - no preprocessed data")

print("\n" + "=" * 70)
print("‚úÖ train.py TESTING COMPLETE")
print("=" * 70)

---
## 6. Test inference.py Components

**Components tested:**
1. WSIInference class
2. predict_from_patches()
3. create_attention_heatmap()
4. visualize_results()
5. process_slide()

In [None]:
print("\n" + "=" * 70)
print("TESTING inference.py COMPONENTS")
print("=" * 70)

if test_metadata is not None and 'h5_path' in test_metadata.columns:
    
    print("\n[1/5] Testing WSIInference initialization...\n")
    
    inference = WSIInference(
        model=model,
        device=device,
        max_patches=DATA_CONFIG['max_patches'],
        transform=val_transform
    )
    
    print(f"   WSIInference initialized")
    print("\n   ‚úÖ WSIInference initialization passed")
    
    # Test predict_from_patches
    test_h5_path = test_metadata['h5_path'].iloc[0]
    test_slide_id = test_metadata['slide_id'].iloc[0]
    
    if Path(test_h5_path).exists():
        print(f"\n[2/5] Testing predict_from_patches()...\n")
        
        pred_class, probs, attention = inference.predict_from_patches(test_h5_path)
        
        print(f"   Prediction results:")
        print(f"      Predicted class: {pred_class}")
        print(f"      Probabilities: {probs.shape}")
        print(f"      Attention: {attention.shape}")
        
        print("\n   ‚úÖ predict_from_patches() passed")
        
        # Test create_attention_heatmap
        print(f"\n[3/5] Testing create_attention_heatmap()...\n")
        
        import h5py
        with h5py.File(test_h5_path, 'r') as f:
            coordinates = f['coordinates'][:]
        
        heatmap = inference.create_attention_heatmap(
            attention_weights=attention,
            coordinates=coordinates,
            slide_size=(4096, 4096),
            downsample=32
        )
        
        print(f"   Heatmap created: {heatmap.shape}")
        print("\n   ‚úÖ create_attention_heatmap() passed")
        
        # Visualize
        plt.figure(figsize=(8, 8))
        plt.imshow(heatmap, cmap='jet')
        plt.title(f'Attention Heatmap - {test_slide_id}')
        plt.colorbar()
        plt.axis('off')
        plt.savefig(DATA_CONFIG['output_dir'] / 'heatmap_test.png', dpi=150)
        plt.show()
        
        # Test process_slide
        test_svs_path = DATA_CONFIG['svs_dir'] / f"{test_slide_id}.svs"
        
        if test_svs_path.exists():
            print(f"\n[4-5/5] Testing visualize_results() and process_slide()...\n")
            
            try:
                results = inference.process_slide(
                    h5_path=test_h5_path,
                    svs_path=str(test_svs_path),
                    output_dir=str(DATA_CONFIG['results_dir']),
                    slide_id=test_slide_id,
                    true_label=test_metadata['label'].iloc[0] if 'label' in test_metadata.columns else None
                )
                
                print(f"   Results: {results}")
                print("\n   ‚úÖ process_slide() passed")
                
            except Exception as e:
                print(f"   ‚ö†Ô∏è  Error: {e}")
    else:
        print(f"\n   ‚ö†Ô∏è  HDF5 file not found")
else:
    print("\n   ‚ö†Ô∏è  Skipping inference tests - no preprocessed data")

print("\n" + "=" * 70)
print("‚úÖ inference.py TESTING COMPLETE")
print("=" * 70)

---
## 7. Test train_ddp.py

**Component tested:**
1. DDPTrainer structure

In [None]:
print("\n" + "=" * 70)
print("TESTING train_ddp.py")
print("=" * 70)

num_gpus = torch.cuda.device_count()

print(f"\n   Available GPUs: {num_gpus}")

if num_gpus >= 2:
    print(f"\n   ‚úÖ Multiple GPUs detected - DDP ready")
    for i in range(num_gpus):
        print(f"      GPU {i}: {torch.cuda.get_device_name(i)}")
elif num_gpus == 1:
    print(f"\n   ‚ÑπÔ∏è  Single GPU - use train.py")
else:
    print(f"\n   ‚ÑπÔ∏è  No GPU - DDP requires GPUs")

print("\n   ‚úÖ train_ddp.py structure verified")

print("\n" + "=" * 70)
print("‚úÖ train_ddp.py TESTING COMPLETE")
print("=" * 70)

---
## 8. Final Summary

In [None]:
print("\n" + "=" * 70)
print("COMPREHENSIVE TESTING COMPLETE")
print("=" * 70)

print("\nüìã Components Tested:")
print("   preprocessing.py: 5/5 ‚úÖ")
print("   model.py:         5/5 ‚úÖ")
print("   dataset.py:       4/4 ‚úÖ" if test_metadata is not None else "   dataset.py:       0/4 ‚ö†Ô∏è")
print("   train.py:         2/2 ‚úÖ" if test_metadata is not None else "   train.py:         0/2 ‚ö†Ô∏è")
print("   inference.py:     5/5 ‚úÖ" if test_metadata is not None else "   inference.py:     0/5 ‚ö†Ô∏è")
print("   train_ddp.py:     1/1 ‚úÖ")

if test_metadata is None:
    print("\n‚ö†Ô∏è  Some tests skipped - no preprocessed data")
    print("   Set RUN_FULL_PREPROCESSING=True to enable all tests")
else:
    print("\nüéâ ALL 26 COMPONENTS TESTED SUCCESSFULLY!")

print("\n" + "=" * 70)