# üõ°Ô∏è Phase 4: Adversarial Robustness Evaluation

## Tri-Objective Robust XAI for Medical Imaging

**Author:** Viraj Pankaj Jain  
**Institution:** University of Glasgow  
**Date:** November 2025

---

## üìã Overview

This notebook evaluates the **adversarial robustness** of baseline ResNet-50 models trained on ISIC 2018 dermoscopy images. We systematically assess vulnerability to:

| Attack | Type | Strength | Use Case |
|--------|------|----------|----------|
| **FGSM** | Gradient-based | Fast, single-step | Real-time threat assessment |
| **PGD** | Iterative | Strong, multi-step | Reliable robustness benchmark |
| **C&W** | Optimization | Strongest, minimal perturbation | Worst-case security analysis |

## üéØ Research Questions Addressed

- **RQ1:** How vulnerable are standard CNNs to adversarial attacks in medical imaging?
- **RQ2:** How does attack strength (Œµ) affect model accuracy degradation?
- **RQ3:** Which skin lesion classes are most vulnerable to adversarial perturbations?

## üìä Evaluation Protocol

```
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ  ADVERSARIAL EVALUATION PIPELINE                                ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ  1. Load trained baseline models (Seeds: 42, 123, 456)          ‚îÇ
‚îÇ  2. Evaluate clean accuracy (sanity check)                      ‚îÇ
‚îÇ  3. Generate adversarial examples at Œµ ‚àà {2/255, 4/255, 8/255}  ‚îÇ
‚îÇ  4. Measure robust accuracy under each attack                   ‚îÇ
‚îÇ  5. Analyze per-class vulnerability                             ‚îÇ
‚îÇ  6. Visualize perturbations and decision boundaries             ‚îÇ
‚îÇ  7. Statistical analysis across seeds                           ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
```

## ‚ö° Hardware Requirements

- **Recommended:** NVIDIA A100 (40GB) - Full evaluation ~15 minutes
- **Minimum:** NVIDIA T4 (16GB) - Full evaluation ~45 minutes

---

In [None]:
# ============================================================================
# CELL 1: ENVIRONMENT SETUP
# ============================================================================
import sys
import os
from pathlib import Path

print("=" * 70)
print("PHASE 4: ADVERSARIAL ROBUSTNESS EVALUATION")
print("=" * 70)

# Mount Google Drive
try:
    from google.colab import drive
    drive.mount('/content/drive')
    IN_COLAB = True
    print("‚úÖ Google Colab detected, Drive mounted")
except ImportError:
    IN_COLAB = False
    print("‚úÖ Local environment detected")

# Clone/update repository
if IN_COLAB:
    REPO_PATH = Path('/content/tri-objective-robust-xai-medimg')
    if not REPO_PATH.exists():
        !git clone https://github.com/viraj1011JAIN/tri-objective-robust-xai-medimg.git {REPO_PATH}
        print("‚úÖ Repository cloned")
    else:
        os.chdir(REPO_PATH)
        !git pull origin main
        print("‚úÖ Repository updated")

    os.chdir(REPO_PATH)
    sys.path.insert(0, str(REPO_PATH))
    PROJECT_ROOT = REPO_PATH
else:
    PROJECT_ROOT = Path.cwd().parent
    sys.path.insert(0, str(PROJECT_ROOT))

print(f"üìÅ Project root: {PROJECT_ROOT}")

In [None]:
# ============================================================================
# CELL 2: INSTALL DEPENDENCIES
# ============================================================================
!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu121
!pip install -q timm albumentations scikit-learn pandas matplotlib seaborn tqdm mlflow
!pip install -q plotly kaleido scipy statsmodels
print("‚úÖ Dependencies installed")

In [None]:
# ============================================================================
# CELL 3: IMPORTS
# ============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from tqdm.auto import tqdm
from pathlib import Path
import json
import time
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any
import warnings
warnings.filterwarnings('ignore')

# Albumentations for transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image

# Metrics
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, f1_score,
    confusion_matrix, roc_auc_score
)
from scipy import stats

# Project imports - Attacks
from src.attacks.fgsm import FGSM, FGSMConfig
from src.attacks.pgd import PGD, PGDConfig
from src.attacks.cw import CarliniWagner, CWConfig

# Project imports - Data & Model
from src.datasets.isic import ISICDataset
from src.models.build import build_model
from src.utils.reproducibility import set_global_seed

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"‚úÖ Using device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    # Enable TF32 for A100
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True

In [None]:
# ============================================================================
# CELL 4: CONFIGURATION
# ============================================================================
print("=" * 70)
print("CONFIGURATION")
print("=" * 70)

CONFIG = {
    # Data paths (Google Drive)
    'data_root': Path('/content/drive/MyDrive/data/data/isic_2018'),
    'checkpoint_dir': Path('/content/drive/MyDrive/checkpoints/baseline'),
    'results_dir': Path('/content/drive/MyDrive/results/phase4'),

    # Model
    'model_name': 'resnet50',
    'num_classes': 7,

    # Evaluation settings
    'batch_size': 64,
    'num_workers': 4,
    'image_size': 224,

    # Seeds to evaluate
    'seeds': [42, 123, 456],

    # Attack configurations
    'epsilons': [2/255, 4/255, 8/255],
    'pgd_steps': 40,
    'cw_iterations': 100,

    # Class names
    'class_names': ['AKIEC', 'BCC', 'BKL', 'DF', 'MEL', 'NV', 'VASC'],
}

# Class descriptions for visualization labels
CLASS_DESCRIPTIONS = {
    'AKIEC': 'Actinic Keratoses (pre-cancerous)',
    'BCC': 'Basal Cell Carcinoma (cancerous)',
    'BKL': 'Benign Keratosis (benign)',
    'DF': 'Dermatofibroma (benign)',
    'MEL': 'Melanoma (malignant)',
    'NV': 'Melanocytic Nevi (moles)',
    'VASC': 'Vascular Lesions (blood vessel)',
}

# ImageNet normalization
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# Create output directories
CONFIG['results_dir'].mkdir(parents=True, exist_ok=True)
(CONFIG['results_dir'] / 'figures').mkdir(exist_ok=True)

print(f"üìä Model: {CONFIG['model_name']}")
print(f"üìä Seeds: {CONFIG['seeds']}")
print(f"üìä Epsilons: {[f'{int(e*255)}/255' for e in CONFIG['epsilons']]}")
print(f"üìä Batch size: {CONFIG['batch_size']}")
print(f"üìÅ Data: {CONFIG['data_root']}")
print(f"üìÅ Checkpoints: {CONFIG['checkpoint_dir']}")
print(f"üìÅ Results: {CONFIG['results_dir']}")

In [None]:
# ============================================================================
# CELL 5: DATA PREPARATION
# ============================================================================
print("=" * 70)
print("DATA PREPARATION")
print("=" * 70)

# Load and fix metadata
metadata_path = CONFIG['data_root'] / 'metadata.csv'
print(f"üìÑ Loading metadata: {metadata_path}")

df = pd.read_csv(metadata_path)
print(f"   Total samples: {len(df)}")

# Fix path separators
if 'image_path' in df.columns:
    df['image_path'] = df['image_path'].str.replace('\\', '/', regex=False)
    print("   ‚úÖ Fixed path separators")

# Save fixed metadata
fixed_path = CONFIG['data_root'] / 'metadata_fixed.csv'
df.to_csv(fixed_path, index=False)

# Show test split info
test_df = df[df['split'] == 'test']
print(f"\nüìä Test samples: {len(test_df)}")
print(f"üìä Test class distribution:")
print(test_df['label'].value_counts())

In [None]:
# ============================================================================
# CELL 6: CREATE TEST DATASET
# ============================================================================
print("=" * 70)
print("CREATING TEST DATASET")
print("=" * 70)

# Test transforms - NO normalization (attacks need [0,1] range)
test_transforms = A.Compose([
    A.Resize(CONFIG['image_size'], CONFIG['image_size']),
    A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),  # Keep in [0, 1]
    ToTensorV2()
])

# Create test dataset
test_dataset = ISICDataset(
    root=str(CONFIG['data_root']),
    split='test',
    transforms=test_transforms,
    csv_path=str(fixed_path),
    image_column='image_path',
    label_column='label'
)

# Create dataloader
test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=True
)

print(f"‚úÖ Test samples: {len(test_dataset)}")
print(f"‚úÖ Batches: {len(test_loader)}")
print(f"‚úÖ Classes: {CONFIG['class_names']}")

In [None]:
# ============================================================================
# CELL 7: HELPER FUNCTIONS
# ============================================================================

def get_normalizer(device):
    """Create ImageNet normalization function."""
    mean = torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1).to(device)
    std = torch.tensor(IMAGENET_STD).view(1, 3, 1, 1).to(device)
    
    def normalize(x):
        return (x - mean) / std
    return normalize

def evaluate_clean(model, dataloader, device, normalize_fn):
    """Evaluate model on clean data."""
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Clean Eval', leave=False):
            if len(batch) == 3:
                images, labels, _ = batch
            else:
                images, labels = batch
            
            images = images.to(device)
            labels = labels.to(device)
            
            # Normalize and predict
            outputs = model(normalize_fn(images))
            probs = F.softmax(outputs, dim=1)
            preds = outputs.argmax(dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    
    return {
        'accuracy': accuracy_score(all_labels, all_preds) * 100,
        'balanced_accuracy': balanced_accuracy_score(all_labels, all_preds) * 100,
        'f1_macro': f1_score(all_labels, all_preds, average='macro') * 100,
        'auroc': roc_auc_score(all_labels, all_probs, multi_class='ovr') * 100,
        'predictions': all_preds,
        'labels': all_labels,
        'probs': all_probs
    }

def evaluate_attack(model, dataloader, device, normalize_fn, attack_fn, desc='Attack'):
    """Evaluate model under adversarial attack."""
    model.eval()
    all_clean_preds, all_adv_preds, all_labels = [], [], []
    
    for batch in tqdm(dataloader, desc=desc, leave=False):
        if len(batch) == 3:
            images, labels, _ = batch
        else:
            images, labels = batch
        
        images = images.to(device)
        labels = labels.to(device)
        
        # Clean predictions
        with torch.no_grad():
            clean_preds = model(normalize_fn(images)).argmax(dim=1)
        
        # Generate adversarial examples
        x_adv = attack_fn(images, labels)
        
        # Adversarial predictions
        with torch.no_grad():
            adv_preds = model(normalize_fn(x_adv)).argmax(dim=1)
        
        all_clean_preds.extend(clean_preds.cpu().numpy())
        all_adv_preds.extend(adv_preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    all_clean_preds = np.array(all_clean_preds)
    all_adv_preds = np.array(all_adv_preds)
    all_labels = np.array(all_labels)
    
    clean_acc = accuracy_score(all_labels, all_clean_preds) * 100
    robust_acc = accuracy_score(all_labels, all_adv_preds) * 100
    
    # Per-class robust accuracy
    cm = confusion_matrix(all_labels, all_adv_preds)
    per_class_acc = (cm.diagonal() / cm.sum(axis=1)) * 100
    
    return {
        'clean_accuracy': clean_acc,
        'robust_accuracy': robust_acc,
        'accuracy_drop': clean_acc - robust_acc,
        'per_class_robust_acc': dict(zip(CONFIG['class_names'], per_class_acc)),
        'confusion_matrix': cm
    }

print("‚úÖ Helper functions defined")

In [None]:
# ============================================================================
# CELL 8: LOAD MODELS AND VERIFY CLEAN ACCURACY
# ============================================================================
print("=" * 70)
print("LOADING MODELS & VERIFYING CLEAN ACCURACY")
print("=" * 70)

normalize = get_normalizer(device)
models = {}
clean_results = {}

for seed in CONFIG['seeds']:
    print(f"\nüì• Loading seed {seed}...")
    
    # Find checkpoint
    checkpoint_path = CONFIG['checkpoint_dir'] / f'seed_{seed}' / 'best.pt'
    if not checkpoint_path.exists():
        print(f"   ‚ùå Checkpoint not found: {checkpoint_path}")
        continue
    
    # Load model
    model = build_model(
        architecture=CONFIG['model_name'],
        num_classes=CONFIG['num_classes'],
        pretrained=False
    ).to(device)
    
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    models[seed] = model
    
    # Verify clean accuracy
    result = evaluate_clean(model, test_loader, device, normalize)
    clean_results[seed] = result
    
    print(f"   ‚úÖ Clean Accuracy: {result['accuracy']:.2f}%")
    print(f"   ‚úÖ AUROC: {result['auroc']:.2f}%")

# Summary
print(f"\n{'='*70}")
print(f"üìä CLEAN ACCURACY SUMMARY")
print(f"{'='*70}")
accs = [clean_results[s]['accuracy'] for s in clean_results]
print(f"Mean: {np.mean(accs):.2f}% ¬± {np.std(accs):.2f}%")
for seed in clean_results:
    print(f"   Seed {seed}: {clean_results[seed]['accuracy']:.2f}%")

In [None]:
# ============================================================================
# CELL 9: RUN ADVERSARIAL EVALUATION - ALL ATTACKS
# ============================================================================
print("=" * 70)
print("ADVERSARIAL EVALUATION - ALL ATTACKS")
print("=" * 70)
print(f"‚è±Ô∏è  Start time: {datetime.now().strftime('%H:%M:%S')}")

all_results = {}

for seed in CONFIG['seeds']:
    if seed not in models:
        continue
    
    print(f"\n{'='*70}")
    print(f"SEED {seed}")
    print(f"{'='*70}")
    
    model = models[seed]
    seed_results = {'clean': clean_results[seed]}
    
    # ==================== FGSM ====================
    print("\nüî• FGSM Attacks:")
    for eps in CONFIG['epsilons']:
        eps_str = f"{int(eps*255)}/255"
        
        # Create FGSM attack
        fgsm_config = FGSMConfig(
            epsilon=eps,
            clip_min=0.0,
            clip_max=1.0,
            targeted=False
        )
        fgsm = FGSM(fgsm_config)
        
        def fgsm_attack_fn(x, y):
            return fgsm.generate(model, x, y, loss_fn=nn.CrossEntropyLoss(), normalize=normalize)
        
        result = evaluate_attack(model, test_loader, device, normalize, fgsm_attack_fn, f"FGSM Œµ={eps_str}")
        seed_results[f'fgsm_{eps_str}'] = result
        print(f"   Œµ={eps_str}: {result['robust_accuracy']:.2f}% (drop: {result['accuracy_drop']:.2f}%)")
    
    # ==================== PGD ====================
    print("\nüî• PGD Attacks:")
    for eps in CONFIG['epsilons']:
        eps_str = f"{int(eps*255)}/255"
        step_size = eps / 4
        
        # Create PGD attack
        pgd_config = PGDConfig(
            epsilon=eps,
            num_steps=CONFIG['pgd_steps'],
            step_size=step_size,
            random_start=True,
            clip_min=0.0,
            clip_max=1.0,
            targeted=False
        )
        pgd = PGD(pgd_config)
        
        def pgd_attack_fn(x, y):
            return pgd.generate(model, x, y, loss_fn=nn.CrossEntropyLoss(), normalize=normalize)
        
        result = evaluate_attack(model, test_loader, device, normalize, pgd_attack_fn, f"PGD Œµ={eps_str}")
        seed_results[f'pgd_{eps_str}'] = result
        print(f"   Œµ={eps_str}: {result['robust_accuracy']:.2f}% (drop: {result['accuracy_drop']:.2f}%)")
    
    # ==================== C&W ====================
    print("\nüî• Carlini-Wagner Attack:")
    cw_config = CWConfig(
        confidence=0.0,
        learning_rate=0.01,
        max_iterations=CONFIG['cw_iterations'],
        binary_search_steps=5,
        clip_min=0.0,
        clip_max=1.0,
        targeted=False
    )
    cw = CarliniWagner(cw_config)
    
    def cw_attack_fn(x, y):
        return cw.generate(model, x, y, normalize=normalize)
    
    result = evaluate_attack(model, test_loader, device, normalize, cw_attack_fn, "C&W L2")
    seed_results['cw'] = result
    print(f"   C&W: {result['robust_accuracy']:.2f}% (drop: {result['accuracy_drop']:.2f}%)")
    
    all_results[seed] = seed_results

print(f"\n‚è±Ô∏è  End time: {datetime.now().strftime('%H:%M:%S')}")
print("‚úÖ Evaluation complete!")

In [None]:
# ============================================================================
# CELL 10: ADVANCED RESULTS TABLE WITH STYLED OUTPUT
# ============================================================================
print("=" * 70)
print("üìä COMPREHENSIVE ADVERSARIAL EVALUATION RESULTS")
print("=" * 70)

# ===================== BUILD COMPREHENSIVE RESULTS TABLE =====================
rows = []
attacks = ['clean'] + [f'fgsm_{int(e*255)}/255' for e in CONFIG['epsilons']] + \
          [f'pgd_{int(e*255)}/255' for e in CONFIG['epsilons']] + ['cw']

for attack in attacks:
    if attack == 'clean':
        accs = [all_results[s]['clean']['accuracy'] for s in all_results]
        attack_name = 'üü¢ Clean (No Attack)'
        attack_type = 'Baseline'
        severity = '‚Äî'
    elif attack == 'cw':
        accs = [all_results[s]['cw']['robust_accuracy'] for s in all_results]
        attack_name = 'üî¥ Carlini-Wagner L2'
        attack_type = 'Optimization'
        severity = 'Maximum'
    elif 'fgsm' in attack:
        accs = [all_results[s][attack]['robust_accuracy'] for s in all_results]
        eps = int(attack.split('_')[1].split('/')[0])
        attack_name = f'üîµ FGSM (Œµ={eps}/255)'
        attack_type = 'Gradient (1-step)'
        severity = 'Weak' if eps == 2 else 'Medium' if eps == 4 else 'Strong'
    else:  # pgd
        accs = [all_results[s][attack]['robust_accuracy'] for s in all_results]
        eps = int(attack.split('_')[1].split('/')[0])
        attack_name = f'üü† PGD-40 (Œµ={eps}/255)'
        attack_type = 'Iterative (40-step)'
        severity = 'Weak' if eps == 2 else 'Medium' if eps == 4 else 'Strong'
    
    mean_acc = np.mean(accs)
    clean_acc = np.mean([all_results[s]['clean']['accuracy'] for s in all_results])
    drop = clean_acc - mean_acc if attack != 'clean' else 0
    
    rows.append({
        'Attack': attack_name,
        'Type': attack_type,
        'Severity': severity,
        'Mean Acc (%)': mean_acc,
        'Std (%)': np.std(accs),
        'Drop (%)': drop,
        'Seed 42': accs[0] if len(accs) > 0 else np.nan,
        'Seed 123': accs[1] if len(accs) > 1 else np.nan,
        'Seed 456': accs[2] if len(accs) > 2 else np.nan,
    })

results_df = pd.DataFrame(rows)

# ===================== STYLED TABLE DISPLAY =====================
from IPython.display import display, HTML

def color_accuracy(val):
    """Color code accuracy values."""
    if pd.isna(val): return ''
    if val >= 70: return 'background-color: #a8e6cf; color: black'
    elif val >= 40: return 'background-color: #ffd3b6; color: black'
    else: return 'background-color: #ffaaa5; color: black'

def color_drop(val):
    """Color code accuracy drop."""
    if pd.isna(val) or val == 0: return ''
    if val <= 20: return 'background-color: #dcedc1; color: black'
    elif val <= 50: return 'background-color: #ffeead; color: black'
    else: return 'background-color: #ff6f69; color: white'

# Format for display
display_df = results_df.copy()
display_df['Mean Acc (%)'] = display_df['Mean Acc (%)'].apply(lambda x: f'{x:.2f}')
display_df['Std (%)'] = display_df['Std (%)'].apply(lambda x: f'{x:.2f}')
display_df['Drop (%)'] = display_df['Drop (%)'].apply(lambda x: f'{x:.2f}' if x > 0 else '‚Äî')
display_df['Seed 42'] = display_df['Seed 42'].apply(lambda x: f'{x:.2f}')
display_df['Seed 123'] = display_df['Seed 123'].apply(lambda x: f'{x:.2f}')
display_df['Seed 456'] = display_df['Seed 456'].apply(lambda x: f'{x:.2f}')

print("\nüìã DETAILED RESULTS TABLE:")
print(display_df.to_string(index=False))

# ===================== EXECUTIVE SUMMARY STATISTICS =====================
clean_mean = np.mean([all_results[s]['clean']['accuracy'] for s in all_results])
fgsm8_mean = np.mean([all_results[s]['fgsm_8/255']['robust_accuracy'] for s in all_results])
pgd8_mean = np.mean([all_results[s]['pgd_8/255']['robust_accuracy'] for s in all_results])
cw_mean = np.mean([all_results[s]['cw']['robust_accuracy'] for s in all_results])

print("\n" + "=" * 70)
print("üìà EXECUTIVE SUMMARY - KEY STATISTICS")
print("=" * 70)
summary_table = f"""
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ  BASELINE MODEL VULNERABILITY ASSESSMENT                          ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ  Clean Accuracy (Baseline):     {clean_mean:>6.2f}%                         ‚îÇ
‚îÇ  ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ    ‚îÇ
‚îÇ  FGSM Œµ=8/255:                  {fgsm8_mean:>6.2f}%  (‚Üì {clean_mean-fgsm8_mean:>5.2f}% drop)         ‚îÇ
‚îÇ  PGD-40 Œµ=8/255:                {pgd8_mean:>6.2f}%  (‚Üì {clean_mean-pgd8_mean:>5.2f}% drop)         ‚îÇ
‚îÇ  Carlini-Wagner L2:             {cw_mean:>6.2f}%  (‚Üì {clean_mean-cw_mean:>5.2f}% drop)         ‚îÇ
‚îÇ  ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ    ‚îÇ
‚îÇ  Average Robustness Degradation: {np.mean([clean_mean-fgsm8_mean, clean_mean-pgd8_mean, clean_mean-cw_mean]):>5.2f}% under strong attacks  ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
"""
print(summary_table)

# Save to CSV with full precision
results_df.to_csv(CONFIG['results_dir'] / 'adversarial_results.csv', index=False, float_format='%.4f')
print(f"‚úÖ Results saved to: {CONFIG['results_dir'] / 'adversarial_results.csv'}")

In [None]:
# ============================================================================
# CELL 11: ADVANCED VISUALIZATION - PUBLICATION-QUALITY ROBUSTNESS ANALYSIS
# ============================================================================
print("=" * 70)
print("üìä GENERATING ADVANCED VISUALIZATIONS")
print("=" * 70)

# ===================== STYLE CONFIGURATION =====================
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({
    'font.family': 'DejaVu Sans',
    'font.size': 11,
    'axes.titlesize': 14,
    'axes.titleweight': 'bold',
    'axes.labelsize': 12,
    'axes.labelweight': 'bold',
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'legend.framealpha': 0.9,
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'axes.spines.top': False,
    'axes.spines.right': False,
})

# Custom color palette - professional & accessible
COLORS = {
    'clean': '#2E8B57',      # Sea green
    'fgsm': '#4169E1',       # Royal blue
    'pgd': '#DC143C',        # Crimson
    'cw': '#9400D3',         # Dark violet
    'seeds': ['#FF6B6B', '#4ECDC4', '#45B7D1'],  # Coral, Teal, Sky blue
    'gradient': ['#00C853', '#FFD600', '#FF6D00', '#D50000']  # Green to red
}

# ===================== FIGURE 1: COMPREHENSIVE ROBUSTNESS DASHBOARD =====================
fig = plt.figure(figsize=(18, 14))
gs = fig.add_gridspec(3, 3, hspace=0.35, wspace=0.3, height_ratios=[1.2, 1, 1])

# ------ Panel A: Robustness Degradation Curves (FGSM vs PGD) ------
ax1 = fig.add_subplot(gs[0, :2])

eps_vals = [0] + [e*255 for e in CONFIG['epsilons']]
markers = ['o', 's', '^']
linestyles = ['-', '--', ':']

# Plot for each seed
for i, seed in enumerate(CONFIG['seeds']):
    if seed not in all_results:
        continue
    
    # FGSM curve
    fgsm_accs = [all_results[seed]['clean']['accuracy']]
    for eps in CONFIG['epsilons']:
        fgsm_accs.append(all_results[seed][f'fgsm_{int(eps*255)}/255']['robust_accuracy'])
    ax1.plot(eps_vals, fgsm_accs, marker=markers[i], linestyle='-', 
             color=COLORS['fgsm'], linewidth=2.5, markersize=10, 
             label=f'FGSM (Seed {seed})', alpha=0.7 + 0.1*i)
    
    # PGD curve
    pgd_accs = [all_results[seed]['clean']['accuracy']]
    for eps in CONFIG['epsilons']:
        pgd_accs.append(all_results[seed][f'pgd_{int(eps*255)}/255']['robust_accuracy'])
    ax1.plot(eps_vals, pgd_accs, marker=markers[i], linestyle='--',
             color=COLORS['pgd'], linewidth=2.5, markersize=10,
             label=f'PGD-40 (Seed {seed})', alpha=0.7 + 0.1*i)

# Add mean curves with confidence band
fgsm_means = [np.mean([all_results[s]['clean']['accuracy'] for s in all_results])]
fgsm_stds = [np.std([all_results[s]['clean']['accuracy'] for s in all_results])]
pgd_means = [np.mean([all_results[s]['clean']['accuracy'] for s in all_results])]
pgd_stds = [np.std([all_results[s]['clean']['accuracy'] for s in all_results])]

for eps in CONFIG['epsilons']:
    fgsm_vals = [all_results[s][f'fgsm_{int(eps*255)}/255']['robust_accuracy'] for s in all_results]
    fgsm_means.append(np.mean(fgsm_vals))
    fgsm_stds.append(np.std(fgsm_vals))
    pgd_vals = [all_results[s][f'pgd_{int(eps*255)}/255']['robust_accuracy'] for s in all_results]
    pgd_means.append(np.mean(pgd_vals))
    pgd_stds.append(np.std(pgd_vals))

# Shade confidence bands
ax1.fill_between(eps_vals, np.array(fgsm_means)-np.array(fgsm_stds), 
                 np.array(fgsm_means)+np.array(fgsm_stds), alpha=0.15, color=COLORS['fgsm'])
ax1.fill_between(eps_vals, np.array(pgd_means)-np.array(pgd_stds),
                 np.array(pgd_means)+np.array(pgd_stds), alpha=0.15, color=COLORS['pgd'])

# Styling
ax1.set_xlabel('Perturbation Budget Œµ (√ó255)', fontsize=13, fontweight='bold')
ax1.set_ylabel('Accuracy (%)', fontsize=13, fontweight='bold')
ax1.set_title('A) Adversarial Robustness Degradation: FGSM vs PGD-40', fontsize=15, fontweight='bold', pad=15)
ax1.set_xticks(eps_vals)
ax1.set_xticklabels(['0\n(Clean)', '2', '4', '8'])
ax1.set_ylim(0, 100)
ax1.axhline(y=50, color='gray', linestyle=':', alpha=0.5, label='Random Guess (50%)')
ax1.legend(loc='upper right', ncol=2, frameon=True, fancybox=True, shadow=True)
ax1.grid(True, alpha=0.3, linestyle='--')

# Add annotations for key drops
clean_mean = np.mean([all_results[s]['clean']['accuracy'] for s in all_results])
pgd8_mean = np.mean([all_results[s]['pgd_8/255']['robust_accuracy'] for s in all_results])
ax1.annotate(f'‚Üì{clean_mean - pgd8_mean:.1f}%', xy=(8, pgd8_mean), xytext=(8.5, pgd8_mean + 15),
             fontsize=12, fontweight='bold', color=COLORS['pgd'],
             arrowprops=dict(arrowstyle='->', color=COLORS['pgd'], lw=2))

# ------ Panel B: Attack Severity Radar Chart ------
ax2 = fig.add_subplot(gs[0, 2], projection='polar')

# Radar data for strongest attacks (Œµ=8/255)
categories = ['FGSM\nŒµ=2/255', 'FGSM\nŒµ=8/255', 'PGD\nŒµ=2/255', 'PGD\nŒµ=8/255', 'C&W\nL2']
attack_keys = ['fgsm_2/255', 'fgsm_8/255', 'pgd_2/255', 'pgd_8/255', 'cw']

values = []
for key in attack_keys:
    accs = [all_results[s][key]['robust_accuracy'] for s in all_results]
    values.append(np.mean(accs))

# Normalize to 0-100 and invert (higher = more vulnerable)
vulnerability = [100 - v for v in values]
vulnerability.append(vulnerability[0])  # Close the polygon

angles = np.linspace(0, 2*np.pi, len(categories), endpoint=False).tolist()
angles.append(angles[0])

ax2.plot(angles, vulnerability, 'o-', linewidth=2.5, color='#E74C3C', markersize=8)
ax2.fill(angles, vulnerability, alpha=0.25, color='#E74C3C')
ax2.set_xticks(angles[:-1])
ax2.set_xticklabels(categories, size=9)
ax2.set_ylim(0, 100)
ax2.set_title('B) Vulnerability Profile\n(Higher = More Vulnerable)', fontsize=13, fontweight='bold', pad=20)

# ------ Panel C: Per-Seed Comparison ------
ax3 = fig.add_subplot(gs[1, 0])

seed_data = []
for seed in CONFIG['seeds']:
    seed_data.append({
        'Seed': str(seed),
        'Clean': all_results[seed]['clean']['accuracy'],
        'FGSM-8': all_results[seed]['fgsm_8/255']['robust_accuracy'],
        'PGD-8': all_results[seed]['pgd_8/255']['robust_accuracy'],
        'C&W': all_results[seed]['cw']['robust_accuracy']
    })

seed_df = pd.DataFrame(seed_data)
x = np.arange(len(CONFIG['seeds']))
width = 0.2

bars1 = ax3.bar(x - 1.5*width, seed_df['Clean'], width, label='Clean', color=COLORS['clean'], edgecolor='white', linewidth=1.5)
bars2 = ax3.bar(x - 0.5*width, seed_df['FGSM-8'], width, label='FGSM Œµ=8', color=COLORS['fgsm'], edgecolor='white', linewidth=1.5)
bars3 = ax3.bar(x + 0.5*width, seed_df['PGD-8'], width, label='PGD Œµ=8', color=COLORS['pgd'], edgecolor='white', linewidth=1.5)
bars4 = ax3.bar(x + 1.5*width, seed_df['C&W'], width, label='C&W', color=COLORS['cw'], edgecolor='white', linewidth=1.5)

ax3.set_ylabel('Accuracy (%)', fontweight='bold')
ax3.set_xlabel('Random Seed', fontweight='bold')
ax3.set_title('C) Cross-Seed Consistency', fontsize=13, fontweight='bold')
ax3.set_xticks(x)
ax3.set_xticklabels([f'Seed {s}' for s in CONFIG['seeds']])
ax3.legend(loc='upper right', fontsize=9)
ax3.set_ylim(0, 100)
ax3.grid(True, alpha=0.3, axis='y')

# Add value labels
for bars in [bars1, bars2, bars3, bars4]:
    for bar in bars:
        height = bar.get_height()
        ax3.annotate(f'{height:.0f}', xy=(bar.get_x() + bar.get_width()/2, height),
                     xytext=(0, 3), textcoords='offset points', ha='center', va='bottom', fontsize=8, fontweight='bold')

# ------ Panel D: Accuracy Drop Waterfall ------
ax4 = fig.add_subplot(gs[1, 1])

attacks_order = ['FGSM\nŒµ=2', 'FGSM\nŒµ=4', 'FGSM\nŒµ=8', 'PGD\nŒµ=2', 'PGD\nŒµ=4', 'PGD\nŒµ=8', 'C&W']
attack_keys_order = ['fgsm_2/255', 'fgsm_4/255', 'fgsm_8/255', 'pgd_2/255', 'pgd_4/255', 'pgd_8/255', 'cw']

drops = []
for key in attack_keys_order:
    accs = [all_results[s][key]['robust_accuracy'] for s in all_results]
    drop = clean_mean - np.mean(accs)
    drops.append(drop)

colors_drop = [COLORS['fgsm']]*3 + [COLORS['pgd']]*3 + [COLORS['cw']]
bars = ax4.barh(attacks_order, drops, color=colors_drop, edgecolor='white', linewidth=1.5, alpha=0.85)

ax4.set_xlabel('Accuracy Drop (%)', fontweight='bold')
ax4.set_title('D) Robustness Degradation by Attack', fontsize=13, fontweight='bold')
ax4.axvline(x=np.mean(drops), color='red', linestyle='--', linewidth=2, label=f'Mean Drop: {np.mean(drops):.1f}%')
ax4.legend(loc='lower right')
ax4.grid(True, alpha=0.3, axis='x')

# Add value labels
for bar, drop in zip(bars, drops):
    ax4.annotate(f'{drop:.1f}%', xy=(bar.get_width(), bar.get_y() + bar.get_height()/2),
                 xytext=(5, 0), textcoords='offset points', va='center', fontsize=10, fontweight='bold')

# ------ Panel E: Statistical Significance ------
ax5 = fig.add_subplot(gs[1, 2])

# Calculate p-values (t-test: clean vs each attack)
from scipy import stats as scipy_stats

p_values = []
attack_labels = ['FGSM-8', 'PGD-8', 'C&W']
for key in ['fgsm_8/255', 'pgd_8/255', 'cw']:
    clean_accs = [all_results[s]['clean']['accuracy'] for s in all_results]
    attack_accs = [all_results[s][key]['robust_accuracy'] for s in all_results]
    _, p = scipy_stats.ttest_rel(clean_accs, attack_accs)
    p_values.append(p)

colors_p = ['green' if p < 0.05 else 'orange' for p in p_values]
bars = ax5.barh(attack_labels, [-np.log10(p) for p in p_values], color=colors_p, edgecolor='white', linewidth=1.5)
ax5.axvline(x=-np.log10(0.05), color='red', linestyle='--', linewidth=2, label='p=0.05 threshold')
ax5.set_xlabel('-log‚ÇÅ‚ÇÄ(p-value)', fontweight='bold')
ax5.set_title('E) Statistical Significance\n(Paired t-test vs Clean)', fontsize=13, fontweight='bold')
ax5.legend(loc='lower right')

for bar, p in zip(bars, p_values):
    ax5.annotate(f'p={p:.4f}', xy=(bar.get_width(), bar.get_y() + bar.get_height()/2),
                 xytext=(5, 0), textcoords='offset points', va='center', fontsize=10, fontweight='bold')

# ------ Panel F: Key Findings Summary ------
ax6 = fig.add_subplot(gs[2, :])
ax6.axis('off')

findings_text = f"""
‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
‚ïë                                    üî¨ KEY FINDINGS: ADVERSARIAL ROBUSTNESS ANALYSIS                                  ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë                                                                                                                      ‚ïë
‚ïë  üìä BASELINE MODEL PERFORMANCE                           ‚îÇ  ‚ö†Ô∏è  VULNERABILITY ASSESSMENT                            ‚ïë
‚ïë  ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ ‚îÇ ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ  ‚ïë
‚ïë  ‚Ä¢ Clean Accuracy: {clean_mean:.2f}% (Mean ¬± {np.std([all_results[s]['clean']['accuracy'] for s in all_results]):.2f}%)         ‚îÇ  ‚Ä¢ Single-step FGSM reduces accuracy by {clean_mean - fgsm8_mean:.1f}% at Œµ=8/255      ‚ïë
‚ïë  ‚Ä¢ Model: ResNet-50 (ImageNet pretrained)                ‚îÇ  ‚Ä¢ Iterative PGD-40 causes {clean_mean - pgd8_mean:.1f}% drop at Œµ=8/255          ‚ïë
‚ïë  ‚Ä¢ Dataset: ISIC 2018 (7 dermoscopy classes)             ‚îÇ  ‚Ä¢ Optimization-based C&W achieves {clean_mean - cw_mean:.1f}% degradation    ‚ïë
‚ïë  ‚Ä¢ Seeds evaluated: {', '.join(map(str, CONFIG['seeds']))}                        ‚îÇ  ‚Ä¢ Maximum observed drop: {max([clean_mean - fgsm8_mean, clean_mean - pgd8_mean, clean_mean - cw_mean]):.1f}%                           ‚ïë
‚ïë                                                          ‚îÇ                                                           ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë                                                                                                                      ‚ïë
‚ïë  üí° IMPLICATIONS FOR MEDICAL IMAGING                                                                                 ‚ïë
‚ïë  ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ ‚ïë
‚ïë  1. Standard CNNs are HIGHLY VULNERABLE to adversarial perturbations - poses serious risk in clinical deployment     ‚ïë
‚ïë  2. Even small perturbations (Œµ=2/255) cause measurable accuracy degradation - imperceptible to human eye            ‚ïë
‚ïë  3. Cross-seed consistency shows vulnerability is SYSTEMATIC, not random - fundamental model weakness                ‚ïë
‚ïë  4. ADVERSARIAL TRAINING (Phase 5) is ESSENTIAL before clinical deployment                                          ‚ïë
‚ïë                                                                                                                      ‚ïë
‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù
"""

ax6.text(0.5, 0.5, findings_text, transform=ax6.transAxes, fontsize=10, 
         fontfamily='monospace', verticalalignment='center', horizontalalignment='center',
         bbox=dict(boxstyle='round,pad=0.5', facecolor='#f8f9fa', edgecolor='#dee2e6', linewidth=2))

plt.suptitle('Phase 4: Adversarial Robustness Evaluation ‚Äî ResNet-50 Baseline on ISIC 2018', 
             fontsize=18, fontweight='bold', y=0.98)
plt.savefig(CONFIG['results_dir'] / 'figures' / 'robustness_dashboard.png', dpi=300, bbox_inches='tight', 
            facecolor='white', edgecolor='none')
plt.show()
print("‚úÖ Saved: robustness_dashboard.png (300 DPI)")

In [None]:
# ============================================================================
# CELL 12: ADVANCED PER-CLASS VULNERABILITY ANALYSIS
# ============================================================================
print("=" * 70)
print("üìä PER-CLASS VULNERABILITY HEATMAP & ANALYSIS")
print("=" * 70)

# ===================== COMPREHENSIVE HEATMAP DATA =====================
attacks_to_show = [
    ('clean', 'Clean\n(No Attack)'),
    ('fgsm_2/255', 'FGSM\nŒµ=2/255'),
    ('fgsm_4/255', 'FGSM\nŒµ=4/255'),
    ('fgsm_8/255', 'FGSM\nŒµ=8/255'),
    ('pgd_2/255', 'PGD\nŒµ=2/255'),
    ('pgd_4/255', 'PGD\nŒµ=4/255'),
    ('pgd_8/255', 'PGD\nŒµ=8/255'),
    ('cw', 'C&W\nL2'),
]

# Build comprehensive heatmap matrix
heatmap_matrix = []
for attack_key, attack_label in attacks_to_show:
    row = []
    for cls in CONFIG['class_names']:
        if attack_key == 'clean':
            # For clean, use predictions vs labels to compute per-class accuracy
            all_preds = np.concatenate([all_results[s]['clean']['predictions'] for s in all_results])
            all_labels = np.concatenate([all_results[s]['clean']['labels'] for s in all_results])
            cls_idx = CONFIG['class_names'].index(cls)
            cls_mask = all_labels == cls_idx
            if cls_mask.sum() > 0:
                cls_acc = (all_preds[cls_mask] == all_labels[cls_mask]).mean() * 100
            else:
                cls_acc = 0
            row.append(cls_acc)
        else:
            accs = [all_results[s][attack_key]['per_class_robust_acc'][cls] for s in all_results]
            row.append(np.mean(accs))
    heatmap_matrix.append(row)

heatmap_array = np.array(heatmap_matrix)

# ===================== FIGURE: MULTI-PANEL VULNERABILITY ANALYSIS =====================
fig = plt.figure(figsize=(18, 12))
gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.25, height_ratios=[1.3, 1])

# ------ Panel A: Full Vulnerability Heatmap ------
ax1 = fig.add_subplot(gs[0, :])

# Create custom colormap (green = robust, red = vulnerable)
from matplotlib.colors import LinearSegmentedColormap
colors_cmap = ['#D32F2F', '#FF5722', '#FF9800', '#FFC107', '#CDDC39', '#8BC34A', '#4CAF50', '#2E7D32']
cmap = LinearSegmentedColormap.from_list('vulnerability', colors_cmap)

im = ax1.imshow(heatmap_array, cmap=cmap, aspect='auto', vmin=0, vmax=100)

# Styling
ax1.set_xticks(range(len(CONFIG['class_names'])))
ax1.set_xticklabels([f'{name}\n({desc.split("(")[0].strip()[:15]})' 
                      for name, desc in zip(CONFIG['class_names'], 
                      [CLASS_DESCRIPTIONS.get(c, c) if 'CLASS_DESCRIPTIONS' in dir() else c for c in CONFIG['class_names']])],
                    fontsize=10, fontweight='bold')
ax1.set_yticks(range(len(attacks_to_show)))
ax1.set_yticklabels([label for _, label in attacks_to_show], fontsize=10)

# Add text annotations with adaptive coloring
for i in range(len(attacks_to_show)):
    for j in range(len(CONFIG['class_names'])):
        val = heatmap_array[i, j]
        text_color = 'white' if val < 40 else 'black'
        ax1.text(j, i, f'{val:.1f}%', ha='center', va='center', 
                 color=text_color, fontsize=9, fontweight='bold')

ax1.set_title('A) Per-Class Robustness Under All Attack Conditions\n(Mean across 3 seeds)', 
              fontsize=14, fontweight='bold', pad=15)

# Colorbar with labels
cbar = plt.colorbar(im, ax=ax1, shrink=0.8, pad=0.02)
cbar.set_label('Accuracy (%)', fontsize=11, fontweight='bold')
cbar.ax.set_yticks([0, 25, 50, 75, 100])
cbar.ax.set_yticklabels(['0%\n(Failed)', '25%', '50%', '75%', '100%\n(Robust)'])

# Add class vulnerability ranking annotation
most_vulnerable_idx = np.argmin(heatmap_array[-2, :])  # PGD Œµ=8/255
most_robust_idx = np.argmax(heatmap_array[-2, :])
ax1.annotate('Most Vulnerable', xy=(most_vulnerable_idx, 6), xytext=(most_vulnerable_idx, 8),
             fontsize=10, color='red', fontweight='bold', ha='center',
             arrowprops=dict(arrowstyle='->', color='red', lw=2))

# ------ Panel B: Class Vulnerability Ranking ------
ax2 = fig.add_subplot(gs[1, 0])

# Calculate mean robustness across all attacks for each class
class_robustness = []
for j, cls in enumerate(CONFIG['class_names']):
    mean_rob = np.mean(heatmap_array[1:, j])  # Exclude clean
    class_robustness.append((cls, mean_rob))

class_robustness.sort(key=lambda x: x[1])  # Sort by robustness (ascending = most vulnerable first)

cls_names = [c[0] for c in class_robustness]
cls_values = [c[1] for c in class_robustness]

# Color bars by vulnerability
colors_bars = [plt.cm.RdYlGn(v/100) for v in cls_values]
bars = ax2.barh(cls_names, cls_values, color=colors_bars, edgecolor='white', linewidth=1.5)

ax2.set_xlabel('Mean Robust Accuracy (%)', fontweight='bold')
ax2.set_title('B) Class Vulnerability Ranking\n(Lower = More Vulnerable)', fontsize=13, fontweight='bold')
ax2.axvline(x=np.mean(cls_values), color='navy', linestyle='--', linewidth=2, label=f'Mean: {np.mean(cls_values):.1f}%')
ax2.legend(loc='lower right')
ax2.set_xlim(0, 100)
ax2.grid(True, alpha=0.3, axis='x')

for bar, val in zip(bars, cls_values):
    ax2.annotate(f'{val:.1f}%', xy=(val, bar.get_y() + bar.get_height()/2),
                 xytext=(5, 0), textcoords='offset points', va='center', fontsize=10, fontweight='bold')

# ------ Panel C: Attack Strength Impact by Class ------
ax3 = fig.add_subplot(gs[1, 1])

# Line plot: accuracy degradation by epsilon for each class
x_positions = [0, 2, 4, 8]  # Epsilon values * 255
colors_class = plt.cm.tab10(np.linspace(0, 1, len(CONFIG['class_names'])))

for j, cls in enumerate(CONFIG['class_names']):
    class_accs = [heatmap_array[0, j]]  # Clean
    class_accs.append(heatmap_array[4, j])  # PGD Œµ=2/255
    class_accs.append(heatmap_array[5, j])  # PGD Œµ=4/255
    class_accs.append(heatmap_array[6, j])  # PGD Œµ=8/255
    ax3.plot(x_positions, class_accs, 'o-', color=colors_class[j], 
             label=cls, linewidth=2, markersize=8, alpha=0.8)

ax3.set_xlabel('Perturbation Œµ (√ó255)', fontweight='bold')
ax3.set_ylabel('Accuracy (%)', fontweight='bold')
ax3.set_title('C) PGD Attack Impact by Class\n(Accuracy vs Epsilon)', fontsize=13, fontweight='bold')
ax3.set_xticks(x_positions)
ax3.set_xticklabels(['0\n(Clean)', '2', '4', '8'])
ax3.legend(loc='upper right', ncol=2, fontsize=9)
ax3.set_ylim(0, 100)
ax3.grid(True, alpha=0.3)

plt.suptitle('Per-Class Adversarial Vulnerability Analysis ‚Äî ISIC 2018 Dermoscopy', 
             fontsize=16, fontweight='bold', y=0.98)
plt.savefig(CONFIG['results_dir'] / 'figures' / 'class_vulnerability_analysis.png', dpi=300, bbox_inches='tight',
            facecolor='white', edgecolor='none')
plt.show()
print("‚úÖ Saved: class_vulnerability_analysis.png (300 DPI)")

In [None]:
# ============================================================================
# CELL 13: INTERACTIVE PLOTLY VISUALIZATIONS
# ============================================================================
print("=" * 70)
print("üìä GENERATING INTERACTIVE PLOTLY VISUALIZATIONS")
print("=" * 70)

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ===================== FIGURE 1: INTERACTIVE ATTACK COMPARISON =====================
# Prepare data for Plotly
attack_data = []
attack_keys = ['clean', 'fgsm_2/255', 'fgsm_4/255', 'fgsm_8/255', 
               'pgd_2/255', 'pgd_4/255', 'pgd_8/255', 'cw']
attack_labels = ['Clean', 'FGSM Œµ=2/255', 'FGSM Œµ=4/255', 'FGSM Œµ=8/255',
                 'PGD Œµ=2/255', 'PGD Œµ=4/255', 'PGD Œµ=8/255', 'C&W L2']
attack_types = ['Baseline', 'FGSM', 'FGSM', 'FGSM', 'PGD', 'PGD', 'PGD', 'C&W']
attack_colors = ['#2E8B57', '#4169E1', '#4169E1', '#4169E1', '#DC143C', '#DC143C', '#DC143C', '#9400D3']

for seed in all_results:
    for key, label, atype, color in zip(attack_keys, attack_labels, attack_types, attack_colors):
        if key == 'clean':
            acc = all_results[seed]['clean']['accuracy']
        else:
            acc = all_results[seed][key]['robust_accuracy']
        attack_data.append({
            'Seed': f'Seed {seed}',
            'Attack': label,
            'Attack Type': atype,
            'Accuracy': acc,
            'Color': color
        })

attack_df = pd.DataFrame(attack_data)

# Create grouped bar chart
fig1 = px.bar(
    attack_df, 
    x='Attack', 
    y='Accuracy', 
    color='Seed',
    barmode='group',
    title='<b>Interactive Attack Comparison: Baseline Model Robustness</b><br><sup>Click legend to toggle seeds | Hover for details</sup>',
    labels={'Accuracy': 'Accuracy (%)', 'Attack': 'Attack Configuration'},
    color_discrete_sequence=['#FF6B6B', '#4ECDC4', '#45B7D1'],
    template='plotly_white'
)

fig1.update_layout(
    font=dict(family='Arial', size=12),
    title_font_size=18,
    title_x=0.5,
    legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
    xaxis_tickangle=-45,
    yaxis=dict(range=[0, 100], title_font_size=14),
    xaxis=dict(title_font_size=14),
    hovermode='x unified',
    bargap=0.15,
    bargroupgap=0.1
)

# Add mean line
mean_clean = attack_df[attack_df['Attack'] == 'Clean']['Accuracy'].mean()
fig1.add_hline(y=mean_clean, line_dash='dash', line_color='gray', 
               annotation_text=f'Clean Baseline: {mean_clean:.1f}%', annotation_position='top right')

fig1.show()
fig1.write_html(str(CONFIG['results_dir'] / 'figures' / 'attack_comparison_interactive.html'))
print("‚úÖ Saved: attack_comparison_interactive.html")

# ===================== FIGURE 2: ROBUSTNESS DEGRADATION SURFACE =====================
# Create 3D surface for PGD attack analysis
fig2 = make_subplots(
    rows=1, cols=2,
    specs=[[{'type': 'scatter3d'}, {'type': 'heatmap'}]],
    subplot_titles=('3D Robustness Surface', 'Attack Intensity Heatmap'),
    horizontal_spacing=0.1
)

# Prepare 3D data
epsilons = [0, 2, 4, 8]
seeds_list = list(all_results.keys())

X, Y = np.meshgrid(epsilons, range(len(seeds_list)))
Z_fgsm = np.zeros_like(X, dtype=float)
Z_pgd = np.zeros_like(X, dtype=float)

for i, seed in enumerate(seeds_list):
    Z_fgsm[i, 0] = all_results[seed]['clean']['accuracy']
    Z_pgd[i, 0] = all_results[seed]['clean']['accuracy']
    for j, eps in enumerate([2, 4, 8]):
        Z_fgsm[i, j+1] = all_results[seed][f'fgsm_{eps}/255']['robust_accuracy']
        Z_pgd[i, j+1] = all_results[seed][f'pgd_{eps}/255']['robust_accuracy']

# Add 3D surface for PGD
fig2.add_trace(
    go.Surface(
        x=X, y=Y, z=Z_pgd,
        colorscale='RdYlGn',
        showscale=True,
        colorbar=dict(title='Accuracy (%)', x=0.45),
        name='PGD Robustness',
        hovertemplate='Œµ=%{x}/255<br>Seed=%{y}<br>Accuracy=%{z:.1f}%<extra></extra>'
    ),
    row=1, col=1
)

# Add heatmap for both attacks
combined_heatmap = np.vstack([Z_fgsm, Z_pgd])
heatmap_labels = [f'FGSM-S{s}' for s in seeds_list] + [f'PGD-S{s}' for s in seeds_list]

fig2.add_trace(
    go.Heatmap(
        z=combined_heatmap,
        x=['Œµ=0', 'Œµ=2/255', 'Œµ=4/255', 'Œµ=8/255'],
        y=heatmap_labels,
        colorscale='RdYlGn',
        showscale=False,
        text=np.round(combined_heatmap, 1),
        texttemplate='%{text}%',
        textfont=dict(size=10, color='black'),
        hovertemplate='%{y}<br>%{x}<br>Accuracy: %{z:.1f}%<extra></extra>'
    ),
    row=1, col=2
)

fig2.update_layout(
    title='<b>3D Robustness Analysis: PGD Attack Surface</b><br><sup>Drag to rotate | Scroll to zoom</sup>',
    title_font_size=16,
    title_x=0.5,
    font=dict(family='Arial', size=11),
    template='plotly_white',
    scene=dict(
        xaxis_title='Epsilon (√ó255)',
        yaxis_title='Seed Index',
        zaxis_title='Accuracy (%)',
        zaxis=dict(range=[0, 100]),
        camera=dict(eye=dict(x=1.5, y=1.5, z=1))
    ),
    height=500
)

fig2.show()
fig2.write_html(str(CONFIG['results_dir'] / 'figures' / 'robustness_3d_surface.html'))
print("‚úÖ Saved: robustness_3d_surface.html")

# ===================== FIGURE 3: ANIMATED ROBUSTNESS DEGRADATION =====================
# Create animation data
animation_data = []
for eps_idx, eps in enumerate([0, 2, 4, 8]):
    for seed in all_results:
        for cls_idx, cls in enumerate(CONFIG['class_names']):
            if eps == 0:
                # Clean accuracy per class
                all_preds = all_results[seed]['clean']['predictions']
                all_labels = all_results[seed]['clean']['labels']
                cls_mask = all_labels == cls_idx
                acc = (all_preds[cls_mask] == all_labels[cls_mask]).mean() * 100 if cls_mask.sum() > 0 else 0
            else:
                acc = all_results[seed][f'pgd_{eps}/255']['per_class_robust_acc'][cls]
            
            animation_data.append({
                'Epsilon': f'Œµ={eps}/255' if eps > 0 else 'Clean',
                'Epsilon_num': eps,
                'Class': cls,
                'Accuracy': acc,
                'Seed': f'Seed {seed}'
            })

anim_df = pd.DataFrame(animation_data)

fig3 = px.bar(
    anim_df,
    x='Class',
    y='Accuracy',
    color='Seed',
    animation_frame='Epsilon',
    title='<b>Animated: Per-Class Robustness Under Increasing PGD Attack Strength</b><br><sup>Press play to see degradation | Each bar = seed performance</sup>',
    labels={'Accuracy': 'Accuracy (%)', 'Class': 'Skin Lesion Class'},
    color_discrete_sequence=['#FF6B6B', '#4ECDC4', '#45B7D1'],
    template='plotly_white',
    barmode='group'
)

fig3.update_layout(
    font=dict(family='Arial', size=12),
    title_font_size=16,
    title_x=0.5,
    yaxis=dict(range=[0, 100]),
    legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
    updatemenus=[dict(
        type='buttons',
        showactive=False,
        y=0,
        x=0.1,
        xanchor='right',
        yanchor='top',
        buttons=[
            dict(label='‚ñ∂ Play', method='animate', args=[None, {'frame': {'duration': 1000, 'redraw': True}, 'fromcurrent': True}]),
            dict(label='‚è∏ Pause', method='animate', args=[[None], {'frame': {'duration': 0, 'redraw': False}, 'mode': 'immediate'}])
        ]
    )]
)

fig3.show()
fig3.write_html(str(CONFIG['results_dir'] / 'figures' / 'animated_robustness.html'))
print("‚úÖ Saved: animated_robustness.html")

# ===================== FIGURE 4: SUNBURST VULNERABILITY BREAKDOWN =====================
sunburst_data = []
for attack_type in ['FGSM', 'PGD', 'C&W']:
    for eps in ['Weak (Œµ=2)', 'Medium (Œµ=4)', 'Strong (Œµ=8)'] if attack_type != 'C&W' else ['Optimization']:
        for cls in CONFIG['class_names']:
            if attack_type == 'FGSM':
                key = f'fgsm_{eps.split("=")[1].split(")")[0]}/255' if eps != 'Optimization' else None
            elif attack_type == 'PGD':
                key = f'pgd_{eps.split("=")[1].split(")")[0]}/255' if eps != 'Optimization' else None
            else:
                key = 'cw'
            
            if key:
                accs = [all_results[s][key]['per_class_robust_acc'][cls] for s in all_results]
                vulnerability = 100 - np.mean(accs)
                sunburst_data.append({
                    'Attack Type': attack_type,
                    'Strength': eps,
                    'Class': cls,
                    'Vulnerability': vulnerability,
                    'Path': f'{attack_type}/{eps}/{cls}'
                })

sunburst_df = pd.DataFrame(sunburst_data)

fig4 = px.sunburst(
    sunburst_df,
    path=['Attack Type', 'Strength', 'Class'],
    values='Vulnerability',
    color='Vulnerability',
    color_continuous_scale='Reds',
    title='<b>Hierarchical Vulnerability Breakdown</b><br><sup>Click to drill down | Size = Vulnerability (100% - Accuracy)</sup>',
)

fig4.update_layout(
    font=dict(family='Arial', size=12),
    title_font_size=16,
    title_x=0.5,
)

fig4.show()
fig4.write_html(str(CONFIG['results_dir'] / 'figures' / 'vulnerability_sunburst.html'))
print("‚úÖ Saved: vulnerability_sunburst.html")

In [None]:
# ============================================================================
# CELL 14: COMPREHENSIVE RESULTS EXPORT & DISSERTATION-READY FIGURES
# ============================================================================
print("=" * 70)
print("üíæ SAVING ALL RESULTS & GENERATING DISSERTATION FIGURES")
print("=" * 70)

# ===================== DISSERTATION-QUALITY SUMMARY FIGURE =====================
fig, axes = plt.subplots(2, 2, figsize=(16, 14))

# ------ Panel A: Clean vs Robust Accuracy Comparison ------
ax = axes[0, 0]
clean_mean = np.mean([all_results[s]['clean']['accuracy'] for s in all_results])
attacks_compare = [
    ('Clean', clean_mean, '#2E8B57'),
    ('FGSM Œµ=8/255', np.mean([all_results[s]['fgsm_8/255']['robust_accuracy'] for s in all_results]), '#4169E1'),
    ('PGD-40 Œµ=8/255', np.mean([all_results[s]['pgd_8/255']['robust_accuracy'] for s in all_results]), '#DC143C'),
    ('C&W L2', np.mean([all_results[s]['cw']['robust_accuracy'] for s in all_results]), '#9400D3'),
]

names = [a[0] for a in attacks_compare]
values = [a[1] for a in attacks_compare]
colors = [a[2] for a in attacks_compare]

bars = ax.bar(names, values, color=colors, edgecolor='white', linewidth=2, alpha=0.85)
ax.set_ylabel('Accuracy (%)', fontsize=13, fontweight='bold')
ax.set_title('(a) Model Accuracy Under Strongest Attacks', fontsize=14, fontweight='bold')
ax.set_ylim(0, 100)
ax.axhline(y=50, color='gray', linestyle=':', alpha=0.5)

for bar, val in zip(bars, values):
    ax.annotate(f'{val:.1f}%', xy=(bar.get_x() + bar.get_width()/2, bar.get_height()),
                xytext=(0, 5), textcoords='offset points', ha='center', fontsize=12, fontweight='bold')

# Add drop annotations
for i, (name, val, color) in enumerate(attacks_compare[1:], 1):
    drop = clean_mean - val
    ax.annotate(f'‚Üì{drop:.1f}%', xy=(i, val/2), ha='center', va='center',
                fontsize=10, color='white', fontweight='bold')

ax.grid(True, alpha=0.3, axis='y')

# ------ Panel B: Epsilon Sensitivity ------
ax = axes[0, 1]
eps_vals = [2, 4, 8]

fgsm_means = [np.mean([all_results[s][f'fgsm_{e}/255']['robust_accuracy'] for s in all_results]) for e in eps_vals]
pgd_means = [np.mean([all_results[s][f'pgd_{e}/255']['robust_accuracy'] for s in all_results]) for e in eps_vals]
fgsm_stds = [np.std([all_results[s][f'fgsm_{e}/255']['robust_accuracy'] for s in all_results]) for e in eps_vals]
pgd_stds = [np.std([all_results[s][f'pgd_{e}/255']['robust_accuracy'] for s in all_results]) for e in eps_vals]

ax.errorbar(eps_vals, fgsm_means, yerr=fgsm_stds, marker='o', markersize=12, linewidth=3,
            color='#4169E1', label='FGSM (1-step)', capsize=5, capthick=2)
ax.errorbar(eps_vals, pgd_means, yerr=pgd_stds, marker='s', markersize=12, linewidth=3,
            color='#DC143C', label='PGD-40 (iterative)', capsize=5, capthick=2)
ax.axhline(y=clean_mean, color='#2E8B57', linestyle='--', linewidth=2, label=f'Clean ({clean_mean:.1f}%)')

ax.set_xlabel('Perturbation Budget Œµ (√ó255)', fontsize=13, fontweight='bold')
ax.set_ylabel('Robust Accuracy (%)', fontsize=13, fontweight='bold')
ax.set_title('(b) Impact of Perturbation Strength', fontsize=14, fontweight='bold')
ax.set_xticks(eps_vals)
ax.set_ylim(0, 100)
ax.legend(loc='upper right', fontsize=11)
ax.grid(True, alpha=0.3)

# Fill area between curves
ax.fill_between(eps_vals, fgsm_means, pgd_means, alpha=0.1, color='gray')
ax.annotate('Gap', xy=(5, (fgsm_means[1] + pgd_means[1])/2), fontsize=10, style='italic')

# ------ Panel C: Per-Class Vulnerability Radar ------
ax = axes[1, 0]
ax.axis('off')

# Create radar subplot
ax_radar = fig.add_subplot(2, 2, 3, projection='polar')

# Data for strongest attack (PGD Œµ=8/255)
values_radar = []
for cls in CONFIG['class_names']:
    accs = [all_results[s]['pgd_8/255']['per_class_robust_acc'][cls] for s in all_results]
    values_radar.append(np.mean(accs))

# Close the loop
values_radar_closed = values_radar + [values_radar[0]]
angles = np.linspace(0, 2*np.pi, len(CONFIG['class_names']), endpoint=False).tolist()
angles += [angles[0]]

ax_radar.plot(angles, values_radar_closed, 'o-', linewidth=2.5, color='#DC143C', markersize=8)
ax_radar.fill(angles, values_radar_closed, alpha=0.25, color='#DC143C')
ax_radar.set_xticks(angles[:-1])
ax_radar.set_xticklabels(CONFIG['class_names'], fontsize=11, fontweight='bold')
ax_radar.set_ylim(0, 100)
ax_radar.set_title('(c) Per-Class Robustness Profile\n(PGD-40 Œµ=8/255)', fontsize=14, fontweight='bold', pad=20)

# ------ Panel D: Confusion of Attack Success ------
ax = axes[1, 1]

# Calculate success rate per class (correctly classified ‚Üí misclassified)
success_rates = []
for cls in CONFIG['class_names']:
    clean_acc = []
    attack_acc = []
    for seed in all_results:
        all_preds = all_results[seed]['clean']['predictions']
        all_labels = all_results[seed]['clean']['labels']
        cls_idx = CONFIG['class_names'].index(cls)
        cls_mask = all_labels == cls_idx
        if cls_mask.sum() > 0:
            clean_acc.append((all_preds[cls_mask] == all_labels[cls_mask]).mean() * 100)
            attack_acc.append(all_results[seed]['pgd_8/255']['per_class_robust_acc'][cls])
    
    if clean_acc:
        drop = np.mean(clean_acc) - np.mean(attack_acc)
        success_rates.append(drop)
    else:
        success_rates.append(0)

sorted_indices = np.argsort(success_rates)[::-1]
sorted_classes = [CONFIG['class_names'][i] for i in sorted_indices]
sorted_rates = [success_rates[i] for i in sorted_indices]

colors_sr = plt.cm.Reds(np.linspace(0.3, 0.9, len(sorted_classes)))
bars = ax.barh(sorted_classes, sorted_rates, color=colors_sr, edgecolor='white', linewidth=1.5)

ax.set_xlabel('Accuracy Drop Under PGD-40 Œµ=8/255 (%)', fontsize=13, fontweight='bold')
ax.set_title('(d) Class-wise Attack Success Rate', fontsize=14, fontweight='bold')
ax.axvline(x=np.mean(sorted_rates), color='navy', linestyle='--', linewidth=2, label=f'Mean: {np.mean(sorted_rates):.1f}%')
ax.legend(loc='lower right')
ax.grid(True, alpha=0.3, axis='x')

for bar, rate in zip(bars, sorted_rates):
    ax.annotate(f'{rate:.1f}%', xy=(rate, bar.get_y() + bar.get_height()/2),
                xytext=(5, 0), textcoords='offset points', va='center', fontsize=11, fontweight='bold')

plt.suptitle('Phase 4: Adversarial Robustness Evaluation Summary\nResNet-50 Baseline on ISIC 2018 Dermoscopy Dataset',
             fontsize=18, fontweight='bold', y=0.98)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(CONFIG['results_dir'] / 'figures' / 'dissertation_figure_robustness.png', dpi=300, bbox_inches='tight',
            facecolor='white', edgecolor='none')
plt.savefig(CONFIG['results_dir'] / 'figures' / 'dissertation_figure_robustness.pdf', dpi=300, bbox_inches='tight',
            facecolor='white', edgecolor='none')
plt.show()
print("‚úÖ Saved: dissertation_figure_robustness.png (300 DPI)")
print("‚úÖ Saved: dissertation_figure_robustness.pdf (Vector)")

# ===================== EXPORT ALL RESULTS =====================
# Prepare comprehensive JSON export
export_results = {
    'metadata': {
        'experiment': 'Phase 4 Adversarial Robustness Evaluation',
        'model': 'ResNet-50 (ImageNet pretrained)',
        'dataset': 'ISIC 2018 Dermoscopy',
        'num_classes': 7,
        'class_names': CONFIG['class_names'],
        'seeds_evaluated': list(all_results.keys()),
        'attacks_evaluated': ['FGSM', 'PGD-40', 'Carlini-Wagner L2'],
        'epsilons': [2/255, 4/255, 8/255],
        'timestamp': datetime.now().isoformat(),
    },
    'summary': {
        'clean_accuracy': {
            'mean': float(np.mean([all_results[s]['clean']['accuracy'] for s in all_results])),
            'std': float(np.std([all_results[s]['clean']['accuracy'] for s in all_results])),
            'per_seed': {str(s): float(all_results[s]['clean']['accuracy']) for s in all_results}
        },
        'robust_accuracy_fgsm_8': {
            'mean': float(np.mean([all_results[s]['fgsm_8/255']['robust_accuracy'] for s in all_results])),
            'std': float(np.std([all_results[s]['fgsm_8/255']['robust_accuracy'] for s in all_results])),
        },
        'robust_accuracy_pgd_8': {
            'mean': float(np.mean([all_results[s]['pgd_8/255']['robust_accuracy'] for s in all_results])),
            'std': float(np.std([all_results[s]['pgd_8/255']['robust_accuracy'] for s in all_results])),
        },
        'robust_accuracy_cw': {
            'mean': float(np.mean([all_results[s]['cw']['robust_accuracy'] for s in all_results])),
            'std': float(np.std([all_results[s]['cw']['robust_accuracy'] for s in all_results])),
        },
    },
    'detailed_results': {}
}

for seed in all_results:
    export_results['detailed_results'][str(seed)] = {}
    for attack_key, attack_result in all_results[seed].items():
        export_results['detailed_results'][str(seed)][attack_key] = {
            k: v.tolist() if isinstance(v, np.ndarray) else v
            for k, v in attack_result.items()
            if k not in ['predictions', 'labels', 'probs', 'confusion_matrix']
        }

# Save JSON
results_file = CONFIG['results_dir'] / 'adversarial_results_full.json'
with open(results_file, 'w') as f:
    json.dump(export_results, f, indent=2)
print(f"‚úÖ Full results saved to: {results_file}")

# ===================== LIST ALL SAVED FILES =====================
print(f"\n{'='*70}")
print(f"üìÅ ALL SAVED FILES IN {CONFIG['results_dir']}:")
print(f"{'='*70}")

for f in sorted(CONFIG['results_dir'].glob('**/*')):
    if f.is_file():
        size_kb = f.stat().st_size / 1024
        rel_path = f.relative_to(CONFIG['results_dir'])
        icon = 'üìä' if f.suffix in ['.png', '.pdf'] else 'üìÑ' if f.suffix == '.html' else 'üíæ'
        print(f"   {icon} {rel_path} ({size_kb:.1f} KB)")

In [None]:
# ============================================================================
# CELL 15: EXECUTIVE SUMMARY & NEXT STEPS
# ============================================================================
print("\n")
print("‚ïî" + "‚ïê"*78 + "‚ïó")
print("‚ïë" + " "*25 + "üéØ PHASE 4 COMPLETE" + " "*25 + "‚ïë")
print("‚ïë" + " "*15 + "ADVERSARIAL ROBUSTNESS EVALUATION SUMMARY" + " "*14 + "‚ïë")
print("‚ïö" + "‚ïê"*78 + "‚ïù")

# ===================== KEY METRICS =====================
clean_mean = np.mean([all_results[s]['clean']['accuracy'] for s in all_results])
clean_std = np.std([all_results[s]['clean']['accuracy'] for s in all_results])
fgsm8_mean = np.mean([all_results[s]['fgsm_8/255']['robust_accuracy'] for s in all_results])
pgd8_mean = np.mean([all_results[s]['pgd_8/255']['robust_accuracy'] for s in all_results])
cw_mean = np.mean([all_results[s]['cw']['robust_accuracy'] for s in all_results])

# Calculate average drop
avg_drop = np.mean([clean_mean - fgsm8_mean, clean_mean - pgd8_mean, clean_mean - cw_mean])

print(f"""
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ                      üìä BASELINE MODEL PERFORMANCE                          ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ                                                                             ‚îÇ
‚îÇ   üéØ Clean Accuracy:         {clean_mean:>6.2f}% ¬± {clean_std:.2f}%                            ‚îÇ
‚îÇ                                                                             ‚îÇ
‚îÇ   ‚öîÔ∏è  Adversarial Robustness (Strong Attacks):                              ‚îÇ
‚îÇ   ‚îú‚îÄ‚îÄ FGSM (Œµ=8/255):        {fgsm8_mean:>6.2f}%  ‚îÇ  Drop: {clean_mean - fgsm8_mean:>5.2f}%                 ‚îÇ
‚îÇ   ‚îú‚îÄ‚îÄ PGD-40 (Œµ=8/255):      {pgd8_mean:>6.2f}%  ‚îÇ  Drop: {clean_mean - pgd8_mean:>5.2f}%                 ‚îÇ
‚îÇ   ‚îî‚îÄ‚îÄ Carlini-Wagner L2:     {cw_mean:>6.2f}%  ‚îÇ  Drop: {clean_mean - cw_mean:>5.2f}%                 ‚îÇ
‚îÇ                                                                             ‚îÇ
‚îÇ   üìâ Average Robustness Drop: {avg_drop:.2f}%                                       ‚îÇ
‚îÇ                                                                             ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
""")

# ===================== KEY FINDINGS =====================
most_vulnerable = CONFIG['class_names'][np.argmin([
    np.mean([all_results[s]['pgd_8/255']['per_class_robust_acc'][c] for s in all_results])
    for c in CONFIG['class_names']
])]
most_robust = CONFIG['class_names'][np.argmax([
    np.mean([all_results[s]['pgd_8/255']['per_class_robust_acc'][c] for s in all_results])
    for c in CONFIG['class_names']
])]

print(f"""
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ                         üî¨ KEY RESEARCH FINDINGS                            ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ                                                                             ‚îÇ
‚îÇ  1Ô∏è‚É£  CRITICAL VULNERABILITY                                                 ‚îÇ
‚îÇ      Standard CNNs show SEVERE vulnerability to adversarial attacks         ‚îÇ
‚îÇ      ‚Ä¢ Up to {max([clean_mean - fgsm8_mean, clean_mean - pgd8_mean, clean_mean - cw_mean]):.1f}% accuracy degradation under imperceptible perturbations    ‚îÇ
‚îÇ      ‚Ä¢ Perturbations invisible to human eye (Œµ ‚â§ 8/255)                     ‚îÇ
‚îÇ                                                                             ‚îÇ
‚îÇ  2Ô∏è‚É£  ATTACK COMPARISON                                                      ‚îÇ
‚îÇ      ‚Ä¢ PGD-40 is more effective than single-step FGSM                       ‚îÇ
‚îÇ      ‚Ä¢ C&W finds minimum perturbation for misclassification                 ‚îÇ
‚îÇ      ‚Ä¢ Iterative attacks reveal true model fragility                        ‚îÇ
‚îÇ                                                                             ‚îÇ
‚îÇ  3Ô∏è‚É£  CLASS-WISE ANALYSIS                                                    ‚îÇ
‚îÇ      ‚Ä¢ Most vulnerable class: {most_vulnerable:<8}                                      ‚îÇ
‚îÇ      ‚Ä¢ Most robust class: {most_robust:<8}                                          ‚îÇ
‚îÇ      ‚Ä¢ Vulnerability varies significantly across classes                    ‚îÇ
‚îÇ                                                                             ‚îÇ
‚îÇ  4Ô∏è‚É£  CLINICAL IMPLICATIONS                                                  ‚îÇ
‚îÇ      ‚ö†Ô∏è  Baseline models are NOT SAFE for clinical deployment               ‚îÇ
‚îÇ      ‚ö†Ô∏è  Adversarial training is ESSENTIAL before real-world use            ‚îÇ
‚îÇ      ‚ö†Ô∏è  All skin lesion classes show significant vulnerability             ‚îÇ
‚îÇ                                                                             ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
""")

# ===================== VISUALIZATIONS GENERATED =====================
print(f"""
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ                      üìä VISUALIZATIONS GENERATED                            ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ                                                                             ‚îÇ
‚îÇ  Static Figures (PNG/PDF, 300 DPI):                                         ‚îÇ
‚îÇ  ‚îú‚îÄ‚îÄ üìà robustness_dashboard.png          - Multi-panel analysis            ‚îÇ
‚îÇ  ‚îú‚îÄ‚îÄ üî• class_vulnerability_analysis.png  - Per-class breakdown             ‚îÇ
‚îÇ  ‚îî‚îÄ‚îÄ üìÑ dissertation_figure_robustness.pdf - Publication-ready              ‚îÇ
‚îÇ                                                                             ‚îÇ
‚îÇ  Interactive Figures (HTML):                                                ‚îÇ
‚îÇ  ‚îú‚îÄ‚îÄ üñ±Ô∏è  attack_comparison_interactive.html                                 ‚îÇ
‚îÇ  ‚îú‚îÄ‚îÄ üåê robustness_3d_surface.html        - 3D rotatable surface            ‚îÇ
‚îÇ  ‚îú‚îÄ‚îÄ ‚ñ∂Ô∏è  animated_robustness.html          - Epsilon animation              ‚îÇ
‚îÇ  ‚îî‚îÄ‚îÄ üå≥ vulnerability_sunburst.html       - Hierarchical breakdown          ‚îÇ
‚îÇ                                                                             ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
""")

# ===================== NEXT STEPS =====================
print(f"""
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ                          üöÄ NEXT STEPS                                      ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ                                                                             ‚îÇ
‚îÇ  ‚úÖ PHASE 4 COMPLETE: Adversarial Robustness Evaluation                     ‚îÇ
‚îÇ                                                                             ‚îÇ
‚îÇ  üìå PHASE 5: Tri-Objective Robust Training                                  ‚îÇ
‚îÇ     ‚Ä¢ Implement adversarial training with PGD-AT                            ‚îÇ
‚îÇ     ‚Ä¢ Add explainability preservation objective                             ‚îÇ
‚îÇ     ‚Ä¢ Multi-objective optimization (Accuracy + Robustness + XAI)            ‚îÇ
‚îÇ                                                                             ‚îÇ
‚îÇ  üìå PHASE 6: Explainability Analysis                                        ‚îÇ
‚îÇ     ‚Ä¢ Grad-CAM visualization comparison                                     ‚îÇ
‚îÇ     ‚Ä¢ SHAP analysis for feature importance                                  ‚îÇ
‚îÇ     ‚Ä¢ XAI consistency under adversarial perturbations                       ‚îÇ
‚îÇ                                                                             ‚îÇ
‚îÇ  üìå PHASE 7: Comparative Evaluation                                         ‚îÇ
‚îÇ     ‚Ä¢ Baseline vs Robust model comparison                                   ‚îÇ
‚îÇ     ‚Ä¢ Trade-off analysis (Accuracy-Robustness-Explainability)               ‚îÇ
‚îÇ     ‚Ä¢ Statistical significance testing                                      ‚îÇ
‚îÇ                                                                             ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
""")

print("‚ïê"*80)
print("   ‚úÖ PHASE 4 ADVERSARIAL ROBUSTNESS EVALUATION SUCCESSFULLY COMPLETED!")
print("‚ïê"*80)
print(f"\nüìÅ All results saved to: {CONFIG['results_dir']}")
print(f"üìä Total figures generated: 7 (4 static + 4 interactive)")
print(f"üíæ Data files: adversarial_results.csv, adversarial_results_full.json")
print("\nüîó Run Phase 5 notebook to continue with adversarial training!")

In [None]:
#@title üîß Cell 2: Environment Setup & Dependencies
#@markdown **Run this cell first to install all required packages**

import subprocess
import sys

def install_packages():
    """Install required packages for adversarial evaluation."""
    packages = [
        "torch>=2.0.0",
        "torchvision>=0.15.0",
        "timm>=0.9.0",
        "albumentations>=1.3.0",
        "scikit-learn>=1.3.0",
        "pandas>=2.0.0",
        "numpy>=1.24.0",
        "matplotlib>=3.7.0",
        "seaborn>=0.12.0",
        "plotly>=5.15.0",
        "kaleido",  # For plotly static export
        "tqdm>=4.65.0",
        "mlflow>=2.5.0",
        "scipy>=1.11.0",
    ]
    
    print("üì¶ Installing packages...")
    for pkg in packages:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])
    print("‚úÖ All packages installed!")

# Check if running in Colab
try:
    import google.colab
    IN_COLAB = True
    print("üåê Running in Google Colab")
    install_packages()
except ImportError:
    IN_COLAB = False
    print("üíª Running locally")

# Core imports
import os
import gc
import json
import time
import warnings
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ML utilities
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, f1_score,
    confusion_matrix, classification_report, roc_auc_score
)
from tqdm.auto import tqdm

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)

# Set visualization style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")
plt.rcParams.update({
    'figure.figsize': (12, 8),
    'font.size': 12,
    'axes.titlesize': 14,
    'axes.labelsize': 12,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'figure.dpi': 100,
})

# GPU Configuration
print("\n" + "="*60)
print("üñ•Ô∏è  HARDWARE CONFIGURATION")
print("="*60)

if torch.cuda.is_available():
    device = torch.device("cuda")
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"‚úÖ GPU: {gpu_name}")
    print(f"‚úÖ VRAM: {gpu_mem:.1f} GB")
    
    # Enable optimizations for A100/Ampere GPUs
    if "A100" in gpu_name or torch.cuda.get_device_capability()[0] >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        print("‚úÖ TF32 enabled for Ampere GPU")
    
    torch.backends.cudnn.benchmark = True
    print("‚úÖ cuDNN benchmark enabled")
else:
    device = torch.device("cpu")
    print("‚ö†Ô∏è No GPU found, using CPU (will be slow)")

print(f"‚úÖ PyTorch version: {torch.__version__}")
print(f"‚úÖ Device: {device}")
print("="*60)

In [None]:
#@title üóÇÔ∏è Cell 3: Mount Google Drive & Configure Paths
#@markdown **Configure data and checkpoint paths**

# Mount Google Drive (Colab only)
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)
    print("‚úÖ Google Drive mounted")

# ============================================================================
# PATH CONFIGURATION - Adjust these paths as needed
# ============================================================================

@dataclass
class PathConfig:
    """Central path configuration for the evaluation pipeline."""
    
    # Base paths
    drive_base: Path = Path("/content/drive/MyDrive")
    
    # Data paths
    data_root: Path = field(default=None)
    train_dir: Path = field(default=None)
    val_dir: Path = field(default=None)
    test_dir: Path = field(default=None)
    
    # Checkpoint paths
    checkpoint_dir: Path = field(default=None)
    
    # Output paths
    results_dir: Path = field(default=None)
    figures_dir: Path = field(default=None)
    
    def __post_init__(self):
        """Initialize derived paths."""
        if self.data_root is None:
            self.data_root = self.drive_base / "data" / "data" / "isic_2018"
        if self.train_dir is None:
            self.train_dir = self.data_root / "train"
        if self.val_dir is None:
            self.val_dir = self.data_root / "val"
        if self.test_dir is None:
            self.test_dir = self.data_root / "test"
        if self.checkpoint_dir is None:
            self.checkpoint_dir = self.drive_base / "checkpoints" / "baseline"
        if self.results_dir is None:
            self.results_dir = self.drive_base / "results" / "phase4_adversarial"
        if self.figures_dir is None:
            self.figures_dir = self.results_dir / "figures"
    
    def validate(self) -> bool:
        """Validate that required paths exist."""
        required = [self.data_root, self.test_dir, self.checkpoint_dir]
        missing = [p for p in required if not p.exists()]
        
        if missing:
            print("‚ùå Missing paths:")
            for p in missing:
                print(f"   - {p}")
            return False
        return True
    
    def create_output_dirs(self):
        """Create output directories if they don't exist."""
        self.results_dir.mkdir(parents=True, exist_ok=True)
        self.figures_dir.mkdir(parents=True, exist_ok=True)
        print(f"‚úÖ Results directory: {self.results_dir}")
        print(f"‚úÖ Figures directory: {self.figures_dir}")

# Initialize paths
paths = PathConfig()

# Validate paths
print("\n" + "="*60)
print("üìÅ PATH VALIDATION")
print("="*60)

if paths.validate():
    print(f"‚úÖ Data root: {paths.data_root}")
    print(f"‚úÖ Test directory: {paths.test_dir}")
    print(f"‚úÖ Checkpoint directory: {paths.checkpoint_dir}")
    paths.create_output_dirs()
else:
    print("\n‚ö†Ô∏è Please update PathConfig with correct paths!")
    
# List available checkpoints
print("\nüì¶ Available Checkpoints:")
if paths.checkpoint_dir.exists():
    checkpoints = list(paths.checkpoint_dir.glob("*.pt"))
    for ckpt in sorted(checkpoints):
        size_mb = ckpt.stat().st_size / 1e6
        print(f"   - {ckpt.name} ({size_mb:.1f} MB)")
else:
    print("   ‚ö†Ô∏è No checkpoints found!")

print("="*60)

In [None]:
#@title üì• Cell 4: Clone Repository & Import Attack Classes
#@markdown **Clone the project repository and import custom modules**

import os
import sys

# Clone repository (Colab only)
REPO_URL = "https://github.com/viraj1011JAIN/tri-objective-robust-xai-medimg.git"
REPO_DIR = "/content/tri-objective-robust-xai-medimg"

if IN_COLAB:
    if not os.path.exists(REPO_DIR):
        print(f"üì• Cloning repository...")
        os.system(f"git clone {REPO_URL} {REPO_DIR}")
        print("‚úÖ Repository cloned!")
    else:
        print("üìÅ Repository already exists, pulling latest...")
        os.system(f"cd {REPO_DIR} && git pull")
    
    # Add to Python path
    if REPO_DIR not in sys.path:
        sys.path.insert(0, REPO_DIR)
    print(f"‚úÖ Added {REPO_DIR} to Python path")
else:
    # Local development - find project root
    current_dir = Path.cwd()
    if "notebooks" in str(current_dir):
        project_root = current_dir.parent
    else:
        project_root = current_dir
    
    if str(project_root) not in sys.path:
        sys.path.insert(0, str(project_root))
    print(f"‚úÖ Using local project: {project_root}")

# Import custom attack classes
print("\nüîß Importing attack modules...")

try:
    from src.attacks.fgsm import FGSM, FGSMConfig, fgsm_attack
    from src.attacks.pgd import PGD, PGDConfig, pgd_attack
    from src.attacks.cw import CarliniWagner, CWConfig, cw_attack
    from src.attacks.base import AttackConfig, AttackResult
    print("‚úÖ FGSM attack imported")
    print("‚úÖ PGD attack imported")
    print("‚úÖ Carlini-Wagner attack imported")
    print("‚úÖ Base attack classes imported")
except ImportError as e:
    print(f"‚ùå Import error: {e}")
    print("‚ö†Ô∏è Please ensure the repository is properly cloned")
    raise

# Import dataset utilities
try:
    from src.datasets.isic import ISICDataset
    print("‚úÖ ISICDataset imported")
except ImportError:
    print("‚ö†Ô∏è ISICDataset not found, will use custom implementation")
    ISICDataset = None

# Import model utilities
import timm
print(f"‚úÖ timm version: {timm.__version__}")

print("\n" + "="*60)
print("‚úÖ ALL MODULES IMPORTED SUCCESSFULLY")
print("="*60)

In [None]:
#@title üìä Cell 5: Dataset & Model Loading Utilities
#@markdown **Define dataset wrapper and model loading functions**

import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
from torch.utils.data import Dataset, DataLoader

# ============================================================================
# ISIC 2018 CLASS INFORMATION
# ============================================================================

CLASS_NAMES = [
    "AKIEC",  # Actinic Keratoses
    "BCC",    # Basal Cell Carcinoma
    "BKL",    # Benign Keratosis
    "DF",     # Dermatofibroma
    "MEL",    # Melanoma
    "NV",     # Melanocytic Nevi
    "VASC"    # Vascular Lesions
]

CLASS_DESCRIPTIONS = {
    "AKIEC": "Actinic Keratoses (pre-cancerous)",
    "BCC": "Basal Cell Carcinoma (cancerous)",
    "BKL": "Benign Keratosis (non-cancerous)",
    "DF": "Dermatofibroma (benign)",
    "MEL": "Melanoma (malignant, dangerous)",
    "NV": "Melanocytic Nevi (common moles)",
    "VASC": "Vascular Lesions (blood vessel)"
}

NUM_CLASSES = len(CLASS_NAMES)

# ImageNet normalization (used by pretrained models)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# ============================================================================
# DATASET CLASS
# ============================================================================

class ISICTestDataset(Dataset):
    """
    ISIC 2018 Test Dataset for adversarial evaluation.
    
    Returns unnormalized images in [0, 1] range for adversarial attacks.
    Normalization is applied separately during model inference.
    """
    
    def __init__(
        self,
        root_dir: Path,
        transform: Optional[A.Compose] = None,
        max_samples: Optional[int] = None
    ):
        self.root_dir = Path(root_dir)
        self.transform = transform or self._default_transform()
        
        # Collect samples
        self.samples = []
        self.class_to_idx = {name: idx for idx, name in enumerate(CLASS_NAMES)}
        
        for class_name in CLASS_NAMES:
            class_dir = self.root_dir / class_name
            if class_dir.exists():
                images = list(class_dir.glob("*.jpg")) + list(class_dir.glob("*.png"))
                for img_path in images:
                    self.samples.append((img_path, self.class_to_idx[class_name]))
        
        # Limit samples if specified
        if max_samples and len(self.samples) > max_samples:
            # Stratified sampling
            from collections import defaultdict
            by_class = defaultdict(list)
            for path, label in self.samples:
                by_class[label].append((path, label))
            
            per_class = max_samples // NUM_CLASSES
            self.samples = []
            for label, items in by_class.items():
                self.samples.extend(items[:per_class])
        
        print(f"üìä Loaded {len(self.samples)} test samples from {self.root_dir}")
        
        # Print class distribution
        class_counts = {}
        for _, label in self.samples:
            class_counts[label] = class_counts.get(label, 0) + 1
        for idx, name in enumerate(CLASS_NAMES):
            print(f"   {name}: {class_counts.get(idx, 0)} samples")
    
    def _default_transform(self) -> A.Compose:
        """Default test transform: resize and convert to tensor."""
        return A.Compose([
            A.Resize(224, 224),
            A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),  # Keep in [0, 1]
            ToTensorV2()
        ])
    
    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int, dict]:
        img_path, label = self.samples[idx]
        
        # Load image
        image = np.array(Image.open(img_path).convert("RGB"))
        
        # Apply transforms
        if self.transform:
            transformed = self.transform(image=image)
            image = transformed["image"]
        
        # Metadata
        metadata = {
            "image_path": str(img_path),
            "class_name": CLASS_NAMES[label]
        }
        
        return image.float(), label, metadata

# ============================================================================
# MODEL LOADING
# ============================================================================

def create_model(num_classes: int = NUM_CLASSES, pretrained: bool = False) -> nn.Module:
    """Create ResNet-50 model for ISIC classification."""
    model = timm.create_model(
        "resnet50",
        pretrained=pretrained,
        num_classes=num_classes
    )
    return model

def load_checkpoint(model: nn.Module, checkpoint_path: Path, device: torch.device) -> dict:
    """
    Load model checkpoint and return metadata.
    
    Args:
        model: PyTorch model
        checkpoint_path: Path to checkpoint file
        device: Target device
    
    Returns:
        Checkpoint metadata dictionary
    """
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    
    # Handle different checkpoint formats
    if "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
        metadata = {
            "epoch": checkpoint.get("epoch", "unknown"),
            "val_acc": checkpoint.get("val_acc", checkpoint.get("best_val_acc", "unknown")),
            "seed": checkpoint.get("seed", "unknown")
        }
    elif "state_dict" in checkpoint:
        model.load_state_dict(checkpoint["state_dict"])
        metadata = {"epoch": "unknown", "val_acc": "unknown", "seed": "unknown"}
    else:
        # Direct state dict
        model.load_state_dict(checkpoint)
        metadata = {"epoch": "unknown", "val_acc": "unknown", "seed": "unknown"}
    
    model.to(device)
    model.eval()
    
    return metadata

def get_normalizer(device: torch.device):
    """Get ImageNet normalization function."""
    mean = torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1).to(device)
    std = torch.tensor(IMAGENET_STD).view(1, 3, 1, 1).to(device)
    
    def normalize(x: torch.Tensor) -> torch.Tensor:
        return (x - mean) / std
    
    return normalize

print("‚úÖ Dataset and model utilities defined")

In [None]:
#@title ‚öôÔ∏è Cell 6: Evaluation Configuration
#@markdown **Configure attack parameters and evaluation settings**

@dataclass
class EvaluationConfig:
    """Configuration for adversarial evaluation."""
    
    # Seeds to evaluate
    seeds: List[int] = field(default_factory=lambda: [42, 123, 456])
    
    # Epsilon values (perturbation budgets)
    epsilons: List[float] = field(default_factory=lambda: [2/255, 4/255, 8/255])
    
    # Attack configurations
    fgsm_enabled: bool = True
    pgd_enabled: bool = True
    pgd_steps: int = 40
    pgd_step_size: Optional[float] = None  # Auto: epsilon/4
    
    cw_enabled: bool = True
    cw_iterations: int = 100  # Reduced for speed (default 1000)
    cw_confidence: float = 0.0
    cw_learning_rate: float = 0.01
    
    # Evaluation settings
    batch_size: int = 64  # Increase for A100
    num_workers: int = 4
    max_test_samples: Optional[int] = None  # None = all samples
    
    # Output settings
    save_adversarial_examples: bool = True
    num_examples_to_save: int = 50
    
    def __post_init__(self):
        """Adjust settings based on hardware."""
        if torch.cuda.is_available():
            gpu_name = torch.cuda.get_device_name(0)
            gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
            
            # Optimize batch size for GPU memory
            if gpu_mem >= 35:  # A100 40GB
                self.batch_size = 128
                self.num_workers = 8
                self.cw_iterations = 200
                print(f"‚ö° A100 optimizations: batch={self.batch_size}, C&W iters={self.cw_iterations}")
            elif gpu_mem >= 14:  # T4/V100
                self.batch_size = 64
                self.num_workers = 4
                self.cw_iterations = 100
                print(f"‚ö° T4/V100 settings: batch={self.batch_size}")
            else:  # Smaller GPU
                self.batch_size = 32
                self.num_workers = 2
                self.cw_iterations = 50
                print(f"‚ö†Ô∏è Limited GPU: batch={self.batch_size}")
    
    def get_epsilon_str(self, eps: float) -> str:
        """Convert epsilon to readable string."""
        return f"{int(eps * 255)}/255"
    
    def summary(self) -> str:
        """Get configuration summary."""
        lines = [
            "="*60,
            "üìã EVALUATION CONFIGURATION",
            "="*60,
            f"Seeds: {self.seeds}",
            f"Epsilons: {[self.get_epsilon_str(e) for e in self.epsilons]}",
            "",
            "Attacks:",
            f"  FGSM: {'‚úì' if self.fgsm_enabled else '‚úó'}",
            f"  PGD:  {'‚úì' if self.pgd_enabled else '‚úó'} (steps={self.pgd_steps})",
            f"  C&W:  {'‚úì' if self.cw_enabled else '‚úó'} (iters={self.cw_iterations})",
            "",
            f"Batch size: {self.batch_size}",
            f"Max samples: {self.max_test_samples or 'all'}",
            "="*60
        ]
        return "\n".join(lines)

# Initialize configuration
config = EvaluationConfig()
print(config.summary())

# Epsilon display helper
EPSILON_LABELS = {
    2/255: "Œµ=2/255 (weak)",
    4/255: "Œµ=4/255 (medium)",
    8/255: "Œµ=8/255 (strong)"
}

print("\n‚úÖ Configuration ready")

In [None]:
#@title üéØ Cell 7: Adversarial Attack Engine
#@markdown **Core attack generation and evaluation functions**

class AdversarialEvaluator:
    """
    Unified adversarial evaluation engine.
    
    Supports FGSM, PGD, and Carlini-Wagner attacks with batch processing
    and detailed metrics collection.
    """
    
    def __init__(
        self,
        model: nn.Module,
        device: torch.device,
        normalize_fn: callable
    ):
        self.model = model
        self.device = device
        self.normalize = normalize_fn
        self.model.eval()
        
        # Results storage
        self.results = {}
    
    def evaluate_clean(
        self,
        dataloader: DataLoader,
        desc: str = "Clean Evaluation"
    ) -> Dict[str, Any]:
        """Evaluate model on clean (unperturbed) data."""
        all_preds = []
        all_labels = []
        all_probs = []
        
        with torch.no_grad():
            for batch in tqdm(dataloader, desc=desc, leave=False):
                # Handle 3-tuple (image, label, metadata)
                images, labels = batch[0], batch[1]
                images = images.to(self.device)
                labels = labels.to(self.device)
                
                # Forward pass with normalization
                logits = self.model(self.normalize(images))
                probs = F.softmax(logits, dim=1)
                preds = logits.argmax(dim=1)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())
        
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        all_probs = np.array(all_probs)
        
        # Calculate metrics
        accuracy = accuracy_score(all_labels, all_preds)
        balanced_acc = balanced_accuracy_score(all_labels, all_preds)
        f1_macro = f1_score(all_labels, all_preds, average='macro')
        f1_weighted = f1_score(all_labels, all_preds, average='weighted')
        
        # Per-class accuracy
        cm = confusion_matrix(all_labels, all_preds)
        per_class_acc = cm.diagonal() / cm.sum(axis=1)
        
        # AUROC (one-vs-rest)
        try:
            auroc = roc_auc_score(all_labels, all_probs, multi_class='ovr')
        except ValueError:
            auroc = 0.0
        
        return {
            "accuracy": accuracy,
            "balanced_accuracy": balanced_acc,
            "f1_macro": f1_macro,
            "f1_weighted": f1_weighted,
            "auroc": auroc,
            "per_class_accuracy": dict(zip(CLASS_NAMES, per_class_acc)),
            "confusion_matrix": cm,
            "predictions": all_preds,
            "labels": all_labels,
            "probabilities": all_probs
        }
    
    def evaluate_attack(
        self,
        dataloader: DataLoader,
        attack_name: str,
        epsilon: float,
        attack_fn: callable,
        desc: str = None
    ) -> Dict[str, Any]:
        """
        Evaluate model under adversarial attack.
        
        Args:
            dataloader: Test data loader
            attack_name: Name of attack (FGSM, PGD, CW)
            epsilon: Perturbation budget
            attack_fn: Function that generates adversarial examples
            desc: Progress bar description
        
        Returns:
            Dictionary with attack results and metrics
        """
        if desc is None:
            desc = f"{attack_name} Œµ={config.get_epsilon_str(epsilon)}"
        
        all_preds_clean = []
        all_preds_adv = []
        all_labels = []
        all_l2_dists = []
        all_linf_dists = []
        successful_attacks = 0
        total_samples = 0
        
        # Store some examples for visualization
        saved_examples = []
        
        for batch in tqdm(dataloader, desc=desc, leave=False):
            images, labels = batch[0], batch[1]
            images = images.to(self.device)
            labels = labels.to(self.device)
            
            # Clean predictions
            with torch.no_grad():
                clean_logits = self.model(self.normalize(images))
                clean_preds = clean_logits.argmax(dim=1)
            
            # Generate adversarial examples
            try:
                x_adv = attack_fn(images, labels)
            except Exception as e:
                print(f"‚ö†Ô∏è Attack failed on batch: {e}")
                continue
            
            # Adversarial predictions
            with torch.no_grad():
                adv_logits = self.model(self.normalize(x_adv))
                adv_preds = adv_logits.argmax(dim=1)
            
            # Calculate perturbation norms
            delta = (x_adv - images).view(images.size(0), -1)
            l2_dist = torch.norm(delta, p=2, dim=1)
            linf_dist = torch.norm(delta, p=float('inf'), dim=1)
            
            # Track successful attacks (correctly classified ‚Üí misclassified)
            was_correct = (clean_preds == labels)
            is_wrong = (adv_preds != labels)
            successful = was_correct & is_wrong
            
            successful_attacks += successful.sum().item()
            total_samples += was_correct.sum().item()  # Only count correctly classified
            
            # Store results
            all_preds_clean.extend(clean_preds.cpu().numpy())
            all_preds_adv.extend(adv_preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_l2_dists.extend(l2_dist.cpu().numpy())
            all_linf_dists.extend(linf_dist.cpu().numpy())
            
            # Save examples for visualization
            if len(saved_examples) < config.num_examples_to_save:
                for i in range(min(5, images.size(0))):
                    if len(saved_examples) >= config.num_examples_to_save:
                        break
                    saved_examples.append({
                        "clean": images[i].cpu(),
                        "adversarial": x_adv[i].cpu(),
                        "perturbation": (x_adv[i] - images[i]).cpu(),
                        "true_label": labels[i].item(),
                        "clean_pred": clean_preds[i].item(),
                        "adv_pred": adv_preds[i].item()
                    })
        
        # Convert to arrays
        all_preds_clean = np.array(all_preds_clean)
        all_preds_adv = np.array(all_preds_adv)
        all_labels = np.array(all_labels)
        
        # Calculate metrics
        clean_acc = accuracy_score(all_labels, all_preds_clean)
        robust_acc = accuracy_score(all_labels, all_preds_adv)
        attack_success_rate = successful_attacks / max(total_samples, 1)
        
        # Per-class robust accuracy
        cm_adv = confusion_matrix(all_labels, all_preds_adv)
        per_class_robust_acc = cm_adv.diagonal() / cm_adv.sum(axis=1)
        
        return {
            "attack_name": attack_name,
            "epsilon": epsilon,
            "epsilon_str": config.get_epsilon_str(epsilon),
            "clean_accuracy": clean_acc,
            "robust_accuracy": robust_acc,
            "accuracy_drop": clean_acc - robust_acc,
            "attack_success_rate": attack_success_rate,
            "mean_l2_dist": np.mean(all_l2_dists),
            "mean_linf_dist": np.mean(all_linf_dists),
            "per_class_robust_accuracy": dict(zip(CLASS_NAMES, per_class_robust_acc)),
            "confusion_matrix": cm_adv,
            "saved_examples": saved_examples,
            "predictions_clean": all_preds_clean,
            "predictions_adv": all_preds_adv,
            "labels": all_labels
        }
    
    def create_fgsm_attack(self, epsilon: float) -> callable:
        """Create FGSM attack function."""
        fgsm_config = FGSMConfig(
            epsilon=epsilon,
            clip_min=0.0,
            clip_max=1.0,
            targeted=False,
            device=str(self.device)
        )
        attack = FGSM(fgsm_config)
        
        def attack_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            return attack.generate(
                self.model, x, y,
                loss_fn=nn.CrossEntropyLoss(),
                normalize=self.normalize
            )
        return attack_fn
    
    def create_pgd_attack(self, epsilon: float, num_steps: int = 40) -> callable:
        """Create PGD attack function."""
        step_size = epsilon / 4  # Standard choice
        
        pgd_config = PGDConfig(
            epsilon=epsilon,
            num_steps=num_steps,
            step_size=step_size,
            random_start=True,
            early_stop=False,
            clip_min=0.0,
            clip_max=1.0,
            targeted=False,
            device=str(self.device)
        )
        attack = PGD(pgd_config)
        
        def attack_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            return attack.generate(
                self.model, x, y,
                loss_fn=nn.CrossEntropyLoss(),
                normalize=self.normalize
            )
        return attack_fn
    
    def create_cw_attack(
        self,
        confidence: float = 0.0,
        max_iterations: int = 100,
        learning_rate: float = 0.01
    ) -> callable:
        """Create Carlini-Wagner L2 attack function."""
        cw_config = CWConfig(
            confidence=confidence,
            learning_rate=learning_rate,
            max_iterations=max_iterations,
            binary_search_steps=5,  # Reduced for speed
            initial_c=1e-3,
            clip_min=0.0,
            clip_max=1.0,
            targeted=False,
            device=str(self.device)
        )
        attack = CarliniWagner(cw_config)
        
        def attack_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            return attack.generate(
                self.model, x, y,
                normalize=self.normalize
            )
        return attack_fn

print("‚úÖ AdversarialEvaluator class defined")

In [None]:
#@title üì¶ Cell 8: Load Test Dataset and Models
#@markdown **Load test data and all seed checkpoints**

print("="*60)
print("üìä LOADING TEST DATASET")
print("="*60)

# Create test dataset
test_dataset = ISICTestDataset(
    root_dir=paths.test_dir,
    max_samples=config.max_test_samples
)

# Create dataloader
test_loader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=True
)

print(f"\n‚úÖ Test samples: {len(test_dataset)}")
print(f"‚úÖ Batches: {len(test_loader)}")
print(f"‚úÖ Batch size: {config.batch_size}")

# ============================================================================
# LOAD MODELS FOR ALL SEEDS
# ============================================================================

print("\n" + "="*60)
print("üîß LOADING MODEL CHECKPOINTS")
print("="*60)

models = {}
normalize = get_normalizer(device)

for seed in config.seeds:
    checkpoint_path = paths.checkpoint_dir / f"baseline_seed_{seed}.pt"
    
    if not checkpoint_path.exists():
        # Try alternative naming
        alt_path = paths.checkpoint_dir / f"seed_{seed}_best.pt"
        if alt_path.exists():
            checkpoint_path = alt_path
        else:
            print(f"‚ö†Ô∏è Checkpoint not found for seed {seed}")
            continue
    
    print(f"\nüì• Loading seed {seed}...")
    model = create_model(num_classes=NUM_CLASSES, pretrained=False)
    metadata = load_checkpoint(model, checkpoint_path, device)
    
    models[seed] = model
    
    print(f"   ‚úÖ Loaded: {checkpoint_path.name}")
    print(f"   üìà Validation accuracy: {metadata.get('val_acc', 'N/A')}")

print(f"\n‚úÖ Loaded {len(models)} models")

# ============================================================================
# SANITY CHECK: VERIFY CLEAN ACCURACY
# ============================================================================

print("\n" + "="*60)
print("üîç SANITY CHECK: CLEAN ACCURACY")
print("="*60)

clean_results = {}

for seed, model in models.items():
    print(f"\nüß™ Evaluating seed {seed}...")
    evaluator = AdversarialEvaluator(model, device, normalize)
    result = evaluator.evaluate_clean(test_loader, desc=f"Clean eval (seed {seed})")
    clean_results[seed] = result
    
    print(f"   ‚úÖ Accuracy: {result['accuracy']*100:.2f}%")
    print(f"   ‚úÖ Balanced Accuracy: {result['balanced_accuracy']*100:.2f}%")
    print(f"   ‚úÖ F1 (macro): {result['f1_macro']*100:.2f}%")
    print(f"   ‚úÖ AUROC: {result['auroc']*100:.2f}%")

# Summary statistics
mean_acc = np.mean([r['accuracy'] for r in clean_results.values()])
std_acc = np.std([r['accuracy'] for r in clean_results.values()])

print(f"\nüìä Mean Clean Accuracy: {mean_acc*100:.2f}% ¬± {std_acc*100:.2f}%")
print("="*60)

In [None]:
#@title üöÄ Cell 9: Run Full Adversarial Evaluation
#@markdown **Execute FGSM, PGD, and C&W attacks across all seeds and epsilons**

print("="*70)
print("üöÄ STARTING FULL ADVERSARIAL EVALUATION")
print("="*70)
print(f"‚è±Ô∏è  Start time: {datetime.now().strftime('%H:%M:%S')}")
print(f"üéØ Seeds: {config.seeds}")
print(f"üéØ Epsilons: {[config.get_epsilon_str(e) for e in config.epsilons]}")
print(f"üéØ Attacks: FGSM={'‚úì' if config.fgsm_enabled else '‚úó'}, "
      f"PGD={'‚úì' if config.pgd_enabled else '‚úó'}, "
      f"C&W={'‚úì' if config.cw_enabled else '‚úó'}")
print("="*70)

# Results storage
all_results = {
    "clean": clean_results,
    "fgsm": {},
    "pgd": {},
    "cw": {}
}

start_time = time.time()

for seed in config.seeds:
    if seed not in models:
        continue
    
    model = models[seed]
    evaluator = AdversarialEvaluator(model, device, normalize)
    
    print(f"\n{'='*60}")
    print(f"üî¨ EVALUATING SEED {seed}")
    print(f"{'='*60}")
    
    # Initialize seed results
    all_results["fgsm"][seed] = {}
    all_results["pgd"][seed] = {}
    all_results["cw"][seed] = {}
    
    # ========================================================================
    # FGSM ATTACKS
    # ========================================================================
    if config.fgsm_enabled:
        print(f"\n‚ö° FGSM Attacks (Seed {seed})")
        print("-" * 40)
        
        for eps in config.epsilons:
            attack_fn = evaluator.create_fgsm_attack(eps)
            result = evaluator.evaluate_attack(
                test_loader,
                attack_name="FGSM",
                epsilon=eps,
                attack_fn=attack_fn
            )
            all_results["fgsm"][seed][eps] = result
            
            print(f"   Œµ={config.get_epsilon_str(eps):>7}: "
                  f"Robust Acc = {result['robust_accuracy']*100:5.2f}% "
                  f"(‚Üì{result['accuracy_drop']*100:5.2f}pp)")
    
    # ========================================================================
    # PGD ATTACKS
    # ========================================================================
    if config.pgd_enabled:
        print(f"\nüîÑ PGD-{config.pgd_steps} Attacks (Seed {seed})")
        print("-" * 40)
        
        for eps in config.epsilons:
            attack_fn = evaluator.create_pgd_attack(eps, num_steps=config.pgd_steps)
            result = evaluator.evaluate_attack(
                test_loader,
                attack_name=f"PGD-{config.pgd_steps}",
                epsilon=eps,
                attack_fn=attack_fn
            )
            all_results["pgd"][seed][eps] = result
            
            print(f"   Œµ={config.get_epsilon_str(eps):>7}: "
                  f"Robust Acc = {result['robust_accuracy']*100:5.2f}% "
                  f"(‚Üì{result['accuracy_drop']*100:5.2f}pp)")
    
    # ========================================================================
    # C&W ATTACKS (Only at strongest epsilon for speed)
    # ========================================================================
    if config.cw_enabled:
        print(f"\nüéØ C&W L2 Attack (Seed {seed})")
        print("-" * 40)
        
        # C&W is epsilon-free (L2 minimization), run once
        attack_fn = evaluator.create_cw_attack(
            confidence=config.cw_confidence,
            max_iterations=config.cw_iterations,
            learning_rate=config.cw_learning_rate
        )
        result = evaluator.evaluate_attack(
            test_loader,
            attack_name="C&W-L2",
            epsilon=0.0,  # C&W minimizes L2 directly
            attack_fn=attack_fn,
            desc="C&W L2 Attack"
        )
        all_results["cw"][seed]["l2"] = result
        
        print(f"   Robust Acc = {result['robust_accuracy']*100:5.2f}% "
              f"(‚Üì{result['accuracy_drop']*100:5.2f}pp)")
        print(f"   Mean L2 perturbation: {result['mean_l2_dist']:.4f}")
    
    # Memory cleanup
    gc.collect()
    torch.cuda.empty_cache()

# Timing
elapsed = time.time() - start_time
print(f"\n{'='*70}")
print(f"‚úÖ EVALUATION COMPLETE")
print(f"‚è±Ô∏è  Total time: {elapsed/60:.1f} minutes")
print(f"{'='*70}")

In [None]:
#@title üìä Cell 10: Results Summary Table
#@markdown **Generate comprehensive results summary with statistics**

print("="*70)
print("üìä ADVERSARIAL ROBUSTNESS RESULTS SUMMARY")
print("="*70)

# ============================================================================
# BUILD RESULTS DATAFRAME
# ============================================================================

results_data = []

for seed in config.seeds:
    if seed not in clean_results:
        continue
    
    # Clean accuracy
    clean_acc = clean_results[seed]["accuracy"]
    
    # FGSM results
    if config.fgsm_enabled and seed in all_results["fgsm"]:
        for eps in config.epsilons:
            if eps in all_results["fgsm"][seed]:
                r = all_results["fgsm"][seed][eps]
                results_data.append({
                    "Seed": seed,
                    "Attack": "FGSM",
                    "Epsilon": config.get_epsilon_str(eps),
                    "Epsilon_Val": eps,
                    "Clean_Acc": clean_acc * 100,
                    "Robust_Acc": r["robust_accuracy"] * 100,
                    "Acc_Drop": r["accuracy_drop"] * 100,
                    "Attack_Success": r["attack_success_rate"] * 100,
                    "Mean_Linf": r["mean_linf_dist"]
                })
    
    # PGD results
    if config.pgd_enabled and seed in all_results["pgd"]:
        for eps in config.epsilons:
            if eps in all_results["pgd"][seed]:
                r = all_results["pgd"][seed][eps]
                results_data.append({
                    "Seed": seed,
                    "Attack": f"PGD-{config.pgd_steps}",
                    "Epsilon": config.get_epsilon_str(eps),
                    "Epsilon_Val": eps,
                    "Clean_Acc": clean_acc * 100,
                    "Robust_Acc": r["robust_accuracy"] * 100,
                    "Acc_Drop": r["accuracy_drop"] * 100,
                    "Attack_Success": r["attack_success_rate"] * 100,
                    "Mean_Linf": r["mean_linf_dist"]
                })
    
    # C&W results
    if config.cw_enabled and seed in all_results["cw"]:
        if "l2" in all_results["cw"][seed]:
            r = all_results["cw"][seed]["l2"]
            results_data.append({
                "Seed": seed,
                "Attack": "C&W-L2",
                "Epsilon": "N/A",
                "Epsilon_Val": 0,
                "Clean_Acc": clean_acc * 100,
                "Robust_Acc": r["robust_accuracy"] * 100,
                "Acc_Drop": r["accuracy_drop"] * 100,
                "Attack_Success": r["attack_success_rate"] * 100,
                "Mean_L2": r["mean_l2_dist"]
            })

df_results = pd.DataFrame(results_data)

# ============================================================================
# AGGREGATE STATISTICS (Mean ¬± Std across seeds)
# ============================================================================

print("\nüìà AGGREGATED RESULTS (Mean ¬± Std across seeds)")
print("-" * 70)

summary_data = []

# Group by Attack and Epsilon
for attack in df_results["Attack"].unique():
    for eps in df_results[df_results["Attack"] == attack]["Epsilon"].unique():
        subset = df_results[(df_results["Attack"] == attack) & (df_results["Epsilon"] == eps)]
        
        if len(subset) > 0:
            summary_data.append({
                "Attack": attack,
                "Epsilon": eps,
                "Clean_Acc": f"{subset['Clean_Acc'].mean():.2f} ¬± {subset['Clean_Acc'].std():.2f}",
                "Robust_Acc": f"{subset['Robust_Acc'].mean():.2f} ¬± {subset['Robust_Acc'].std():.2f}",
                "Acc_Drop": f"{subset['Acc_Drop'].mean():.2f} ¬± {subset['Acc_Drop'].std():.2f}",
                "Attack_Success": f"{subset['Attack_Success'].mean():.2f} ¬± {subset['Attack_Success'].std():.2f}"
            })

df_summary = pd.DataFrame(summary_data)
print(df_summary.to_string(index=False))

# ============================================================================
# DETAILED PER-SEED TABLE
# ============================================================================

print("\n\nüìã DETAILED RESULTS (Per Seed)")
print("-" * 70)

display_cols = ["Seed", "Attack", "Epsilon", "Clean_Acc", "Robust_Acc", "Acc_Drop", "Attack_Success"]
print(df_results[display_cols].to_string(index=False))

# ============================================================================
# KEY FINDINGS
# ============================================================================

print("\n\nüîë KEY FINDINGS")
print("=" * 70)

# Best/worst robust accuracy under PGD-8/255
if config.pgd_enabled:
    pgd_8 = df_results[(df_results["Attack"] == f"PGD-{config.pgd_steps}") & 
                       (df_results["Epsilon"] == "8/255")]
    if len(pgd_8) > 0:
        mean_robust = pgd_8["Robust_Acc"].mean()
        mean_drop = pgd_8["Acc_Drop"].mean()
        print(f"‚Ä¢ PGD-{config.pgd_steps} (Œµ=8/255): {mean_robust:.2f}% robust accuracy "
              f"(‚Üì{mean_drop:.2f}pp from clean)")

# FGSM vs PGD comparison
if config.fgsm_enabled and config.pgd_enabled:
    fgsm_8 = df_results[(df_results["Attack"] == "FGSM") & (df_results["Epsilon"] == "8/255")]
    pgd_8 = df_results[(df_results["Attack"] == f"PGD-{config.pgd_steps}") & 
                       (df_results["Epsilon"] == "8/255")]
    if len(fgsm_8) > 0 and len(pgd_8) > 0:
        fgsm_robust = fgsm_8["Robust_Acc"].mean()
        pgd_robust = pgd_8["Robust_Acc"].mean()
        print(f"‚Ä¢ FGSM vs PGD gap at Œµ=8/255: {abs(fgsm_robust - pgd_robust):.2f}pp")

# C&W results
if config.cw_enabled:
    cw_results = df_results[df_results["Attack"] == "C&W-L2"]
    if len(cw_results) > 0:
        mean_cw_robust = cw_results["Robust_Acc"].mean()
        print(f"‚Ä¢ C&W L2 attack: {mean_cw_robust:.2f}% robust accuracy")

print("=" * 70)

In [None]:
#@title üìà Cell 11: PhD-Level Visualization - Robustness Degradation Curves
#@markdown **Publication-quality robustness vs perturbation strength plots**

def create_robustness_curves(df_results: pd.DataFrame, config: EvaluationConfig) -> go.Figure:
    """
    Create interactive robustness degradation curves.
    
    Shows how accuracy degrades with increasing perturbation budget Œµ.
    """
    # Prepare data for plotting
    attacks = ["FGSM", f"PGD-{config.pgd_steps}"]
    colors = {"FGSM": "#FF6B6B", f"PGD-{config.pgd_steps}": "#4ECDC4"}
    markers = {"FGSM": "circle", f"PGD-{config.pgd_steps}": "square"}
    
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=(
            "<b>Robustness Degradation by Attack Type</b>",
            "<b>Accuracy Drop Severity</b>"
        ),
        horizontal_spacing=0.12
    )
    
    # Add clean accuracy reference line
    clean_acc = df_results["Clean_Acc"].mean()
    
    for attack in attacks:
        attack_data = df_results[df_results["Attack"] == attack].copy()
        if len(attack_data) == 0:
            continue
        
        # Group by epsilon
        grouped = attack_data.groupby("Epsilon_Val").agg({
            "Robust_Acc": ["mean", "std"],
            "Acc_Drop": ["mean", "std"]
        }).reset_index()
        grouped.columns = ["Epsilon", "Robust_Mean", "Robust_Std", "Drop_Mean", "Drop_Std"]
        grouped = grouped.sort_values("Epsilon")
        
        # Convert epsilon to string labels for x-axis
        epsilon_labels = [f"{int(e*255)}/255" for e in grouped["Epsilon"]]
        
        # Plot 1: Robustness curves with confidence bands
        fig.add_trace(
            go.Scatter(
                x=epsilon_labels,
                y=grouped["Robust_Mean"],
                mode="lines+markers",
                name=attack,
                line=dict(color=colors[attack], width=3),
                marker=dict(size=12, symbol=markers[attack]),
                error_y=dict(
                    type="data",
                    array=grouped["Robust_Std"],
                    visible=True,
                    color=colors[attack],
                    thickness=2
                ),
                legendgroup=attack,
                showlegend=True
            ),
            row=1, col=1
        )
        
        # Plot 2: Accuracy drop bars
        fig.add_trace(
            go.Bar(
                x=epsilon_labels,
                y=grouped["Drop_Mean"],
                name=attack,
                marker_color=colors[attack],
                error_y=dict(
                    type="data",
                    array=grouped["Drop_Std"],
                    visible=True
                ),
                legendgroup=attack,
                showlegend=False
            ),
            row=1, col=2
        )
    
    # Add clean accuracy reference
    fig.add_hline(
        y=clean_acc, 
        line_dash="dash", 
        line_color="gray",
        annotation_text=f"Clean: {clean_acc:.1f}%",
        row=1, col=1
    )
    
    # Update layout
    fig.update_layout(
        height=500,
        width=1100,
        title=dict(
            text="<b>Adversarial Robustness Analysis: Baseline ResNet-50 on ISIC 2018</b>",
            font=dict(size=18),
            x=0.5
        ),
        font=dict(family="Arial", size=12),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.08,
            xanchor="center",
            x=0.5
        ),
        paper_bgcolor="white",
        plot_bgcolor="white"
    )
    
    # Axis labels
    fig.update_xaxes(title_text="Perturbation Budget (Œµ)", row=1, col=1, gridcolor="lightgray")
    fig.update_xaxes(title_text="Perturbation Budget (Œµ)", row=1, col=2, gridcolor="lightgray")
    fig.update_yaxes(title_text="Robust Accuracy (%)", row=1, col=1, gridcolor="lightgray", range=[0, 100])
    fig.update_yaxes(title_text="Accuracy Drop (pp)", row=1, col=2, gridcolor="lightgray")
    
    return fig

# Create and display
fig_robustness = create_robustness_curves(df_results, config)
fig_robustness.show()

# Save figure
if paths.figures_dir.exists():
    fig_robustness.write_html(paths.figures_dir / "robustness_curves.html")
    fig_robustness.write_image(paths.figures_dir / "robustness_curves.png", scale=2)
    print(f"‚úÖ Saved to {paths.figures_dir / 'robustness_curves.png'}")

In [None]:
#@title üî• Cell 12: PhD-Level Visualization - Per-Class Vulnerability Heatmap
#@markdown **Detailed heatmap showing which skin lesion classes are most vulnerable**

def create_vulnerability_heatmap(all_results: dict, clean_results: dict, config: EvaluationConfig) -> go.Figure:
    """
    Create per-class vulnerability heatmap across attacks and epsilons.
    """
    # Build vulnerability matrix
    attack_configs = []
    
    # Add FGSM configs
    if config.fgsm_enabled:
        for eps in config.epsilons:
            attack_configs.append(("FGSM", eps, f"FGSM Œµ={config.get_epsilon_str(eps)}"))
    
    # Add PGD configs
    if config.pgd_enabled:
        for eps in config.epsilons:
            attack_configs.append((f"PGD-{config.pgd_steps}", eps, f"PGD-{config.pgd_steps} Œµ={config.get_epsilon_str(eps)}"))
    
    # Add C&W
    if config.cw_enabled:
        attack_configs.append(("C&W-L2", "l2", "C&W L2"))
    
    # Calculate mean per-class robust accuracy across seeds
    vulnerability_matrix = []
    
    for class_name in CLASS_NAMES:
        row = []
        for attack_type, eps_key, label in attack_configs:
            # Get attack dict key
            attack_key = "fgsm" if attack_type == "FGSM" else ("pgd" if "PGD" in attack_type else "cw")
            
            accuracies = []
            for seed in config.seeds:
                if seed not in all_results[attack_key]:
                    continue
                if eps_key not in all_results[attack_key][seed]:
                    continue
                
                result = all_results[attack_key][seed][eps_key]
                if class_name in result.get("per_class_robust_accuracy", {}):
                    accuracies.append(result["per_class_robust_accuracy"][class_name] * 100)
            
            if accuracies:
                row.append(np.mean(accuracies))
            else:
                row.append(np.nan)
        
        vulnerability_matrix.append(row)
    
    vulnerability_matrix = np.array(vulnerability_matrix)
    
    # Create heatmap
    fig = go.Figure(data=go.Heatmap(
        z=vulnerability_matrix,
        x=[c[2] for c in attack_configs],
        y=[f"{name}\n({CLASS_DESCRIPTIONS[name].split(' ')[0]})" for name in CLASS_NAMES],
        colorscale=[
            [0, "#d73027"],      # Red - low accuracy (vulnerable)
            [0.25, "#fc8d59"],   # Orange
            [0.5, "#fee08b"],    # Yellow
            [0.75, "#91cf60"],   # Light green
            [1, "#1a9850"]       # Dark green - high accuracy (robust)
        ],
        colorbar=dict(
            title="Robust<br>Accuracy (%)",
            titleside="right",
            ticksuffix="%"
        ),
        text=np.round(vulnerability_matrix, 1),
        texttemplate="%{text:.1f}%",
        textfont=dict(size=11, color="black"),
        hoverongaps=False,
        hovertemplate="<b>%{y}</b><br>Attack: %{x}<br>Robust Acc: %{z:.2f}%<extra></extra>"
    ))
    
    # Add clean accuracy comparison as annotation column
    fig.update_layout(
        title=dict(
            text="<b>Per-Class Adversarial Vulnerability Analysis</b><br>"
                 "<sup>Red = Vulnerable (low accuracy) | Green = Robust (high accuracy)</sup>",
            font=dict(size=16),
            x=0.5
        ),
        xaxis=dict(
            title="Attack Configuration",
            tickangle=45,
            side="bottom"
        ),
        yaxis=dict(
            title="Skin Lesion Class",
            autorange="reversed"
        ),
        height=550,
        width=1000,
        font=dict(family="Arial", size=12),
        paper_bgcolor="white",
        plot_bgcolor="white"
    )
    
    return fig

# Create and display
fig_heatmap = create_vulnerability_heatmap(all_results, clean_results, config)
fig_heatmap.show()

# Save
if paths.figures_dir.exists():
    fig_heatmap.write_html(paths.figures_dir / "vulnerability_heatmap.html")
    fig_heatmap.write_image(paths.figures_dir / "vulnerability_heatmap.png", scale=2)
    print(f"‚úÖ Saved to {paths.figures_dir / 'vulnerability_heatmap.png'}")

# ============================================================================
# IDENTIFY MOST VULNERABLE CLASSES
# ============================================================================

print("\nüéØ CLASS VULNERABILITY RANKING (Under PGD-40 Œµ=8/255)")
print("-" * 60)

# Get PGD-40 at 8/255 per-class results
eps_8 = 8/255
class_vuln = {}

for seed in config.seeds:
    if seed not in all_results["pgd"]:
        continue
    if eps_8 not in all_results["pgd"][seed]:
        continue
    
    result = all_results["pgd"][seed][eps_8]
    for class_name, acc in result.get("per_class_robust_accuracy", {}).items():
        if class_name not in class_vuln:
            class_vuln[class_name] = []
        class_vuln[class_name].append(acc * 100)

# Sort by vulnerability (lowest accuracy = most vulnerable)
sorted_vuln = sorted(
    [(name, np.mean(accs), np.std(accs)) for name, accs in class_vuln.items()],
    key=lambda x: x[1]
)

for rank, (name, mean_acc, std_acc) in enumerate(sorted_vuln, 1):
    status = "üî¥ CRITICAL" if mean_acc < 20 else "üü° MODERATE" if mean_acc < 50 else "üü¢ ROBUST"
    print(f"{rank}. {name:6} ({CLASS_DESCRIPTIONS[name]:30}): {mean_acc:5.1f}% ¬± {std_acc:4.1f}% {status}")

In [None]:
#@title üéØ Cell 13: PhD-Level Visualization - Clean vs Adversarial Confusion Matrices
#@markdown **Side-by-side confusion matrices showing attack impact on predictions**

def create_confusion_matrix_comparison(
    clean_results: dict, 
    all_results: dict, 
    seed: int,
    attack_key: str = "pgd",
    epsilon: float = 8/255
) -> plt.Figure:
    """
    Create side-by-side confusion matrices for clean vs adversarial.
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Get confusion matrices
    cm_clean = clean_results[seed]["confusion_matrix"]
    
    # Get adversarial CM
    if attack_key == "cw":
        cm_adv = all_results[attack_key][seed]["l2"]["confusion_matrix"]
        attack_label = "C&W L2"
    else:
        cm_adv = all_results[attack_key][seed][epsilon]["confusion_matrix"]
        eps_str = f"{int(epsilon*255)}/255"
        attack_label = f"PGD-{config.pgd_steps} (Œµ={eps_str})" if attack_key == "pgd" else f"FGSM (Œµ={eps_str})"
    
    # Normalize confusion matrices
    cm_clean_norm = cm_clean.astype(float) / cm_clean.sum(axis=1, keepdims=True)
    cm_adv_norm = cm_adv.astype(float) / cm_adv.sum(axis=1, keepdims=True)
    cm_diff = cm_adv_norm - cm_clean_norm
    
    # Plot 1: Clean CM
    im1 = axes[0].imshow(cm_clean_norm, cmap="Blues", vmin=0, vmax=1)
    axes[0].set_title(f"Clean Predictions\n(Seed {seed})", fontsize=14, fontweight="bold")
    axes[0].set_xlabel("Predicted")
    axes[0].set_ylabel("True Label")
    axes[0].set_xticks(range(len(CLASS_NAMES)))
    axes[0].set_yticks(range(len(CLASS_NAMES)))
    axes[0].set_xticklabels(CLASS_NAMES, rotation=45, ha="right")
    axes[0].set_yticklabels(CLASS_NAMES)
    
    # Add text annotations
    for i in range(len(CLASS_NAMES)):
        for j in range(len(CLASS_NAMES)):
            val = cm_clean_norm[i, j]
            color = "white" if val > 0.5 else "black"
            axes[0].text(j, i, f"{val:.2f}", ha="center", va="center", color=color, fontsize=9)
    
    plt.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)
    
    # Plot 2: Adversarial CM
    im2 = axes[1].imshow(cm_adv_norm, cmap="Reds", vmin=0, vmax=1)
    axes[1].set_title(f"Under {attack_label}\n(Seed {seed})", fontsize=14, fontweight="bold")
    axes[1].set_xlabel("Predicted")
    axes[1].set_ylabel("True Label")
    axes[1].set_xticks(range(len(CLASS_NAMES)))
    axes[1].set_yticks(range(len(CLASS_NAMES)))
    axes[1].set_xticklabels(CLASS_NAMES, rotation=45, ha="right")
    axes[1].set_yticklabels(CLASS_NAMES)
    
    for i in range(len(CLASS_NAMES)):
        for j in range(len(CLASS_NAMES)):
            val = cm_adv_norm[i, j]
            color = "white" if val > 0.5 else "black"
            axes[1].text(j, i, f"{val:.2f}", ha="center", va="center", color=color, fontsize=9)
    
    plt.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)
    
    # Plot 3: Difference (shows where predictions shifted)
    im3 = axes[2].imshow(cm_diff, cmap="RdBu_r", vmin=-0.5, vmax=0.5)
    axes[2].set_title(f"Prediction Shift\n(Adversarial - Clean)", fontsize=14, fontweight="bold")
    axes[2].set_xlabel("Predicted")
    axes[2].set_ylabel("True Label")
    axes[2].set_xticks(range(len(CLASS_NAMES)))
    axes[2].set_yticks(range(len(CLASS_NAMES)))
    axes[2].set_xticklabels(CLASS_NAMES, rotation=45, ha="right")
    axes[2].set_yticklabels(CLASS_NAMES)
    
    for i in range(len(CLASS_NAMES)):
        for j in range(len(CLASS_NAMES)):
            val = cm_diff[i, j]
            color = "white" if abs(val) > 0.25 else "black"
            axes[2].text(j, i, f"{val:+.2f}", ha="center", va="center", color=color, fontsize=9)
    
    cbar = plt.colorbar(im3, ax=axes[2], fraction=0.046, pad=0.04)
    cbar.set_label("Change in probability", rotation=270, labelpad=15)
    
    plt.suptitle(
        f"Confusion Matrix Analysis: Impact of {attack_label} Attack",
        fontsize=16, fontweight="bold", y=1.02
    )
    plt.tight_layout()
    
    return fig

# Create confusion matrix for first available seed
first_seed = config.seeds[0]
if first_seed in clean_results and first_seed in all_results.get("pgd", {}):
    fig_cm = create_confusion_matrix_comparison(
        clean_results, all_results, 
        seed=first_seed,
        attack_key="pgd",
        epsilon=8/255
    )
    plt.show()
    
    # Save
    if paths.figures_dir.exists():
        fig_cm.savefig(paths.figures_dir / "confusion_matrix_comparison.png", 
                       dpi=200, bbox_inches="tight", facecolor="white")
        print(f"‚úÖ Saved to {paths.figures_dir / 'confusion_matrix_comparison.png'}")
else:
    print("‚ö†Ô∏è Results not available for confusion matrix visualization")

In [None]:
#@title üï∏Ô∏è Cell 14: PhD-Level Visualization - Radar Chart & Attack Effectiveness
#@markdown **Multi-dimensional attack comparison using radar charts**

def create_radar_chart(all_results: dict, clean_results: dict, config: EvaluationConfig) -> go.Figure:
    """
    Create radar chart comparing attack effectiveness across multiple dimensions.
    """
    # Calculate mean metrics across seeds for strongest epsilon
    eps_8 = 8/255
    
    metrics = {
        "FGSM": {},
        f"PGD-{config.pgd_steps}": {},
        "C&W-L2": {}
    }
    
    dimensions = [
        "Attack Success Rate",
        "Accuracy Drop",
        "Per-Class Variance",
        "Mean Perturbation",
        "Speed (inverse)"
    ]
    
    # Calculate FGSM metrics
    if config.fgsm_enabled:
        fgsm_results = [all_results["fgsm"][s][eps_8] for s in config.seeds if s in all_results["fgsm"] and eps_8 in all_results["fgsm"][s]]
        if fgsm_results:
            metrics["FGSM"] = {
                "Attack Success Rate": np.mean([r["attack_success_rate"] for r in fgsm_results]) * 100,
                "Accuracy Drop": np.mean([r["accuracy_drop"] for r in fgsm_results]) * 100,
                "Per-Class Variance": np.mean([np.std(list(r["per_class_robust_accuracy"].values())) for r in fgsm_results]) * 100,
                "Mean Perturbation": np.mean([r["mean_linf_dist"] for r in fgsm_results]) * 255,  # Scale to 0-8
                "Speed (inverse)": 95  # FGSM is single-step, very fast
            }
    
    # Calculate PGD metrics
    if config.pgd_enabled:
        pgd_results = [all_results["pgd"][s][eps_8] for s in config.seeds if s in all_results["pgd"] and eps_8 in all_results["pgd"][s]]
        if pgd_results:
            metrics[f"PGD-{config.pgd_steps}"] = {
                "Attack Success Rate": np.mean([r["attack_success_rate"] for r in pgd_results]) * 100,
                "Accuracy Drop": np.mean([r["accuracy_drop"] for r in pgd_results]) * 100,
                "Per-Class Variance": np.mean([np.std(list(r["per_class_robust_accuracy"].values())) for r in pgd_results]) * 100,
                "Mean Perturbation": np.mean([r["mean_linf_dist"] for r in pgd_results]) * 255,
                "Speed (inverse)": 40  # PGD-40 is 40x slower than FGSM
            }
    
    # Calculate C&W metrics
    if config.cw_enabled:
        cw_results = [all_results["cw"][s]["l2"] for s in config.seeds if s in all_results["cw"] and "l2" in all_results["cw"][s]]
        if cw_results:
            metrics["C&W-L2"] = {
                "Attack Success Rate": np.mean([r["attack_success_rate"] for r in cw_results]) * 100,
                "Accuracy Drop": np.mean([r["accuracy_drop"] for r in cw_results]) * 100,
                "Per-Class Variance": np.mean([np.std(list(r["per_class_robust_accuracy"].values())) for r in cw_results]) * 100,
                "Mean Perturbation": np.mean([r["mean_l2_dist"] for r in cw_results]) * 10,  # Scale L2 for visibility
                "Speed (inverse)": 10  # C&W is slowest
            }
    
    # Create radar chart
    fig = go.Figure()
    
    colors = {"FGSM": "#FF6B6B", f"PGD-{config.pgd_steps}": "#4ECDC4", "C&W-L2": "#9B59B6"}
    
    for attack_name, data in metrics.items():
        if not data:
            continue
        
        values = [data.get(dim, 0) for dim in dimensions]
        values.append(values[0])  # Close the polygon
        
        fig.add_trace(go.Scatterpolar(
            r=values,
            theta=dimensions + [dimensions[0]],
            fill='toself',
            name=attack_name,
            line=dict(color=colors[attack_name], width=2),
            fillcolor=colors[attack_name],
            opacity=0.3
        ))
    
    fig.update_layout(
        polar=dict(
            radialaxis=dict(
                visible=True,
                range=[0, 100],
                ticksuffix="%"
            ),
            angularaxis=dict(
                tickfont=dict(size=12)
            )
        ),
        title=dict(
            text="<b>Attack Effectiveness Comparison (Radar Chart)</b><br>"
                 "<sup>Higher values = more effective/impactful attack</sup>",
            font=dict(size=16),
            x=0.5
        ),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=-0.2,
            xanchor="center",
            x=0.5
        ),
        height=600,
        width=700
    )
    
    return fig

# Create radar chart
fig_radar = create_radar_chart(all_results, clean_results, config)
fig_radar.show()

# Save
if paths.figures_dir.exists():
    fig_radar.write_html(paths.figures_dir / "attack_radar_chart.html")
    fig_radar.write_image(paths.figures_dir / "attack_radar_chart.png", scale=2)
    print(f"‚úÖ Saved to {paths.figures_dir / 'attack_radar_chart.png'}")

In [None]:
#@title üñºÔ∏è Cell 15: PhD-Level Visualization - Adversarial Example Gallery
#@markdown **Visualize clean images, perturbations, and adversarial examples**

def visualize_adversarial_examples(
    all_results: dict,
    attack_key: str = "pgd",
    epsilon: float = 8/255,
    num_examples: int = 5,
    amplification: float = 10.0
) -> plt.Figure:
    """
    Create publication-quality adversarial example visualization.
    
    Shows:
    - Clean image
    - Perturbation (amplified for visibility)
    - Adversarial image
    - Prediction change
    """
    # Get saved examples from first seed
    first_seed = config.seeds[0]
    
    if attack_key == "cw":
        examples = all_results[attack_key][first_seed]["l2"].get("saved_examples", [])
    else:
        examples = all_results[attack_key][first_seed][epsilon].get("saved_examples", [])
    
    if not examples:
        print("‚ö†Ô∏è No saved examples available")
        return None
    
    # Select successful attacks (prediction changed)
    successful = [e for e in examples if e["clean_pred"] != e["adv_pred"]]
    if len(successful) < num_examples:
        successful = examples  # Use all if not enough successful
    
    num_examples = min(num_examples, len(successful))
    
    fig, axes = plt.subplots(num_examples, 4, figsize=(16, 4 * num_examples))
    if num_examples == 1:
        axes = axes.reshape(1, -1)
    
    for idx, example in enumerate(successful[:num_examples]):
        # Get tensors
        clean = example["clean"].numpy().transpose(1, 2, 0)  # CHW -> HWC
        adv = example["adversarial"].numpy().transpose(1, 2, 0)
        perturbation = example["perturbation"].numpy().transpose(1, 2, 0)
        
        true_label = CLASS_NAMES[example["true_label"]]
        clean_pred = CLASS_NAMES[example["clean_pred"]]
        adv_pred = CLASS_NAMES[example["adv_pred"]]
        
        # Ensure valid range
        clean = np.clip(clean, 0, 1)
        adv = np.clip(adv, 0, 1)
        
        # Amplify perturbation for visibility
        pert_amplified = perturbation * amplification + 0.5
        pert_amplified = np.clip(pert_amplified, 0, 1)
        
        # L2 and Linf norms
        l2_norm = np.sqrt(np.sum(perturbation ** 2))
        linf_norm = np.max(np.abs(perturbation))
        
        # Plot clean image
        axes[idx, 0].imshow(clean)
        axes[idx, 0].set_title(f"Clean Image\nTrue: {true_label}\nPred: {clean_pred}", fontsize=11)
        axes[idx, 0].axis("off")
        if example["clean_pred"] == example["true_label"]:
            axes[idx, 0].spines[:].set_visible(True)
            for spine in axes[idx, 0].spines.values():
                spine.set_edgecolor('green')
                spine.set_linewidth(3)
        
        # Plot perturbation
        axes[idx, 1].imshow(pert_amplified)
        axes[idx, 1].set_title(f"Perturbation (√ó{amplification:.0f})\nL‚àû: {linf_norm*255:.2f}/255\nL2: {l2_norm:.4f}", fontsize=11)
        axes[idx, 1].axis("off")
        
        # Plot adversarial image
        axes[idx, 2].imshow(adv)
        axes[idx, 2].set_title(f"Adversarial Image\nPred: {adv_pred}", fontsize=11)
        axes[idx, 2].axis("off")
        if example["adv_pred"] != example["true_label"]:
            for spine in axes[idx, 2].spines.values():
                spine.set_visible(True)
                spine.set_edgecolor('red')
                spine.set_linewidth(3)
        
        # Plot difference heatmap
        diff = np.mean(np.abs(adv - clean), axis=2)  # Average across channels
        im = axes[idx, 3].imshow(diff, cmap="hot", vmin=0, vmax=epsilon * 2)
        axes[idx, 3].set_title(f"Absolute Difference\n({clean_pred} ‚Üí {adv_pred})", fontsize=11)
        axes[idx, 3].axis("off")
        plt.colorbar(im, ax=axes[idx, 3], fraction=0.046, pad=0.04)
    
    attack_label = "C&W L2" if attack_key == "cw" else f"{'PGD' if attack_key == 'pgd' else 'FGSM'}-{config.get_epsilon_str(epsilon)}"
    plt.suptitle(
        f"Adversarial Example Gallery: {attack_label} Attack\n"
        f"Green border = correct, Red border = misclassified",
        fontsize=16, fontweight="bold", y=1.02
    )
    plt.tight_layout()
    
    return fig

# Create visualization for PGD attack
if config.pgd_enabled and config.seeds[0] in all_results.get("pgd", {}):
    fig_examples = visualize_adversarial_examples(
        all_results,
        attack_key="pgd",
        epsilon=8/255,
        num_examples=5,
        amplification=10.0
    )
    if fig_examples:
        plt.show()
        
        # Save
        if paths.figures_dir.exists():
            fig_examples.savefig(paths.figures_dir / "adversarial_examples_pgd.png",
                                 dpi=200, bbox_inches="tight", facecolor="white")
            print(f"‚úÖ Saved to {paths.figures_dir / 'adversarial_examples_pgd.png'}")

# Also show FGSM examples if available
if config.fgsm_enabled and config.seeds[0] in all_results.get("fgsm", {}):
    print("\n" + "="*60)
    print("üì∏ FGSM Attack Examples")
    print("="*60)
    
    fig_fgsm = visualize_adversarial_examples(
        all_results,
        attack_key="fgsm",
        epsilon=8/255,
        num_examples=3,
        amplification=10.0
    )
    if fig_fgsm:
        plt.show()
        
        if paths.figures_dir.exists():
            fig_fgsm.savefig(paths.figures_dir / "adversarial_examples_fgsm.png",
                            dpi=200, bbox_inches="tight", facecolor="white")
            print(f"‚úÖ Saved to {paths.figures_dir / 'adversarial_examples_fgsm.png'}")

In [None]:
#@title üìä Cell 16: PhD-Level Analysis - Statistical Significance & Seed Consistency
#@markdown **Bootstrap confidence intervals and cross-seed consistency analysis**

from scipy import stats

def statistical_analysis(df_results: pd.DataFrame, config: EvaluationConfig) -> Dict[str, Any]:
    """
    Perform rigorous statistical analysis of adversarial evaluation results.
    """
    results = {}
    
    print("="*70)
    print("üìä STATISTICAL ANALYSIS OF ADVERSARIAL ROBUSTNESS")
    print("="*70)
    
    # 1. Cross-seed consistency (Coefficient of Variation)
    print("\n1Ô∏è‚É£ CROSS-SEED CONSISTENCY (Coefficient of Variation)")
    print("-" * 50)
    
    for attack in df_results["Attack"].unique():
        for eps in df_results[df_results["Attack"] == attack]["Epsilon"].unique():
            subset = df_results[(df_results["Attack"] == attack) & (df_results["Epsilon"] == eps)]
            if len(subset) >= 2:
                mean_acc = subset["Robust_Acc"].mean()
                std_acc = subset["Robust_Acc"].std()
                cv = (std_acc / mean_acc) * 100 if mean_acc > 0 else 0
                
                consistency = "‚úÖ High" if cv < 5 else "‚ö†Ô∏è Medium" if cv < 10 else "‚ùå Low"
                print(f"   {attack:15} {eps:>7}: CV = {cv:5.2f}% {consistency}")
    
    # 2. Bootstrap 95% Confidence Intervals
    print("\n2Ô∏è‚É£ BOOTSTRAP 95% CONFIDENCE INTERVALS")
    print("-" * 50)
    
    def bootstrap_ci(data, n_bootstrap=1000, ci=0.95):
        """Calculate bootstrap confidence interval."""
        if len(data) < 2:
            return data.mean(), data.mean(), data.mean()
        
        bootstraps = []
        for _ in range(n_bootstrap):
            sample = np.random.choice(data, size=len(data), replace=True)
            bootstraps.append(sample.mean())
        
        lower = np.percentile(bootstraps, (1 - ci) / 2 * 100)
        upper = np.percentile(bootstraps, (1 + ci) / 2 * 100)
        return data.mean(), lower, upper
    
    ci_results = []
    for attack in df_results["Attack"].unique():
        attack_data = df_results[df_results["Attack"] == attack]
        
        if attack == "C&W-L2":
            subset = attack_data
            mean, lower, upper = bootstrap_ci(subset["Robust_Acc"].values)
            ci_results.append({
                "Attack": attack,
                "Epsilon": "N/A",
                "Mean": mean,
                "CI_Lower": lower,
                "CI_Upper": upper
            })
            print(f"   {attack:15} {'N/A':>7}: {mean:5.2f}% [{lower:5.2f}%, {upper:5.2f}%]")
        else:
            for eps in attack_data["Epsilon"].unique():
                subset = attack_data[attack_data["Epsilon"] == eps]
                mean, lower, upper = bootstrap_ci(subset["Robust_Acc"].values)
                ci_results.append({
                    "Attack": attack,
                    "Epsilon": eps,
                    "Mean": mean,
                    "CI_Lower": lower,
                    "CI_Upper": upper
                })
                print(f"   {attack:15} {eps:>7}: {mean:5.2f}% [{lower:5.2f}%, {upper:5.2f}%]")
    
    results["confidence_intervals"] = ci_results
    
    # 3. Attack Comparison (Paired t-test: FGSM vs PGD)
    print("\n3Ô∏è‚É£ ATTACK COMPARISON (Paired t-test at Œµ=8/255)")
    print("-" * 50)
    
    if config.fgsm_enabled and config.pgd_enabled:
        eps_8 = "8/255"
        fgsm_acc = df_results[(df_results["Attack"] == "FGSM") & (df_results["Epsilon"] == eps_8)]["Robust_Acc"].values
        pgd_acc = df_results[(df_results["Attack"] == f"PGD-{config.pgd_steps}") & (df_results["Epsilon"] == eps_8)]["Robust_Acc"].values
        
        if len(fgsm_acc) >= 2 and len(pgd_acc) >= 2 and len(fgsm_acc) == len(pgd_acc):
            t_stat, p_value = stats.ttest_rel(fgsm_acc, pgd_acc)
            
            significance = "***" if p_value < 0.001 else "**" if p_value < 0.01 else "*" if p_value < 0.05 else "ns"
            
            print(f"   FGSM mean: {fgsm_acc.mean():.2f}%")
            print(f"   PGD mean:  {pgd_acc.mean():.2f}%")
            print(f"   Difference: {(fgsm_acc.mean() - pgd_acc.mean()):.2f}pp")
            print(f"   t-statistic: {t_stat:.3f}")
            print(f"   p-value: {p_value:.4f} {significance}")
            
            results["fgsm_vs_pgd"] = {
                "t_statistic": t_stat,
                "p_value": p_value,
                "significant": p_value < 0.05
            }
    
    # 4. Effect Size (Cohen's d)
    print("\n4Ô∏è‚É£ EFFECT SIZE (Cohen's d): Clean vs Adversarial")
    print("-" * 50)
    
    for attack in df_results["Attack"].unique():
        attack_data = df_results[df_results["Attack"] == attack]
        clean_acc = attack_data["Clean_Acc"].values
        robust_acc = attack_data["Robust_Acc"].values
        
        # Cohen's d
        pooled_std = np.sqrt((clean_acc.std()**2 + robust_acc.std()**2) / 2)
        if pooled_std > 0:
            cohens_d = (clean_acc.mean() - robust_acc.mean()) / pooled_std
        else:
            cohens_d = 0
        
        effect = "Negligible" if abs(cohens_d) < 0.2 else "Small" if abs(cohens_d) < 0.5 else "Medium" if abs(cohens_d) < 0.8 else "Large"
        print(f"   {attack:15}: d = {cohens_d:6.2f} ({effect})")
    
    print("\n" + "="*70)
    
    return results

# Run statistical analysis
stat_results = statistical_analysis(df_results, config)

# ============================================================================
# VISUALIZATION: SEED CONSISTENCY BOX PLOT
# ============================================================================

def create_seed_comparison_boxplot(df_results: pd.DataFrame) -> go.Figure:
    """Create box plot comparing robust accuracy across seeds."""
    
    fig = go.Figure()
    
    colors = {"FGSM": "#FF6B6B", f"PGD-{config.pgd_steps}": "#4ECDC4", "C&W-L2": "#9B59B6"}
    
    for attack in df_results["Attack"].unique():
        attack_data = df_results[df_results["Attack"] == attack]
        
        fig.add_trace(go.Box(
            y=attack_data["Robust_Acc"],
            x=[attack] * len(attack_data),
            name=attack,
            marker_color=colors.get(attack, "#888888"),
            boxpoints="all",
            jitter=0.3,
            pointpos=-1.8,
            hovertemplate="Seed: %{text}<br>Robust Acc: %{y:.2f}%<extra></extra>",
            text=attack_data["Seed"].astype(str)
        ))
    
    fig.update_layout(
        title=dict(
            text="<b>Cross-Seed Consistency: Robust Accuracy Distribution</b>",
            font=dict(size=16),
            x=0.5
        ),
        xaxis_title="Attack Type",
        yaxis_title="Robust Accuracy (%)",
        yaxis=dict(range=[0, 100]),
        showlegend=False,
        height=500,
        width=800
    )
    
    return fig

fig_boxplot = create_seed_comparison_boxplot(df_results)
fig_boxplot.show()

# Save
if paths.figures_dir.exists():
    fig_boxplot.write_html(paths.figures_dir / "seed_consistency_boxplot.html")
    fig_boxplot.write_image(paths.figures_dir / "seed_consistency_boxplot.png", scale=2)
    print(f"‚úÖ Saved to {paths.figures_dir / 'seed_consistency_boxplot.png'}")

In [None]:
#@title üíæ Cell 17: Save Results & Export for Dissertation
#@markdown **Export all results in multiple formats for dissertation use**

print("="*70)
print("üíæ SAVING RESULTS & EXPORTS")
print("="*70)

# ============================================================================
# SAVE RESULTS AS CSV
# ============================================================================

# Detailed results CSV
csv_path = paths.results_dir / "adversarial_robustness_results.csv"
df_results.to_csv(csv_path, index=False)
print(f"‚úÖ Detailed results: {csv_path}")

# Summary statistics CSV  
summary_path = paths.results_dir / "adversarial_robustness_summary.csv"
df_summary.to_csv(summary_path, index=False)
print(f"‚úÖ Summary statistics: {summary_path}")

# ============================================================================
# SAVE AS JSON (for programmatic access)
# ============================================================================

json_results = {
    "metadata": {
        "evaluation_date": datetime.now().isoformat(),
        "seeds": config.seeds,
        "epsilons": [float(e) for e in config.epsilons],
        "attacks": {
            "fgsm": config.fgsm_enabled,
            "pgd": config.pgd_enabled,
            "pgd_steps": config.pgd_steps,
            "cw": config.cw_enabled,
            "cw_iterations": config.cw_iterations
        },
        "dataset": "ISIC 2018",
        "model": "ResNet-50",
        "num_classes": NUM_CLASSES
    },
    "clean_accuracy": {
        str(seed): {
            "accuracy": float(clean_results[seed]["accuracy"]),
            "balanced_accuracy": float(clean_results[seed]["balanced_accuracy"]),
            "f1_macro": float(clean_results[seed]["f1_macro"]),
            "auroc": float(clean_results[seed]["auroc"])
        }
        for seed in config.seeds if seed in clean_results
    },
    "adversarial_results": df_results.to_dict(orient="records")
}

json_path = paths.results_dir / "adversarial_robustness_results.json"
with open(json_path, "w") as f:
    json.dump(json_results, f, indent=2, default=str)
print(f"‚úÖ JSON results: {json_path}")

# ============================================================================
# GENERATE LATEX TABLE FOR DISSERTATION
# ============================================================================

def generate_latex_table(df_results: pd.DataFrame, config: EvaluationConfig) -> str:
    """Generate LaTeX table for dissertation."""
    
    lines = [
        "\\begin{table}[htbp]",
        "\\centering",
        "\\caption{Adversarial Robustness Evaluation: Baseline ResNet-50 on ISIC 2018}",
        "\\label{tab:adversarial_robustness}",
        "\\begin{tabular}{llcccc}",
        "\\toprule",
        "Attack & $\\epsilon$ & Clean Acc (\\%) & Robust Acc (\\%) & Acc Drop (pp) & ASR (\\%) \\\\",
        "\\midrule"
    ]
    
    # Group and average across seeds
    for attack in df_results["Attack"].unique():
        attack_data = df_results[df_results["Attack"] == attack]
        
        for eps in attack_data["Epsilon"].unique():
            subset = attack_data[attack_data["Epsilon"] == eps]
            
            clean = f"{subset['Clean_Acc'].mean():.1f} $\\pm$ {subset['Clean_Acc'].std():.1f}"
            robust = f"{subset['Robust_Acc'].mean():.1f} $\\pm$ {subset['Robust_Acc'].std():.1f}"
            drop = f"{subset['Acc_Drop'].mean():.1f} $\\pm$ {subset['Acc_Drop'].std():.1f}"
            asr = f"{subset['Attack_Success'].mean():.1f} $\\pm$ {subset['Attack_Success'].std():.1f}"
            
            lines.append(f"{attack} & {eps} & {clean} & {robust} & {drop} & {asr} \\\\")
    
    lines.extend([
        "\\bottomrule",
        "\\end{tabular}",
        "\\begin{tablenotes}",
        "\\small",
        "\\item Note: Values are mean $\\pm$ std across seeds (42, 123, 456). ASR = Attack Success Rate.",
        "\\end{tablenotes}",
        "\\end{table}"
    ])
    
    return "\n".join(lines)

latex_table = generate_latex_table(df_results, config)

# Save LaTeX table
latex_path = paths.results_dir / "adversarial_robustness_table.tex"
with open(latex_path, "w") as f:
    f.write(latex_table)
print(f"‚úÖ LaTeX table: {latex_path}")

# Print LaTeX table
print("\nüìÑ LATEX TABLE FOR DISSERTATION:")
print("-" * 60)
print(latex_table)

# ============================================================================
# LIST ALL SAVED FILES
# ============================================================================

print("\n" + "="*70)
print("üìÅ ALL SAVED FILES")
print("="*70)

if paths.results_dir.exists():
    for f in sorted(paths.results_dir.rglob("*")):
        if f.is_file():
            size_kb = f.stat().st_size / 1024
            print(f"   {f.relative_to(paths.results_dir)} ({size_kb:.1f} KB)")

print("="*70)

In [None]:
#@title üìã Cell 18: Executive Summary & Key Findings
#@markdown **Final summary of adversarial robustness evaluation**

def print_executive_summary(df_results: pd.DataFrame, clean_results: dict, config: EvaluationConfig):
    """Generate executive summary of evaluation results."""
    
    print("‚ïî" + "‚ïê"*68 + "‚ïó")
    print("‚ïë" + " "*20 + "EXECUTIVE SUMMARY" + " "*31 + "‚ïë")
    print("‚ïë" + " "*10 + "Phase 4: Adversarial Robustness Evaluation" + " "*15 + "‚ïë")
    print("‚ïö" + "‚ïê"*68 + "‚ïù")
    
    # Clean accuracy
    mean_clean = np.mean([r["accuracy"] for r in clean_results.values()]) * 100
    std_clean = np.std([r["accuracy"] for r in clean_results.values()]) * 100
    
    print(f"\nüìä BASELINE CLEAN PERFORMANCE")
    print(f"   ‚Ä¢ Clean Accuracy: {mean_clean:.2f}% ¬± {std_clean:.2f}%")
    print(f"   ‚Ä¢ Model: ResNet-50 (ImageNet pretrained)")
    print(f"   ‚Ä¢ Dataset: ISIC 2018 (7 classes)")
    print(f"   ‚Ä¢ Seeds evaluated: {config.seeds}")
    
    # Adversarial results summary
    print(f"\nüõ°Ô∏è ADVERSARIAL ROBUSTNESS FINDINGS")
    
    # FGSM
    if config.fgsm_enabled:
        fgsm_8 = df_results[(df_results["Attack"] == "FGSM") & (df_results["Epsilon"] == "8/255")]
        if len(fgsm_8) > 0:
            mean_robust = fgsm_8["Robust_Acc"].mean()
            mean_drop = fgsm_8["Acc_Drop"].mean()
            print(f"\n   FGSM (Œµ=8/255):")
            print(f"   ‚îú‚îÄ Robust Accuracy: {mean_robust:.2f}%")
            print(f"   ‚îú‚îÄ Accuracy Drop: {mean_drop:.2f}pp")
            print(f"   ‚îî‚îÄ Interpretation: {'Moderate vulnerability' if mean_drop < 40 else 'High vulnerability'}")
    
    # PGD
    if config.pgd_enabled:
        pgd_8 = df_results[(df_results["Attack"] == f"PGD-{config.pgd_steps}") & (df_results["Epsilon"] == "8/255")]
        if len(pgd_8) > 0:
            mean_robust = pgd_8["Robust_Acc"].mean()
            mean_drop = pgd_8["Acc_Drop"].mean()
            print(f"\n   PGD-{config.pgd_steps} (Œµ=8/255):")
            print(f"   ‚îú‚îÄ Robust Accuracy: {mean_robust:.2f}%")
            print(f"   ‚îú‚îÄ Accuracy Drop: {mean_drop:.2f}pp")
            print(f"   ‚îî‚îÄ Interpretation: {'Severe vulnerability' if mean_drop > 60 else 'High vulnerability' if mean_drop > 40 else 'Moderate'}")
    
    # C&W
    if config.cw_enabled:
        cw = df_results[df_results["Attack"] == "C&W-L2"]
        if len(cw) > 0:
            mean_robust = cw["Robust_Acc"].mean()
            mean_drop = cw["Acc_Drop"].mean()
            print(f"\n   C&W L2:")
            print(f"   ‚îú‚îÄ Robust Accuracy: {mean_robust:.2f}%")
            print(f"   ‚îú‚îÄ Accuracy Drop: {mean_drop:.2f}pp")
            print(f"   ‚îî‚îÄ Interpretation: Strongest attack, minimal perturbation")
    
    # Key insights
    print(f"\nüîë KEY INSIGHTS")
    print(f"   1. Standard CNNs are highly vulnerable to adversarial attacks")
    print(f"   2. PGD attacks are stronger than FGSM (multi-step > single-step)")
    print(f"   3. Some classes (e.g., MEL, BCC) may be more vulnerable than others")
    print(f"   4. Adversarial training is needed for robust medical AI deployment")
    
    # Research implications
    print(f"\nüìö RESEARCH IMPLICATIONS")
    print(f"   ‚Ä¢ These results motivate Phase 5: Adversarial Training")
    print(f"   ‚Ä¢ Tri-objective optimization will balance accuracy, robustness, and explainability")
    print(f"   ‚Ä¢ Medical AI systems require robustness validation before clinical use")
    
    # Files generated
    print(f"\nüìÅ GENERATED OUTPUTS")
    print(f"   ‚Ä¢ Results CSV: adversarial_robustness_results.csv")
    print(f"   ‚Ä¢ Summary CSV: adversarial_robustness_summary.csv")
    print(f"   ‚Ä¢ JSON results: adversarial_robustness_results.json")
    print(f"   ‚Ä¢ LaTeX table: adversarial_robustness_table.tex")
    print(f"   ‚Ä¢ Figures: robustness_curves.png, vulnerability_heatmap.png, etc.")
    
    print("\n" + "‚ïê"*70)
    print(f"‚úÖ Phase 4 Adversarial Robustness Evaluation COMPLETE")
    print(f"‚è±Ô∏è  Completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print("‚ïê"*70)

# Print executive summary
print_executive_summary(df_results, clean_results, config)

# ============================================================================
# FINAL MEMORY CLEANUP
# ============================================================================

gc.collect()
torch.cuda.empty_cache()
print("\nüßπ Memory cleaned up")