# Post-Training Model Analysis

This notebook provides interactive analysis of your trained SwinUNETR model.
Use this for detailed examination of model performance and planning improvements.

In [None]:
import sys
import json
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.utils.data import DataLoader
import pandas as pd

# Add project root to path
sys.path.append(str(Path.cwd()))

from axon_ia.config import ConfigParser
from axon_ia.data import AxonDataset
from axon_ia.models import create_model
from axon_ia.evaluation.metrics import compute_metrics
from axon_ia.utils.nifti_utils import load_nifti

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("Setup complete!")

## 1. Configuration and Paths

In [None]:
# Update these paths based on your training results
CONFIG_PATH = "configs/training/swinunetr_config.yaml"
CHECKPOINT_PATH = "/path/to/your/best_model.pth"  # Update this!
METRICS_FILE = "/path/to/evaluation/metrics.json"  # Update this!
DATA_DIR = "C:/development/data/axon_ia/processed"  # Update if different

# Load configuration
config = ConfigParser(CONFIG_PATH)
print(f"Loaded config from: {CONFIG_PATH}")
print(f"Model architecture: {config.get('model.architecture')}")
print(f"Data directory: {config.get('data.root_dir')}")

## 2. Load and Analyze Metrics

In [None]:
# Load metrics (run evaluation script first if this file doesn't exist)
try:
    with open(METRICS_FILE, 'r') as f:
        metrics_data = json.load(f)
    
    print("Overall Metrics:")
    for metric, value in metrics_data['overall'].items():
        print(f"  {metric}: {value:.4f}")
    
    # Convert to DataFrame for easier analysis
    per_patient_df = pd.DataFrame(metrics_data['per_patient']).T
    print(f"\nLoaded metrics for {len(per_patient_df)} patients")
    
except FileNotFoundError:
    print(f"Metrics file not found: {METRICS_FILE}")
    print("Run the evaluation script first:")
    print(f"python scripts/evaluate.py --config {CONFIG_PATH} --checkpoint {CHECKPOINT_PATH} --generate-report")
    per_patient_df = None

In [None]:
# Metrics distribution analysis
if per_patient_df is not None:
    # Summary statistics
    print("\nSummary Statistics:")
    print(per_patient_df.describe())
    
    # Plot distributions
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    metrics = list(per_patient_df.columns)
    for i, metric in enumerate(metrics[:6]):  # Plot first 6 metrics
        if i < len(axes):
            axes[i].hist(per_patient_df[metric], bins=20, alpha=0.7, edgecolor='black')
            axes[i].set_title(f'{metric.replace("_", " ").title()} Distribution')
            axes[i].set_xlabel('Score')
            axes[i].set_ylabel('Frequency')
            
            # Add mean line
            mean_val = per_patient_df[metric].mean()
            axes[i].axvline(mean_val, color='red', linestyle='--', 
                          label=f'Mean: {mean_val:.3f}')
            axes[i].legend()
    
    # Hide unused subplots
    for i in range(len(metrics), len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Identify best and worst performing cases
if per_patient_df is not None:
    # Sort by Dice score
    if 'dice' in per_patient_df.columns:
        sorted_df = per_patient_df.sort_values('dice', ascending=False)
        
        print("\nTop 5 Best Performing Cases:")
        print(sorted_df.head())
        
        print("\nTop 5 Worst Performing Cases:")
        print(sorted_df.tail())
        
        # Correlation analysis
        print("\nMetric Correlations:")
        correlation_matrix = per_patient_df.corr()
        
        plt.figure(figsize=(10, 8))
        sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0,
                   square=True, fmt='.3f')
        plt.title('Metric Correlations')
        plt.tight_layout()
        plt.show()

## 3. Model Performance Analysis

In [None]:
# Performance benchmarking
if per_patient_df is not None and 'dice' in per_patient_df.columns:
    dice_scores = per_patient_df['dice']
    
    # Performance categories
    excellent = (dice_scores >= 0.8).sum()
    good = ((dice_scores >= 0.7) & (dice_scores < 0.8)).sum()
    fair = ((dice_scores >= 0.5) & (dice_scores < 0.7)).sum()
    poor = (dice_scores < 0.5).sum()
    
    print(f"\nPerformance Categories (Dice Score):")
    print(f"  Excellent (≥0.8): {excellent} cases ({excellent/len(dice_scores)*100:.1f}%)")
    print(f"  Good (0.7-0.8): {good} cases ({good/len(dice_scores)*100:.1f}%)")
    print(f"  Fair (0.5-0.7): {fair} cases ({fair/len(dice_scores)*100:.1f}%)")
    print(f"  Poor (<0.5): {poor} cases ({poor/len(dice_scores)*100:.1f}%)")
    
    # Visualization
    categories = ['Excellent\n(≥0.8)', 'Good\n(0.7-0.8)', 'Fair\n(0.5-0.7)', 'Poor\n(<0.5)']
    counts = [excellent, good, fair, poor]
    colors = ['green', 'orange', 'yellow', 'red']
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Bar chart
    bars = ax1.bar(categories, counts, color=colors, alpha=0.7, edgecolor='black')
    ax1.set_title('Performance Distribution')
    ax1.set_ylabel('Number of Cases')
    
    # Add percentage labels on bars
    for bar, count in zip(bars, counts):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                f'{count}\n({count/len(dice_scores)*100:.1f}%)',
                ha='center', va='bottom')
    
    # Pie chart
    ax2.pie(counts, labels=categories, colors=colors, autopct='%1.1f%%', startangle=90)
    ax2.set_title('Performance Distribution (Percentage)')
    
    plt.tight_layout()
    plt.show()

## 4. Failure Case Analysis

In [None]:
# Identify cases that need attention
if per_patient_df is not None:
    # Cases with low Dice but high precision (missed lesions)
    if 'dice' in per_patient_df.columns and 'precision' in per_patient_df.columns:
        low_dice_high_precision = per_patient_df[
            (per_patient_df['dice'] < 0.6) & (per_patient_df['precision'] > 0.7)
        ]
        
        print(f"\nCases with Low Dice but High Precision (likely missed lesions): {len(low_dice_high_precision)}")
        if len(low_dice_high_precision) > 0:
            print(low_dice_high_precision[['dice', 'precision', 'recall']].head())
    
    # Cases with high Dice but low precision (many false positives)
    if 'dice' in per_patient_df.columns and 'precision' in per_patient_df.columns:
        low_precision_cases = per_patient_df[
            (per_patient_df['dice'] > 0.6) & (per_patient_df['precision'] < 0.6)
        ]
        
        print(f"\nCases with Low Precision (many false positives): {len(low_precision_cases)}")
        if len(low_precision_cases) > 0:
            print(low_precision_cases[['dice', 'precision', 'recall']].head())
    
    # Scatter plot: Precision vs Recall
    if 'precision' in per_patient_df.columns and 'recall' in per_patient_df.columns:
        plt.figure(figsize=(10, 8))
        scatter = plt.scatter(per_patient_df['precision'], per_patient_df['recall'], 
                            c=per_patient_df['dice'], cmap='viridis', alpha=0.7, s=60)
        plt.colorbar(scatter, label='Dice Score')
        plt.xlabel('Precision')
        plt.ylabel('Recall')
        plt.title('Precision vs Recall (colored by Dice Score)')
        
        # Add diagonal line (equal precision and recall)
        plt.plot([0, 1], [0, 1], 'r--', alpha=0.5, label='Precision = Recall')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.xlim(0, 1)
        plt.ylim(0, 1)
        plt.show()

## 5. Model Architecture Analysis

In [None]:
# Load model for analysis
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model
model_config = config.get('model')
model = create_model(
    architecture=model_config['architecture'],
    **model_config.get('params', {})
)

# Model summary
print(f"\nModel Architecture: {model_config['architecture']}")
print(f"Input channels: {model_config['params']['in_channels']}")
print(f"Output channels: {model_config['params']['out_channels']}")
print(f"Feature size: {model_config['params']['feature_size']}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size (MB): {total_params * 4 / 1024 / 1024:.2f}")

## 6. Training Configuration Analysis

In [None]:
# Analyze training configuration
print("Training Configuration:")
print(f"  Epochs: {config.get('training.epochs')}")
print(f"  Batch size: {config.get('data.batch_size')}")
print(f"  Learning rate: {config.get('optimizer.learning_rate')}")
print(f"  Weight decay: {config.get('optimizer.weight_decay')}")
print(f"  Use AMP: {config.get('training.use_amp')}")

print("\nLoss Configuration:")
loss_config = config.get('loss')
print(f"  Type: {loss_config['type']}")
if 'params' in loss_config:
    for key, value in loss_config['params'].items():
        print(f"  {key}: {value}")

print("\nData Configuration:")
data_config = config.get('data')
print(f"  Modalities: {data_config['modalities']}")
print(f"  Target: {data_config['target']}")
print(f"  Num workers: {data_config['num_workers']}")

## 7. Recommendations for Improvement

In [None]:
# Generate recommendations based on analysis
recommendations = []

if per_patient_df is not None and 'dice' in per_patient_df.columns:
    mean_dice = per_patient_df['dice'].mean()
    std_dice = per_patient_df['dice'].std()
    
    print(f"\nModel Performance Analysis:")
    print(f"Mean Dice Score: {mean_dice:.3f} ± {std_dice:.3f}")
    
    if mean_dice < 0.6:
        recommendations.append("🔴 LOW PERFORMANCE: Consider major changes:")
        recommendations.append("   - Review data quality and preprocessing")
        recommendations.append("   - Increase model capacity (feature_size: 96 or 128)")
        recommendations.append("   - Try different architecture (UNETR, SegResNet)")
        recommendations.append("   - Increase training duration (50+ epochs)")
    elif mean_dice < 0.75:
        recommendations.append("🟡 MODERATE PERFORMANCE: Consider improvements:")
        recommendations.append("   - Enhance data augmentation")
        recommendations.append("   - Fine-tune loss function weights")
        recommendations.append("   - Implement ensemble methods")
        recommendations.append("   - Add boundary loss component")
    else:
        recommendations.append("🟢 GOOD PERFORMANCE: Fine-tuning options:")
        recommendations.append("   - Implement test-time augmentation")
        recommendations.append("   - Try ensemble of multiple models")
        recommendations.append("   - Focus on edge case improvements")
    
    # Variability analysis
    if std_dice > 0.2:
        recommendations.append("\n⚠️  HIGH VARIABILITY detected:")
        recommendations.append("   - Review data consistency")
        recommendations.append("   - Implement stratified training")
        recommendations.append("   - Consider patient-specific normalization")
    
    # Precision/Recall analysis
    if 'precision' in per_patient_df.columns and 'recall' in per_patient_df.columns:
        mean_precision = per_patient_df['precision'].mean()
        mean_recall = per_patient_df['recall'].mean()
        
        if mean_precision < 0.7:
            recommendations.append("\n🔸 LOW PRECISION (many false positives):")
            recommendations.append("   - Increase focal loss weight")
            recommendations.append("   - Add false positive penalty")
            recommendations.append("   - Implement post-processing filters")
        
        if mean_recall < 0.7:
            recommendations.append("\n🔸 LOW RECALL (missing lesions):")
            recommendations.append("   - Increase Dice loss weight")
            recommendations.append("   - Review data augmentation")
            recommendations.append("   - Consider multi-scale training")

# Training efficiency recommendations
batch_size = config.get('data.batch_size')
if batch_size == 1:
    recommendations.append("\n⚡ TRAINING EFFICIENCY:")
    recommendations.append("   - Consider gradient accumulation for larger effective batch size")
    recommendations.append("   - Try mixed precision training if not already enabled")
    recommendations.append("   - Implement learning rate warmup")

print("\n" + "="*60)
print("RECOMMENDATIONS FOR IMPROVEMENT")
print("="*60)
for rec in recommendations:
    print(rec)

print("\n" + "="*60)
print("NEXT STEPS")
print("="*60)
print("1. Run comprehensive evaluation:")
print(f"   python scripts/evaluate.py --config {CONFIG_PATH} --checkpoint {CHECKPOINT_PATH} --generate-report")
print("\n2. Generate visualizations:")
print("   python scripts/visualize_results.py --predictions-dir /path/to/predictions --ground-truth-dir /path/to/gt --images-dir /path/to/images --output-dir visualizations")
print("\n3. Consider training improvements and run next iteration")
print("\n4. Review failure cases manually for data quality issues")

## 8. Next Training Configuration Template

In [None]:
# Generate improved configuration template
improved_config = {
    "model": {
        "architecture": "swinunetr",
        "params": {
            "in_channels": 4,
            "out_channels": 1,
            "feature_size": 96,  # Increased from 48
            "drop_rate": 0.1,    # Added dropout
            "attn_drop_rate": 0.1,
            "dropout_path_rate": 0.1,
            "use_checkpoint": True,
            "use_deep_supervision": True
        }
    },
    "loss": {
        "type": "combo",
        "params": {
            "dice_weight": 1.0,
            "focal_weight": 1.0,  # Increased if precision issues
            "focal_gamma": 2.0,
            "include_background": False
        }
    },
    "optimizer": {
        "type": "adamw",
        "learning_rate": 1e-4,  # Slightly higher
        "weight_decay": 0.01
    },
    "scheduler": {
        "use_scheduler": True,
        "type": "cosine_warmup",
        "params": {
            "warmup_epochs": 20,  # Longer warmup
            "min_lr": 1e-7
        }
    },
    "training": {
        "epochs": 100,  # More epochs
        "use_amp": True,
        "grad_clip": 1.0,
        "val_interval": 1
    },
    "data": {
        "batch_size": 2,  # Try to increase if possible
        "augmentation": {
            "rotation": [-15, 15],
            "scaling": [0.9, 1.1],
            "elastic_deformation": True,
            "gaussian_noise": 0.1
        }
    }
}

print("Suggested Configuration Improvements:")
print("\n# Enhanced SwinUNETR Configuration")
print("# Based on current model analysis")
print("\n# Key changes:")
print("# - Increased feature_size for more capacity")
print("# - Added dropout for regularization")
print("# - Adjusted learning rate and warmup")
print("# - Extended training duration")
print("# - Enhanced data augmentation")

import yaml
print("\n", yaml.dump(improved_config, default_flow_style=False))