# PHASE 7: TRI-OBJECTIVE LOSS & TRAINING

**Author**: Viraj Pankaj Jain  
**Institution**: University of Glasgow  
**Project**: Tri-Objective Robust XAI for Medical Imaging  
**Date**: November 27, 2025

---

## üìã Phase 7 Overview

This notebook documents the execution of Phase 7: Tri-Objective Loss & Training.

### Objectives
1. ‚úÖ **L_task**: Classification loss with temperature scaling
2. ‚úÖ **L_rob**: TRADES adversarial robustness (PGD-7, Œµ=8/255)
3. ‚úÖ **L_expl**: Explanation stability (SSIM) + Concept alignment (TCAV)

### Combined Loss
```
L_total = L_task + Œª_rob √ó L_rob + Œª_expl √ó L_expl
```

### Hyperparameters
- **Œª_rob** = 0.3 (robustness weight)
- **Œª_expl** = 0.1 (explanation weight)
- **Œ≤** = 6.0 (TRADES parameter)
- **Œµ_rob** = 8/255 (PGD attack strength)
- **Œµ_expl** = 2/255 (explanation stability perturbation)

---

## ‚úÖ Status: Phase 7.7 - Initial Tri-Objective Validation

**Baseline Training Complete**:
- ‚úÖ 3 seeds trained (42, 123, 456)
- ‚úÖ Mean Accuracy: **64.33% ¬± 3.43%**
- ‚úÖ Mean AUROC: **91.27% ¬± 0.74%**
- ‚úÖ Dataset: ISIC 2018 (10,015 train, 193 val, 1,512 test)

**Tri-Objective Training**: Ready to execute

## üöÄ Phase 7.7: Tri-Objective Training Execution

### ‚ö†Ô∏è Important: Run from Terminal, Not Notebook

**Due to Jupyter path issues, please run the training from a PowerShell terminal instead:**

1. Open a new PowerShell terminal
2. Navigate to project root: `cd C:\Users\Dissertation\tri-objective-robust-xai-medimg`
3. Activate venv: `.\.venv\Scripts\Activate.ps1`
4. Run training (choose one):

**Option 1 - Single seed (42) - RECOMMENDED SETTINGS:**
```powershell
python scripts/train_tri_objective_standalone.py --data-root "data/processed/isic2018" --seed 42 --device cuda --batch-size 16 --max-epochs 60 --learning-rate 1e-4 --lambda-rob 0.3 --lambda-expl 0.1 --pgd-num-steps 7 --results-dir "results/tri_objective" --log-dir "logs/tri_objective" --use-mlflow --num-workers 0
```

**Option 2 - All seeds (42, 123, 456) - Run each separately:**
```powershell
# Seed 42
python scripts/train_tri_objective_standalone.py --data-root "data/processed/isic2018" --seed 42 --device cuda --batch-size 16 --max-epochs 60 --learning-rate 1e-4 --lambda-rob 0.3 --lambda-expl 0.1 --pgd-num-steps 7 --results-dir "results/tri_objective" --log-dir "logs/tri_objective" --use-mlflow --num-workers 0

# Seed 123
python scripts/train_tri_objective_standalone.py --data-root "data/processed/isic2018" --seed 123 --device cuda --batch-size 16 --max-epochs 60 --learning-rate 1e-4 --lambda-rob 0.3 --lambda-expl 0.1 --pgd-num-steps 7 --results-dir "results/tri_objective" --log-dir "logs/tri_objective" --use-mlflow --num-workers 0

# Seed 456
python scripts/train_tri_objective_standalone.py --data-root "data/processed/isic2018" --seed 456 --device cuda --batch-size 16 --max-epochs 60 --learning-rate 1e-4 --lambda-rob 0.3 --lambda-expl 0.1 --pgd-num-steps 7 --results-dir "results/tri_objective" --log-dir "logs/tri_objective" --use-mlflow --num-workers 0
```

**Configuration Notes:**
- **Batch size: 16** (reduced from 32 due to GPU memory constraints with adversarial training)
- **PGD steps: 7** (adversarial robustness training)
- **Œª_rob: 0.3, Œª_expl: 0.1** (tri-objective weights)
- **Expected time:** ~2-3 hours per seed (~6-9 hours total for all 3 seeds)

## ‚úÖ Training Status Update

**TRAINING IS NOW RUNNING SUCCESSFULLY!** üéâ

### What Was Fixed

**Issue**: Image tensor dimension mismatch - dataset returned HWC format (Height √ó Width √ó Channels) but PyTorch models expect CHW format (Channels √ó Height √ó Width).

**Solution**: Added proper Albumentations transforms with `ToTensorV2()`:
- ‚úÖ Train transforms: Augmentation + Normalization + ToTensorV2
- ‚úÖ Val/Test transforms: Resize + Normalization + ToTensorV2
- ‚úÖ Batch size: 16 (reduced from 32 for GPU memory)

### Current Training Details

**Configuration**:
- Dataset: ISIC 2018 (10,015 train, 193 val, 1,512 test)
- Model: ResNet-50 (pretrained, 7 classes)
- Batch size: 16
- Tri-objective weights: Œª_rob=0.3, Œª_expl=0.1
- PGD steps: 7 (adversarial training)
- Device: CUDA (RTX 3050 Laptop GPU)

**Training Progress**:
- ‚úÖ Environment initialized
- ‚úÖ Datasets loaded correctly
- ‚úÖ Model built successfully
- ‚úÖ TriObjectiveTrainer initialized
- ‚úÖ MLflow logging active
- ‚úÖ **Training in progress** (batches processing every ~30 seconds)

**Expected Warnings** (non-critical):
- Unicode encoding errors (Œª symbols on Windows console) - harmless
- Image normalization warnings (values outside [0,1]) - expected for ImageNet normalization
- get_embeddings() warnings - model correctly uses forward pass

### Performance Estimate
- **Time per batch**: ~30 seconds (adversarial training is compute-intensive)
- **Batches per epoch**: ~626 (10,015 samples / 16 batch size)
- **Time per epoch**: ~5-6 hours
- **Full training (60 epochs)**: ~5-6 days per seed

**Recommendation**: For Phase 7.7 validation, consider:
- Running fewer epochs (e.g., 20-30) to get initial results faster
- Using the current 2-epoch test to verify everything works
- Then running full 60-epoch training overnight

In [19]:
import sys
from pathlib import Path
import torch
import os

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
sys.path.insert(0, str(PROJECT_ROOT))

# Check environment
print("=" * 80)
print("ENVIRONMENT CHECK")
print("=" * 80)
print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"Current Working Dir: {Path.cwd()}")
print(f"Project Root: {PROJECT_ROOT}")

# Verify data path exists
data_path = PROJECT_ROOT / "data" / "processed" / "isic2018"
print(f"Data path: {data_path}")
print(f"Data exists: {data_path.exists()}")
if data_path.exists():
    csv_path = data_path / "metadata_processed.csv"
    print(f"CSV exists: {csv_path.exists()}")
print("=" * 80)

ENVIRONMENT CHECK
Python: 3.11.9 (tags/v3.11.9:de54cf5, Apr  2 2024, 10:12:12) [MSC v.1938 64 bit (AMD64)]
PyTorch: 2.9.1+cu128
CUDA Available: True
CUDA Device: NVIDIA GeForce RTX 3050 Laptop GPU
CUDA Memory: 4.29 GB
Current Working Dir: c:\
Project Root: c:\
Data path: c:\data\processed\isic2018
Data exists: False


### Monitor Training with MLflow

In [20]:
# Start MLflow UI in background (run in separate terminal)
# Command: mlflow ui --port 5000
# Then open: http://localhost:5000

# Or start from notebook:
import subprocess
import time

# Start MLflow UI
mlflow_process = subprocess.Popen(
    ["mlflow", "ui", "--port", "5000"],
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE
)

print("MLflow UI starting...")
time.sleep(3)
print("‚úÖ MLflow UI running at: http://localhost:5000")
print("View experiment: 'Tri-Objective-XAI-Dermoscopy'")
print("\nTo stop: mlflow_process.terminate()")

MLflow UI starting...
‚úÖ MLflow UI running at: http://localhost:5000
View experiment: 'Tri-Objective-XAI-Dermoscopy'

To stop: mlflow_process.terminate()
‚úÖ MLflow UI running at: http://localhost:5000
View experiment: 'Tri-Objective-XAI-Dermoscopy'

To stop: mlflow_process.terminate()


## üìä Results Analysis

### Load Training Results

In [21]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Load results for all seeds
results = {}
for seed in [42, 123, 456]:
    result_file = f"../results/tri_objective/tri_objective_seed{seed}_results.json"
    try:
        with open(result_file, 'r') as f:
            results[seed] = json.load(f)
        print(f"‚úÖ Loaded results for seed {seed}")
    except FileNotFoundError:
        print(f"‚ö†Ô∏è  Results not found for seed {seed}")

print(f"\nLoaded {len(results)} seed results")

‚ö†Ô∏è  Results not found for seed 42
‚ö†Ô∏è  Results not found for seed 123
‚ö†Ô∏è  Results not found for seed 456

Loaded 0 seed results


### Compare with Baseline

In [None]:
# Baseline results (from previous training)
baseline_results = {
    'accuracy': 0.6433,
    'auroc': 0.9127,
    'robust_acc': 0.10,  # Estimated (untrained)
    'ssim': 0.60,        # Estimated
}

# Extract tri-objective results (placeholder - update after training)
tri_obj_results = {
    'accuracy': None,  # Will be populated after training
    'auroc': None,
    'robust_acc': None,
    'ssim': None,
    'artifact_tcav': None,
    'medical_tcav': None,
}

# Expected targets
targets = {
    'accuracy': 0.83,        # ‚â•83% (allow -2% from baseline)
    'robust_acc': 0.45,      # ‚â•45% (+35pp from baseline)
    'ssim': 0.75,            # ‚â•75% (+15pp from baseline)
    'artifact_tcav': 0.20,   # ‚â§20% (-25pp from baseline)
    'medical_tcav': 0.68,    # ‚â•68% (+10pp from baseline)
}

print("BASELINE vs TRI-OBJECTIVE vs TARGETS")
print("="*80)
print(f"{'Metric':<20} {'Baseline':<15} {'Tri-Obj':<15} {'Target':<15} {'Status'}")
print("="*80)

for metric in ['accuracy', 'auroc', 'robust_acc', 'ssim']:
    baseline = baseline_results.get(metric, 'N/A')
    triobj = tri_obj_results.get(metric, 'N/A')
    target = targets.get(metric, 'N/A')
    
    print(f"{metric:<20} {baseline:<15} {triobj:<15} {target:<15} {'‚è≥ Pending'}")

## ‚úÖ Phase 7.7 Completion Checklist

### Phase 7 Criteria
- ‚úÖ **Explanation loss implemented and tested**
  - L_stab: SSIM stability loss
  - L_concept: TCAV regularization
- ‚úÖ **Tri-objective loss integrated**
  - Combined L_total = L_task + Œª_rob √ó L_rob + Œª_expl √ó L_expl
- ‚úÖ **Tri-objective trainer working end-to-end**
  - PGD adversarial training
  - Explanation loss computation
  - MLflow logging
- ‚è≥ **Tri-objective models trained (3 seeds √ó 1 dataset)**
  - ISIC 2018: Pending execution
  - NIH ChestX-ray14: Phase 7.6 (future)
- ‚è≥ **Initial validation shows improvements**
  - Robust accuracy: Target +35pp
  - SSIM stability: Target +15pp
  - TCAV alignment: Target improvements
- ‚úÖ **All training logged to MLflow**
  - Experiment setup complete
  - Real-time monitoring ready

---

## üéØ Next Steps

1. **Execute training** using one of the options above
2. **Monitor in MLflow** at http://localhost:5000
3. **Validate results** against targets (accuracy ‚â•83%, robust ‚â•45%, SSIM ‚â•75%)
4. **Document findings** in dissertation
5. **Proceed to Phase 8** (comprehensive evaluation)

## üìã Phase 7.7: Initial Tri-Objective Validation Checklist

This section provides comprehensive evaluation tools for Phase 7.7 validation.

### Validation Objectives
1. ‚úÖ **Clean Accuracy**: Similar to baseline or slightly lower (allow -2%)
2. ‚úÖ **Robust Accuracy**: Significant improvement over baseline (~+35pp)
3. ‚úÖ **SSIM Stability**: Improved explanation consistency (~+15pp)
4. ‚úÖ **Artifact TCAV**: Decreased artifact reliance (~-25pp)
5. ‚úÖ **Medical TCAV**: Increased medical concept alignment (~+10pp)

In [None]:
"""
Phase 7.7 Quick Evaluation Setup
Run this cell to set up evaluation utilities
"""

import sys
from pathlib import Path
import torch
import numpy as np
import json
from typing import Dict, Any, List

# Add project root
PROJECT_ROOT = Path.cwd().parent
sys.path.insert(0, str(PROJECT_ROOT))

from src.models.build import build_model
from src.datasets.isic import ISICDataset
from src.datasets.transforms import get_isic_transforms
from src.attacks.pgd import PGD
from src.xai.gradcam import GradCAM
from torch.utils.data import DataLoader

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"‚úÖ Evaluation device: {device}")

# Baseline results (Phase 7.6)
BASELINE_RESULTS = {
    'seed_42': {
        'clean_acc': 0.6435,
        'auroc': 0.9224,
        'robust_acc': 0.10,  # Estimated
        'ssim': 0.60,  # Estimated
        'artifact_tcav': 0.45,  # Estimated
        'medical_tcav': 0.58  # Estimated
    },
    'seed_123': {
        'clean_acc': 0.6012,
        'auroc': 0.9113,
        'robust_acc': 0.10,
        'ssim': 0.60,
        'artifact_tcav': 0.45,
        'medical_tcav': 0.58
    },
    'seed_456': {
        'clean_acc': 0.6852,
        'auroc': 0.9044,
        'robust_acc': 0.10,
        'ssim': 0.60,
        'artifact_tcav': 0.45,
        'medical_tcav': 0.58
    },
    'mean': {
        'clean_acc': 0.6433,
        'auroc': 0.9127,
        'robust_acc': 0.10,
        'ssim': 0.60,
        'artifact_tcav': 0.45,
        'medical_tcav': 0.58
    }
}

# Phase 7.7 Targets
PHASE_77_TARGETS = {
    'clean_acc': 0.83,  # ‚â•83% (allow -2% from baseline 85%)
    'robust_acc': 0.45,  # ‚â•45% (+35pp from baseline ~10%)
    'ssim': 0.75,  # ‚â•75% (+15pp from baseline ~60%)
    'artifact_tcav': 0.20,  # ‚â§20% (-25pp from baseline ~45%)
    'medical_tcav': 0.68  # ‚â•68% (+10pp from baseline ~58%)
}

print("‚úÖ Baseline results loaded")
print("‚úÖ Phase 7.7 targets defined")
print("\nReady for evaluation!")

### 1Ô∏è‚É£ Load Trained Model and Evaluate Clean Accuracy

In [None]:
"""
Evaluate Clean Accuracy
Expected: Similar to baseline (‚â•83%, allow -2% drop)
"""

def evaluate_clean_accuracy(checkpoint_path: Path, seed: int = 42) -> Dict[str, float]:
    """
    Evaluate clean (non-adversarial) accuracy on test set.
    
    Returns:
        dict: {'accuracy': float, 'loss': float}
    """
    print(f"\n{'='*80}")
    print(f"CLEAN ACCURACY EVALUATION - Seed {seed}")
    print(f"{'='*80}")
    
    # Load model
    print(f"Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model = build_model(
        name='resnet50',
        num_classes=7,
        pretrained=False  # Load from checkpoint
    ).to(device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print("‚úÖ Model loaded")
    
    # Load test dataset
    test_transforms = get_isic_transforms(split='test', image_size=224)
    test_dataset = ISICDataset(
        root=PROJECT_ROOT / "data" / "processed" / "isic2018",
        split="test",
        transforms=test_transforms
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=0
    )
    print(f"‚úÖ Test dataset loaded: {len(test_dataset)} samples")
    
    # Evaluate
    correct = 0
    total = 0
    total_loss = 0.0
    criterion = torch.nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            if len(batch) == 3:
                images, labels, _ = batch
            else:
                images, labels = batch
            
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            total_loss += loss.item()
            
            if (batch_idx + 1) % 10 == 0:
                print(f"Batch {batch_idx + 1}/{len(test_loader)}: "
                      f"Acc={100*correct/total:.2f}%")
    
    accuracy = correct / total
    avg_loss = total_loss / len(test_loader)
    
    print(f"\n{'='*80}")
    print(f"RESULTS:")
    print(f"  Clean Accuracy: {accuracy*100:.2f}% ({correct}/{total})")
    print(f"  Average Loss: {avg_loss:.4f}")
    print(f"  Baseline: {BASELINE_RESULTS['mean']['clean_acc']*100:.2f}%")
    print(f"  Target: {PHASE_77_TARGETS['clean_acc']*100:.2f}%")
    
    if accuracy >= PHASE_77_TARGETS['clean_acc']:
        print(f"  ‚úÖ PASS: Meets target (‚â•{PHASE_77_TARGETS['clean_acc']*100:.0f}%)")
    elif accuracy >= BASELINE_RESULTS['mean']['clean_acc'] - 0.02:
        print(f"  ‚ö†Ô∏è  ACCEPTABLE: Within -2% of baseline")
    else:
        print(f"  ‚ùå FAIL: Below acceptable threshold")
    print(f"{'='*80}\n")
    
    return {
        'accuracy': accuracy,
        'loss': avg_loss
    }

# Example usage (run after training completes):
# checkpoint_path = PROJECT_ROOT / "checkpoints" / "tri_objective" / "best.pt"
# if checkpoint_path.exists():
#     clean_results = evaluate_clean_accuracy(checkpoint_path, seed=42)
# else:
#     print(f"‚ö†Ô∏è  Checkpoint not found: {checkpoint_path}")
#     print("Run this cell after training completes")

print("‚úÖ Clean accuracy evaluation function ready")
print("Uncomment the example usage code after training completes")

### 2Ô∏è‚É£ Evaluate Robust Accuracy (PGD Attack)

In [None]:
"""
Evaluate Robust Accuracy
Expected: Significant improvement (‚â•45%, +35pp from baseline ~10%)
"""

def evaluate_robust_accuracy(
    checkpoint_path: Path,
    seed: int = 42,
    epsilon: float = 8/255,
    num_steps: int = 20,
    step_size: float = 2/255
) -> Dict[str, float]:
    """
    Evaluate robustness against PGD adversarial attacks.
    
    Args:
        checkpoint_path: Path to model checkpoint
        seed: Random seed
        epsilon: PGD attack strength (default: 8/255)
        num_steps: PGD steps (default: 20 for evaluation)
        step_size: PGD step size (default: 2/255)
        
    Returns:
        dict: {'robust_acc': float, 'clean_acc': float, 'attack_success_rate': float}
    """
    print(f"\n{'='*80}")
    print(f"ROBUST ACCURACY EVALUATION - Seed {seed}")
    print(f"{'='*80}")
    print(f"PGD Attack Config: Œµ={epsilon:.4f}, steps={num_steps}, Œ±={step_size:.4f}")
    
    # Load model
    print(f"\nLoading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model = build_model(
        name='resnet50',
        num_classes=7,
        pretrained=False
    ).to(device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print("‚úÖ Model loaded")
    
    # Initialize PGD attack
    pgd_attack = PGD(
        epsilon=epsilon,
        num_steps=num_steps,
        step_size=step_size,
        random_start=True,
        device=device
    )
    print("‚úÖ PGD attack initialized")
    
    # Load test dataset
    test_transforms = get_isic_transforms(split='test', image_size=224)
    test_dataset = ISICDataset(
        root=PROJECT_ROOT / "data" / "processed" / "isic2018",
        split="test",
        transforms=test_transforms
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=16,  # Smaller batch for adversarial evaluation
        shuffle=False,
        num_workers=0
    )
    print(f"‚úÖ Test dataset loaded: {len(test_dataset)} samples")
    
    # Evaluate
    clean_correct = 0
    robust_correct = 0
    total = 0
    
    print("\nGenerating adversarial examples and evaluating...")
    for batch_idx, batch in enumerate(test_loader):
        if len(batch) == 3:
            images, labels, _ = batch
        else:
            images, labels = batch
        
        images = images.to(device)
        labels = labels.to(device)
        
        # Clean predictions
        with torch.no_grad():
            clean_outputs = model(images)
            _, clean_pred = torch.max(clean_outputs, 1)
            clean_correct += (clean_pred == labels).sum().item()
        
        # Generate adversarial examples
        images_adv = pgd_attack(model, images, labels)
        
        # Robust predictions
        with torch.no_grad():
            adv_outputs = model(images_adv)
            _, adv_pred = torch.max(adv_outputs, 1)
            robust_correct += (adv_pred == labels).sum().item()
        
        total += labels.size(0)
        
        if (batch_idx + 1) % 10 == 0:
            print(f"Batch {batch_idx + 1}/{len(test_loader)}: "
                  f"Clean={100*clean_correct/total:.2f}%, "
                  f"Robust={100*robust_correct/total:.2f}%")
    
    clean_acc = clean_correct / total
    robust_acc = robust_correct / total
    attack_success = (clean_correct - robust_correct) / clean_correct if clean_correct > 0 else 0
    
    print(f"\n{'='*80}")
    print(f"RESULTS:")
    print(f"  Clean Accuracy: {clean_acc*100:.2f}% ({clean_correct}/{total})")
    print(f"  Robust Accuracy: {robust_acc*100:.2f}% ({robust_correct}/{total})")
    print(f"  Attack Success Rate: {attack_success*100:.2f}%")
    print(f"  Baseline Robust: {BASELINE_RESULTS['mean']['robust_acc']*100:.2f}%")
    print(f"  Target Robust: {PHASE_77_TARGETS['robust_acc']*100:.2f}%")
    print(f"  Improvement: +{(robust_acc - BASELINE_RESULTS['mean']['robust_acc'])*100:.1f}pp")
    
    if robust_acc >= PHASE_77_TARGETS['robust_acc']:
        print(f"  ‚úÖ PASS: Meets target (‚â•{PHASE_77_TARGETS['robust_acc']*100:.0f}%)")
    else:
        print(f"  ‚ùå FAIL: Below target ({robust_acc*100:.1f}% < {PHASE_77_TARGETS['robust_acc']*100:.0f}%)")
    print(f"{'='*80}\n")
    
    return {
        'robust_acc': robust_acc,
        'clean_acc': clean_acc,
        'attack_success_rate': attack_success
    }

# Example usage (run after training completes):
# checkpoint_path = PROJECT_ROOT / "checkpoints" / "tri_objective" / "best.pt"
# if checkpoint_path.exists():
#     robust_results = evaluate_robust_accuracy(checkpoint_path, seed=42)
# else:
#     print(f"‚ö†Ô∏è  Checkpoint not found: {checkpoint_path}")
#     print("Run this cell after training completes")

print("‚úÖ Robust accuracy evaluation function ready")
print("Uncomment the example usage code after training completes")

### 3Ô∏è‚É£ Evaluate SSIM Explanation Stability

In [None]:
"""
Evaluate SSIM Explanation Stability
Expected: Improved stability (‚â•75%, +15pp from baseline ~60%)
"""

from skimage.metrics import structural_similarity as ssim

def evaluate_ssim_stability(
    checkpoint_path: Path,
    seed: int = 42,
    epsilon: float = 2/255,
    num_samples: int = 200
) -> Dict[str, float]:
    """
    Evaluate SSIM between clean and perturbed explanations.
    
    Args:
        checkpoint_path: Path to model checkpoint
        seed: Random seed
        epsilon: Perturbation strength (default: 2/255)
        num_samples: Number of samples to evaluate
        
    Returns:
        dict: {'mean_ssim': float, 'std_ssim': float}
    """
    print(f"\n{'='*80}")
    print(f"SSIM EXPLANATION STABILITY EVALUATION - Seed {seed}")
    print(f"{'='*80}")
    print(f"Perturbation: Œµ={epsilon:.4f}, Samples={num_samples}")
    
    # Load model
    print(f"\nLoading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model = build_model(
        name='resnet50',
        num_classes=7,
        pretrained=False
    ).to(device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print("‚úÖ Model loaded")
    
    # Initialize Grad-CAM
    gradcam = GradCAM(model, target_layer='layer4')
    print("‚úÖ Grad-CAM initialized")
    
    # Load test dataset
    test_transforms = get_isic_transforms(split='test', image_size=224)
    test_dataset = ISICDataset(
        root=PROJECT_ROOT / "data" / "processed" / "isic2018",
        split="test",
        transforms=test_transforms
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=1,  # Process one at a time for explanations
        shuffle=False,
        num_workers=0
    )
    print(f"‚úÖ Test dataset loaded")
    
    # Evaluate SSIM
    ssim_scores = []
    
    print("\nComputing SSIM scores...")
    for idx, batch in enumerate(test_loader):
        if idx >= num_samples:
            break
            
        if len(batch) == 3:
            images, labels, _ = batch
        else:
            images, labels = batch
        
        images = images.to(device)
        
        # Generate clean explanation
        with torch.enable_grad():
            heatmap_clean = gradcam.generate_heatmap(
                images,
                target_class=None  # Use predicted class
            )
        
        # Generate perturbed image
        noise = torch.randn_like(images) * epsilon
        images_pert = torch.clamp(images + noise, -3, 3)  # Clamp to reasonable range
        
        # Generate perturbed explanation
        with torch.enable_grad():
            heatmap_pert = gradcam.generate_heatmap(
                images_pert,
                target_class=None
            )
        
        # Compute SSIM
        hm_clean = heatmap_clean.cpu().numpy()
        hm_pert = heatmap_pert.cpu().numpy()
        
        ssim_score = ssim(
            hm_clean,
            hm_pert,
            data_range=hm_clean.max() - hm_clean.min()
        )
        ssim_scores.append(ssim_score)
        
        if (idx + 1) % 50 == 0:
            print(f"Processed {idx + 1}/{num_samples}: Mean SSIM={np.mean(ssim_scores):.4f}")
    
    mean_ssim = np.mean(ssim_scores)
    std_ssim = np.std(ssim_scores)
    
    print(f"\n{'='*80}")
    print(f"RESULTS:")
    print(f"  Mean SSIM: {mean_ssim:.4f} ¬± {std_ssim:.4f}")
    print(f"  Baseline SSIM: {BASELINE_RESULTS['mean']['ssim']:.4f}")
    print(f"  Target SSIM: {PHASE_77_TARGETS['ssim']:.4f}")
    print(f"  Improvement: +{(mean_ssim - BASELINE_RESULTS['mean']['ssim']):.4f}")
    
    if mean_ssim >= PHASE_77_TARGETS['ssim']:
        print(f"  ‚úÖ PASS: Meets target (‚â•{PHASE_77_TARGETS['ssim']:.2f})")
    else:
        print(f"  ‚ùå FAIL: Below target ({mean_ssim:.4f} < {PHASE_77_TARGETS['ssim']:.2f})")
    print(f"{'='*80}\n")
    
    return {
        'mean_ssim': mean_ssim,
        'std_ssim': std_ssim,
        'all_scores': ssim_scores
    }

# Example usage (run after training completes):
# checkpoint_path = PROJECT_ROOT / "checkpoints" / "tri_objective" / "best.pt"
# if checkpoint_path.exists():
#     ssim_results = evaluate_ssim_stability(checkpoint_path, seed=42)
# else:
#     print(f"‚ö†Ô∏è  Checkpoint not found: {checkpoint_path}")
#     print("Run this cell after training completes")

print("‚úÖ SSIM stability evaluation function ready")
print("‚ö†Ô∏è  Note: This evaluation requires Grad-CAM and may take 10-15 minutes")
print("Uncomment the example usage code after training completes")

### 4Ô∏è‚É£ Comprehensive Validation Report

In [None]:
"""
Generate Comprehensive Phase 7.7 Validation Report
Run all evaluations and generate summary report
"""

def generate_phase77_validation_report(
    checkpoint_path: Path,
    seed: int = 42,
    save_report: bool = True
) -> Dict[str, Any]:
    """
    Run comprehensive Phase 7.7 validation and generate report.
    
    Args:
        checkpoint_path: Path to trained model checkpoint
        seed: Random seed
        save_report: Save report to JSON file
        
    Returns:
        dict: Complete validation results
    """
    print("\n" + "="*80)
    print("PHASE 7.7 COMPREHENSIVE VALIDATION REPORT")
    print("="*80 + "\n")
    
    results = {
        'seed': seed,
        'checkpoint': str(checkpoint_path),
        'timestamp': datetime.now().isoformat()
    }
    
    try:
        # 1. Clean Accuracy
        print("1Ô∏è‚É£  Evaluating Clean Accuracy...")
        clean_results = evaluate_clean_accuracy(checkpoint_path, seed)
        results['clean_accuracy'] = clean_results
        
        # 2. Robust Accuracy
        print("\n2Ô∏è‚É£  Evaluating Robust Accuracy...")
        robust_results = evaluate_robust_accuracy(checkpoint_path, seed)
        results['robust_accuracy'] = robust_results
        
        # 3. SSIM Stability
        print("\n3Ô∏è‚É£  Evaluating SSIM Stability...")
        ssim_results = evaluate_ssim_stability(checkpoint_path, seed, num_samples=200)
        results['ssim_stability'] = ssim_results
        
        # 4. Summary
        print("\n" + "="*80)
        print("PHASE 7.7 VALIDATION SUMMARY")
        print("="*80)
        
        # Create summary table
        metrics = {
            'Clean Accuracy': {
                'value': clean_results['accuracy'],
                'baseline': BASELINE_RESULTS['mean']['clean_acc'],
                'target': PHASE_77_TARGETS['clean_acc'],
                'format': '.2%'
            },
            'Robust Accuracy': {
                'value': robust_results['robust_acc'],
                'baseline': BASELINE_RESULTS['mean']['robust_acc'],
                'target': PHASE_77_TARGETS['robust_acc'],
                'format': '.2%'
            },
            'SSIM Stability': {
                'value': ssim_results['mean_ssim'],
                'baseline': BASELINE_RESULTS['mean']['ssim'],
                'target': PHASE_77_TARGETS['ssim'],
                'format': '.4f'
            }
        }
        
        print(f"\n{'Metric':<20} {'Value':<12} {'Baseline':<12} {'Target':<12} {'Œî':<10} {'Status'}")
        print("-"*80)
        
        all_pass = True
        for metric_name, metric_data in metrics.items():
            val = metric_data['value']
            baseline = metric_data['baseline']
            target = metric_data['target']
            fmt = metric_data['format']
            
            delta = val - baseline
            
            # Determine status
            if metric_name in ['Clean Accuracy', 'Robust Accuracy', 'SSIM Stability']:
                passed = val >= target
            else:
                passed = val <= target  # For metrics where lower is better
            
            all_pass = all_pass and passed
            status = "‚úÖ PASS" if passed else "‚ùå FAIL"
            
            print(f"{metric_name:<20} {val:{fmt}:<12} {baseline:{fmt}:<12} "
                  f"{target:{fmt}:<12} {delta:+.4f}    {status}")
        
        print("-"*80)
        
        if all_pass:
            print("\nüéâ ALL VALIDATION CRITERIA PASSED!")
            print("Phase 7.7 Initial Tri-Objective Validation: SUCCESSFUL")
        else:
            print("\n‚ö†Ô∏è  Some validation criteria not met")
            print("Review individual metric results above")
        
        print("="*80 + "\n")
        
        results['summary'] = {
            'all_pass': all_pass,
            'metrics': metrics
        }
        
        # Save report
        if save_report:
            report_path = PROJECT_ROOT / "results" / "tri_objective" / f"phase77_validation_seed{seed}.json"
            report_path.parent.mkdir(parents=True, exist_ok=True)
            
            with open(report_path, 'w') as f:
                json.dump(results, f, indent=2, default=str)
            
            print(f"‚úÖ Report saved to: {report_path}")
        
        return results
        
    except Exception as e:
        print(f"\n‚ùå Error during validation: {e}")
        import traceback
        traceback.print_exc()
        return results

# Example usage (run after training completes):
# checkpoint_path = PROJECT_ROOT / "checkpoints" / "tri_objective" / "best.pt"
# if checkpoint_path.exists():
#     validation_report = generate_phase77_validation_report(checkpoint_path, seed=42)
# else:
#     print(f"‚ö†Ô∏è  Checkpoint not found: {checkpoint_path}")
#     print("Run this cell after training completes")

from datetime import datetime
print("‚úÖ Comprehensive validation report function ready")
print("Uncomment the example usage code after training completes")

### üìà Monitor Training Progress (During Training)

In [None]:
"""
Quick Training Progress Check
Run this cell to check training logs and early metrics
"""

def check_training_progress(log_dir: Path = None, seed: int = 42):
    """
    Check training progress from log files.
    """
    if log_dir is None:
        log_dir = PROJECT_ROOT / "logs" / "tri_objective"
    
    print(f"\n{'='*80}")
    print(f"TRAINING PROGRESS CHECK - Seed {seed}")
    print(f"{'='*80}\n")
    
    # Find latest log file
    log_files = list(log_dir.glob(f"tri_objective_seed{seed}_*.log"))
    
    if not log_files:
        print(f"‚ö†Ô∏è  No log files found in {log_dir}")
        print("Training may not have started yet")
        return
    
    latest_log = max(log_files, key=lambda p: p.stat().st_mtime)
    print(f"Reading log: {latest_log.name}")
    
    # Parse log file
    with open(latest_log, 'r', encoding='utf-8', errors='ignore') as f:
        lines = f.readlines()
    
    # Extract key information
    epochs_info = []
    current_epoch = None
    
    for line in lines:
        # Look for epoch information
        if 'Epoch' in line and 'train_loss' in line:
            try:
                # Extract metrics from epoch summary
                if 'train_loss' in line:
                    parts = line.split('train_loss')
                    if len(parts) > 1:
                        epochs_info.append(line.strip())
            except:
                pass
    
    # Display progress
    if epochs_info:
        print(f"\nüìä Training Progress:")
        print(f"  Total epoch summaries found: {len(epochs_info)}")
        print(f"\n  Latest epoch logs:")
        for log in epochs_info[-3:]:  # Show last 3 epochs
            print(f"    {log}")
    else:
        print("\n‚è≥ Training in progress...")
        print(f"   Log file size: {latest_log.stat().st_size / 1024:.2f} KB")
        print(f"   Last modified: {datetime.fromtimestamp(latest_log.stat().st_mtime)}")
    
    # Check for errors
    errors = [line for line in lines if 'ERROR' in line or 'Error' in line]
    if errors:
        print(f"\n‚ö†Ô∏è  Found {len(errors)} error(s) in log:")
        for err in errors[-5:]:  # Show last 5 errors
            print(f"    {err.strip()}")
    else:
        print(f"\n‚úÖ No errors detected in log")
    
    print(f"\n{'='*80}\n")

# Run progress check
try:
    check_training_progress(seed=42)
except Exception as e:
    print(f"Error checking progress: {e}")
    print("\nTo manually check:")
    print(f"  1. Open logs/tri_objective/ directory")
    print(f"  2. Find tri_objective_seed42_*.log file")
    print(f"  3. Tail the file to see latest progress")

## ‚úÖ Phase 7.7 Complete Validation Checklist

### Quick Evaluation During/After Training

#### 1Ô∏è‚É£ Clean Accuracy Evaluation
- [ ] Load trained model checkpoint
- [ ] Evaluate on test set (1,512 samples)
- [ ] **Expected**: ‚â•83% (allow -2% from baseline 85%)
- [ ] **Status**: ‚è≥ Pending training completion
- [ ] **Cell to run**: Cell with `evaluate_clean_accuracy()` function

#### 2Ô∏è‚É£ Robust Accuracy Evaluation  
- [ ] Load trained model checkpoint
- [ ] Run PGD-20 attack (Œµ=8/255)
- [ ] Evaluate adversarial accuracy
- [ ] **Expected**: ‚â•45% (+35pp from baseline ~10%)
- [ ] **Status**: ‚è≥ Pending training completion
- [ ] **Cell to run**: Cell with `evaluate_robust_accuracy()` function
- [ ] **Time estimate**: ~20-30 minutes

#### 3Ô∏è‚É£ SSIM Explanation Stability
- [ ] Load trained model checkpoint
- [ ] Generate Grad-CAM heatmaps (200 samples)
- [ ] Compute SSIM between clean/perturbed
- [ ] **Expected**: ‚â•0.75 (+15pp from baseline ~0.60)
- [ ] **Status**: ‚è≥ Pending training completion
- [ ] **Cell to run**: Cell with `evaluate_ssim_stability()` function
- [ ] **Time estimate**: ~10-15 minutes

#### 4Ô∏è‚É£ TCAV Concept Alignment (Optional - Future)
- [ ] Prepare concept activation vectors (CAVs)
  - [ ] Artifact concepts (hair, ruler, gel)
  - [ ] Medical concepts (pigment network, vessels)
- [ ] Compute TCAV scores
- [ ] **Expected**: Artifact ‚â§0.20, Medical ‚â•0.68
- [ ] **Status**: ‚è≥ Requires CAV preparation (Phase 7.8)
- [ ] **Note**: Can be deferred to comprehensive evaluation

### Early Observation of Improvements

#### Confirm All Three Objectives Addressed
- [ ] **L_task (Classification)**: Model maintains clean accuracy
  - Check: Clean accuracy ‚â•83% or within -2% of baseline
- [ ] **L_rob (Robustness)**: Model resists adversarial attacks
  - Check: Robust accuracy ‚â•45% (major improvement)
- [ ] **L_expl (Explanations)**: Explanations are stable
  - Check: SSIM ‚â•0.75 (improved consistency)

#### Identify Issues Before Full Evaluation
- [ ] Check training logs for convergence
  - Use: `check_training_progress()` cell
- [ ] Verify no overfitting (train vs val loss)
  - Monitor: MLflow at http://localhost:5000
- [ ] Confirm balanced loss components
  - Check: L_task, L_rob, L_expl all contributing
- [ ] Validate checkpoints saved correctly
  - Location: `checkpoints/tri_objective/best.pt`

### Execution Order

**Step 1**: Let training complete (2 epochs test: ~10-12 hours, Full 60 epochs: ~5-6 days)

**Step 2**: Run comprehensive validation
```python
# Uncomment and run after training completes:
checkpoint_path = PROJECT_ROOT / "checkpoints" / "tri_objective" / "best.pt"
validation_report = generate_phase77_validation_report(checkpoint_path, seed=42)
```

**Step 3**: Review results and decide next steps
- ‚úÖ All metrics pass ‚Üí Proceed with full 3-seed training
- ‚ö†Ô∏è Some metrics below target ‚Üí Adjust hyperparameters
- ‚ùå Training issues ‚Üí Debug and restart

### Success Criteria Summary

| Metric | Baseline | Target | Status |
|--------|----------|--------|--------|
| Clean Accuracy | 64.33% | ‚â•83% | ‚è≥ Pending |
| Robust Accuracy | ~10% | ‚â•45% | ‚è≥ Pending |
| SSIM Stability | ~60% | ‚â•75% | ‚è≥ Pending |
| Artifact TCAV | ~45% | ‚â§20% | ‚è≥ Future |
| Medical TCAV | ~58% | ‚â•68% | ‚è≥ Future |

---

**Current Status**: üèÉ Training in progress (2-epoch validation test)  
**Next Action**: Wait for training completion, then run validation cells above  
**Timeline**: Test completes in ~10-12 hours, full validation takes ~30-45 minutes