# FedNAMs+ Results Evaluation and Visualization

This notebook provides detailed analysis and visualization of FedNAMs+ experiment results.

## Contents
- Load experiment results
- Analyze classification performance
- Examine explanation quality
- Evaluate uncertainty quantification
- Generate publication-ready figures

## 1. Setup

In [None]:
import sys
from pathlib import Path
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Image

# Setup plotting
sns.set_style('whitegrid')
sns.set_palette('husl')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11

## 2. Load Results

In [None]:
# Specify experiment directory
experiment_dir = Path('outputs/fednams_colab_demo')  # Update this path

# Load results
with open(experiment_dir / 'results.json', 'r') as f:
    results = json.load(f)

print(f"Loaded results from: {experiment_dir}")
print(f"Experiment: {results.get('experiment_name', 'N/A')}")
print(f"Completed rounds: {len(results.get('training_history', {}).get('rounds', []))}")

## 3. Classification Performance Analysis

In [None]:
# Extract test metrics
test_metrics = results['test_metrics']

# Create metrics dataframe
metrics_df = pd.DataFrame({
    'Metric': ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'AUC-ROC', 'AUC-PR'],
    'Value': [
        test_metrics.get('accuracy', 0),
        test_metrics.get('precision', 0),
        test_metrics.get('recall', 0),
        test_metrics.get('f1', 0),
        test_metrics.get('auc_roc', 0),
        test_metrics.get('auc_pr', 0)
    ]
})

print("\n=== Test Set Performance ===")
display(metrics_df.style.format({'Value': '{:.4f}'}))

In [None]:
# Visualize metrics
fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.barh(metrics_df['Metric'], metrics_df['Value'], color='steelblue')
ax.set_xlabel('Score')
ax.set_title('FedNAMs+ Test Set Performance', fontsize=14, fontweight='bold')
ax.set_xlim(0, 1)
ax.grid(axis='x', alpha=0.3)

# Add value labels
for bar in bars:
    width = bar.get_width()
    ax.text(width + 0.01, bar.get_y() + bar.get_height()/2, 
            f'{width:.3f}', ha='left', va='center')

plt.tight_layout()
plt.savefig(experiment_dir / 'test_metrics.png', dpi=300, bbox_inches='tight')
plt.show()

## 4. Per-Class Performance

In [None]:
# Load per-class metrics if available
if 'per_class_metrics' in test_metrics:
    per_class = test_metrics['per_class_metrics']
    
    # Create dataframe
    class_names = per_class.get('class_names', [f'Class {i}' for i in range(len(per_class['f1']))])
    per_class_df = pd.DataFrame({
        'Class': class_names,
        'Precision': per_class['precision'],
        'Recall': per_class['recall'],
        'F1-Score': per_class['f1'],
        'AUC-ROC': per_class.get('auc_roc', [0]*len(class_names))
    })
    
    print("\n=== Per-Class Performance ===")
    display(per_class_df.style.format({
        'Precision': '{:.3f}',
        'Recall': '{:.3f}',
        'F1-Score': '{:.3f}',
        'AUC-ROC': '{:.3f}'
    }))
    
    # Visualize
    fig, ax = plt.subplots(figsize=(12, 8))
    x = np.arange(len(class_names))
    width = 0.2
    
    ax.bar(x - 1.5*width, per_class_df['Precision'], width, label='Precision', alpha=0.8)
    ax.bar(x - 0.5*width, per_class_df['Recall'], width, label='Recall', alpha=0.8)
    ax.bar(x + 0.5*width, per_class_df['F1-Score'], width, label='F1-Score', alpha=0.8)
    ax.bar(x + 1.5*width, per_class_df['AUC-ROC'], width, label='AUC-ROC', alpha=0.8)
    
    ax.set_xlabel('Class')
    ax.set_ylabel('Score')
    ax.set_title('Per-Class Performance Metrics', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(class_names, rotation=45, ha='right')
    ax.legend()
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(experiment_dir / 'per_class_metrics.png', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("Per-class metrics not available")

## 5. Training Progress Analysis

In [None]:
# Extract training history
history = results['training_history']

# Create comprehensive training plot
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss
axes[0, 0].plot(history['rounds'], history['train_loss'], 'o-', label='Train', linewidth=2)
axes[0, 0].plot(history['rounds'], history['val_loss'], 's-', label='Validation', linewidth=2)
axes[0, 0].set_xlabel('Round')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss', fontweight='bold')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy
axes[0, 1].plot(history['rounds'], history['train_accuracy'], 'o-', label='Train', linewidth=2)
axes[0, 1].plot(history['rounds'], history['val_accuracy'], 's-', label='Validation', linewidth=2)
axes[0, 1].set_xlabel('Round')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('Training and Validation Accuracy', fontweight='bold')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# F1-Score
if 'train_f1' in history:
    axes[1, 0].plot(history['rounds'], history['train_f1'], 'o-', label='Train', linewidth=2)
    axes[1, 0].plot(history['rounds'], history['val_f1'], 's-', label='Validation', linewidth=2)
    axes[1, 0].set_xlabel('Round')
    axes[1, 0].set_ylabel('F1-Score')
    axes[1, 0].set_title('Training and Validation F1-Score', fontweight='bold')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

# AUC-ROC
if 'train_auc' in history:
    axes[1, 1].plot(history['rounds'], history['train_auc'], 'o-', label='Train', linewidth=2)
    axes[1, 1].plot(history['rounds'], history['val_auc'], 's-', label='Validation', linewidth=2)
    axes[1, 1].set_xlabel('Round')
    axes[1, 1].set_ylabel('AUC-ROC')
    axes[1, 1].set_title('Training and Validation AUC-ROC', fontweight='bold')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(experiment_dir / 'training_progress.png', dpi=300, bbox_inches='tight')
plt.show()

## 6. Explanation Quality Analysis

In [None]:
# Load explanation metrics
if 'explanation_metrics' in results:
    exp_metrics = results['explanation_metrics']
    
    print("\n=== Explanation Quality Metrics ===")
    print(f"SHAP Consistency: {exp_metrics.get('shap_consistency', 0):.4f}")
    print(f"Feature Stability: {exp_metrics.get('feature_stability', 0):.4f}")
    print(f"Cross-Client Agreement: {exp_metrics.get('cross_client_agreement', 0):.4f}")
    
    # Visualize
    exp_df = pd.DataFrame({
        'Metric': ['SHAP Consistency', 'Feature Stability', 'Cross-Client Agreement'],
        'Score': [
            exp_metrics.get('shap_consistency', 0),
            exp_metrics.get('feature_stability', 0),
            exp_metrics.get('cross_client_agreement', 0)
        ]
    })
    
    fig, ax = plt.subplots(figsize=(10, 5))
    bars = ax.bar(exp_df['Metric'], exp_df['Score'], color=['#2ecc71', '#3498db', '#e74c3c'])
    ax.set_ylabel('Score')
    ax.set_title('Explanation Quality Metrics', fontsize=14, fontweight='bold')
    ax.set_ylim(0, 1)
    ax.grid(axis='y', alpha=0.3)
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2, height + 0.02,
                f'{height:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(experiment_dir / 'explanation_metrics.png', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("Explanation metrics not available")

## 7. SHAP Visualizations

In [None]:
# Display SHAP plots
shap_dir = experiment_dir / 'shap_visualizations'

if shap_dir.exists():
    print("\n=== SHAP Visualizations ===")
    
    # Summary plot
    summary_path = shap_dir / 'shap_summary.png'
    if summary_path.exists():
        print("\nSHAP Summary Plot:")
        display(Image(filename=str(summary_path)))
    
    # Feature importance
    importance_path = shap_dir / 'feature_importance.png'
    if importance_path.exists():
        print("\nFeature Importance:")
        display(Image(filename=str(importance_path)))
    
    # Client comparison
    comparison_path = shap_dir / 'client_comparison.png'
    if comparison_path.exists():
        print("\nCross-Client Feature Comparison:")
        display(Image(filename=str(comparison_path)))
else:
    print("SHAP visualizations not found")

## 8. Uncertainty Quantification Analysis

In [None]:
# Load uncertainty metrics
if 'uncertainty_metrics' in results:
    unc_metrics = results['uncertainty_metrics']
    
    print("\n=== Uncertainty Quantification Metrics ===")
    print(f"Coverage: {unc_metrics.get('coverage', 0):.4f}")
    print(f"Target Confidence: {unc_metrics.get('target_confidence', 0.9):.2f}")
    print(f"Average Set Size: {unc_metrics.get('avg_set_size', 0):.2f}")
    print(f"Median Set Size: {unc_metrics.get('median_set_size', 0):.2f}")
    
    # Visualize coverage
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Coverage comparison
    coverage_data = {
        'Metric': ['Target', 'Achieved'],
        'Coverage': [
            unc_metrics.get('target_confidence', 0.9),
            unc_metrics.get('coverage', 0)
        ]
    }
    axes[0].bar(coverage_data['Metric'], coverage_data['Coverage'], 
                color=['#95a5a6', '#2ecc71'])
    axes[0].set_ylabel('Coverage')
    axes[0].set_title('Coverage: Target vs Achieved', fontweight='bold')
    axes[0].set_ylim(0, 1)
    axes[0].grid(axis='y', alpha=0.3)
    
    # Set size distribution
    if 'set_size_distribution' in unc_metrics:
        set_sizes = unc_metrics['set_size_distribution']
        axes[1].hist(set_sizes, bins=20, color='steelblue', edgecolor='black', alpha=0.7)
        axes[1].axvline(unc_metrics.get('avg_set_size', 0), 
                       color='red', linestyle='--', linewidth=2, label='Mean')
        axes[1].axvline(unc_metrics.get('median_set_size', 0), 
                       color='green', linestyle='--', linewidth=2, label='Median')
        axes[1].set_xlabel('Prediction Set Size')
        axes[1].set_ylabel('Frequency')
        axes[1].set_title('Prediction Set Size Distribution', fontweight='bold')
        axes[1].legend()
        axes[1].grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(experiment_dir / 'uncertainty_metrics.png', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("Uncertainty metrics not available")

## 9. Communication Cost Analysis

In [None]:
# Load communication metrics
if 'communication_metrics' in results:
    comm_metrics = results['communication_metrics']
    
    print("\n=== Communication Cost ===")
    print(f"Parameters per round: {comm_metrics.get('params_per_round', 0):,}")
    print(f"MB per round: {comm_metrics.get('mb_per_round', 0):.2f}")
    print(f"Total rounds: {comm_metrics.get('total_rounds', 0)}")
    print(f"Total communication: {comm_metrics.get('total_mb', 0):.2f} MB")
    
    # Visualize
    fig, ax = plt.subplots(figsize=(10, 5))
    rounds = list(range(1, comm_metrics.get('total_rounds', 0) + 1))
    cumulative_mb = [i * comm_metrics.get('mb_per_round', 0) for i in rounds]
    
    ax.plot(rounds, cumulative_mb, 'o-', linewidth=2, markersize=6)
    ax.set_xlabel('Round')
    ax.set_ylabel('Cumulative Communication (MB)')
    ax.set_title('Communication Cost Over Training', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(experiment_dir / 'communication_cost.png', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("Communication metrics not available")

## 10. Generate Summary Report

In [None]:
# Create comprehensive summary
summary = f"""
{'='*60}
FedNAMs+ Experiment Summary
{'='*60}

Experiment: {results.get('experiment_name', 'N/A')}
Date: {results.get('timestamp', 'N/A')}

CLASSIFICATION PERFORMANCE
{'-'*60}
Accuracy:  {test_metrics.get('accuracy', 0):.4f}
F1-Score:  {test_metrics.get('f1', 0):.4f}
AUC-ROC:   {test_metrics.get('auc_roc', 0):.4f}
AUC-PR:    {test_metrics.get('auc_pr', 0):.4f}
"""

if 'explanation_metrics' in results:
    exp_metrics = results['explanation_metrics']
    summary += f"""
EXPLANATION QUALITY
{'-'*60}
SHAP Consistency:        {exp_metrics.get('shap_consistency', 0):.4f}
Feature Stability:       {exp_metrics.get('feature_stability', 0):.4f}
Cross-Client Agreement:  {exp_metrics.get('cross_client_agreement', 0):.4f}
"""

if 'uncertainty_metrics' in results:
    unc_metrics = results['uncertainty_metrics']
    summary += f"""
UNCERTAINTY QUANTIFICATION
{'-'*60}
Coverage:           {unc_metrics.get('coverage', 0):.4f}
Target Confidence:  {unc_metrics.get('target_confidence', 0.9):.2f}
Avg Set Size:       {unc_metrics.get('avg_set_size', 0):.2f}
"""

if 'communication_metrics' in results:
    comm_metrics = results['communication_metrics']
    summary += f"""
COMMUNICATION COST
{'-'*60}
MB per round:        {comm_metrics.get('mb_per_round', 0):.2f}
Total communication: {comm_metrics.get('total_mb', 0):.2f} MB
"""

summary += f"""
{'='*60}
"""

print(summary)

# Save summary
with open(experiment_dir / 'summary_report.txt', 'w') as f:
    f.write(summary)

print(f"\nâœ“ Summary report saved to: {experiment_dir / 'summary_report.txt'}")

## Summary

This notebook provided comprehensive analysis of FedNAMs+ results including:
- Classification performance metrics
- Per-class analysis
- Training progress visualization
- Explanation quality assessment
- SHAP visualizations
- Uncertainty quantification analysis
- Communication cost tracking

All visualizations have been saved to the experiment directory for publication use.