# FedNAMs+ Training Notebook

This notebook demonstrates end-to-end training of FedNAMs+ on chest X-ray data using Google Colab.

## ⚠️ Current Status
**Note:** This notebook is ready for configuration and setup. The full training implementation requires:
- Data loading module (Task 2 - in progress)
- Complete training orchestration
- Evaluation pipeline integration

You can run this notebook to verify the setup and configuration work correctly.

## Overview
- Setup environment and mount Google Drive
- Install dependencies
- Configure experiment
- Run federated training (when implemented)
- Visualize results

## 1. Setup Environment

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repository (if not already cloned)
import os
if not os.path.exists('fednams-plus'):
    !git clone https://github.com/yourusername/fednams-plus.git
%cd fednams-plus

In [None]:
# Install dependencies
!pip install -q -r requirements.txt

## 2. Import Libraries

In [None]:
import sys
from pathlib import Path
import yaml
import matplotlib.pyplot as plt
import seaborn as sns

# Add project to path
sys.path.insert(0, str(Path.cwd()))

from configs.config_loader import ConfigLoader
from experiments import ExperimentRunner
from utils.logging_utils import setup_logger

# Setup plotting
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## 3. Configure Experiment

Modify the configuration below to customize your experiment.

In [None]:
# Configuration
config_dict = {
    'experiment_name': 'fednams_colab_demo',
    'seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'output_dir': 'outputs/fednams_colab_demo',
    
    'data': {
        'dataset': 'nih-cxr',
        'data_dir': '/content/drive/MyDrive/NIH Chest XRAY Dataset',  # Update this path
        'image_size': [224, 224],
        'normalization': 'imagenet',
        'augmentation': True,
        'augmentation_params': {
            'horizontal_flip_prob': 0.5,
            'rotation_degrees': 10,
            'brightness': 0.2,
            'contrast': 0.2
        }
    },
    
    'federated': {
        'num_clients': 5,
        'num_rounds': 20,  # Reduced for demo
        'client_fraction': 1.0,
        'min_clients': 3,
        'partition_strategy': 'dirichlet',
        'partition_params': {
            'alpha': 0.5,
            'min_samples': 200
        }
    },
    
    'model': {
        'backbone': 'resnet18',
        'pretrained': True,
        'feature_dim': 512,
        'num_classes': 15,
        'nam_hidden_units': [64, 32],
        'dropout': 0.3,
        'use_exu': False
    },
    
    'training': {
        'batch_size': 32,
        'learning_rate': 0.001,
        'num_local_epochs': 5,
        'optimizer': 'adam',
        'scheduler': 'cosine',
        'early_stopping_patience': 10,
        'mixed_precision': True
    },
    
    'explainability': {
        'compute_shap': True,
        'shap_background_samples': 100,
        'shap_test_samples': 500,
        'generate_plots': True,
        'plot_types': ['summary', 'importance', 'comparison']
    },
    
    'uncertainty': {
        'use_conformal': True,
        'confidence_level': 0.9,
        'calibration_fraction': 0.15
    },
    
    'evaluation': {
        'metrics': ['accuracy', 'f1', 'auc_roc', 'auc_pr'],
        'per_class_metrics': True,
        'fairness_analysis': True,
        'save_predictions': True
    }
}

# Save configuration
config_path = Path('configs/colab_experiment.yaml')
config_path.parent.mkdir(exist_ok=True)
with open(config_path, 'w') as f:
    yaml.dump(config_dict, f, default_flow_style=False)

print("Configuration saved to:", config_path)

## 4. Load Configuration and Setup

In [None]:
# Load configuration
config = ConfigLoader.load_experiment_config(config_path)

print(f"Experiment: {config.experiment_name}")
print(f"Device: {config.device}")
print(f"Num clients: {config.fed_config.num_clients}")
print(f"Num rounds: {config.fed_config.num_rounds}")
print(f"Model backbone: {config.model_config.backbone}")

## 5. Initialize Experiment Runner

In [None]:
# Setup logger
logger = setup_logger('fednams_training', level='INFO')

# Create experiment runner
runner = ExperimentRunner(config)
print("✓ Experiment runner initialized")

## 6. Run Federated Training

This will train the FedNAMs+ model across multiple federated clients.

In [None]:
# Run experiment
print("Starting federated training...\n")
results = runner.run_experiment()
print("\n✓ Training complete!")

## 7. View Training Results

In [None]:
# Display final metrics
print("\n=== Final Results ===")

# Check if full results are available
if 'test_metrics' in results:
    print(f"Test Accuracy: {results['test_metrics']['accuracy']:.4f}")
    print(f"Test F1-Score: {results['test_metrics']['f1']:.4f}")
    print(f"Test AUC-ROC: {results['test_metrics']['auc_roc']:.4f}")
    print(f"Test AUC-PR: {results['test_metrics']['auc_pr']:.4f}")
    
    if 'uncertainty_metrics' in results:
        print(f"\nCoverage: {results['uncertainty_metrics']['coverage']:.4f}")
        print(f"Avg Set Size: {results['uncertainty_metrics']['avg_set_size']:.2f}")
else:
    print("⚠️ Note: Full training implementation is pending.")
    print(f"Status: {results.get('status', 'unknown')}")
    print(f"Message: {results.get('message', 'No message')}")
    print("\nThe experiment runner is initialized and ready.")
    print("Complete implementation requires:")
    print("  1. Data loading module (Task 2)")
    print("  2. Training orchestration")
    print("  3. Evaluation pipeline")

## 8. Visualize Training Progress

In [None]:
# Plot training curves (if available)
if 'training_history' in results:
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss curve
    axes[0].plot(results['training_history']['rounds'], 
                 results['training_history']['train_loss'], 
                 label='Train Loss', marker='o')
    axes[0].plot(results['training_history']['rounds'], 
                 results['training_history']['val_loss'], 
                 label='Val Loss', marker='s')
    axes[0].set_xlabel('Round')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # Accuracy curve
    axes[1].plot(results['training_history']['rounds'], 
                 results['training_history']['train_accuracy'], 
                 label='Train Accuracy', marker='o')
    axes[1].plot(results['training_history']['rounds'], 
                 results['training_history']['val_accuracy'], 
                 label='Val Accuracy', marker='s')
    axes[1].set_xlabel('Round')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Training and Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.savefig(config.output_dir / 'training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("⚠️ Training history not available yet.")
    print("This will be populated once the full training loop is implemented.")

## 9. View SHAP Explanations

If SHAP computation was enabled, visualize feature importance.

In [None]:
# Display SHAP visualizations
from IPython.display import Image, display

shap_dir = config.output_dir / 'shap_visualizations'
if shap_dir.exists():
    print("SHAP Summary Plot:")
    display(Image(filename=str(shap_dir / 'shap_summary.png')))
    
    print("\nFeature Importance:")
    display(Image(filename=str(shap_dir / 'feature_importance.png')))
else:
    print("SHAP visualizations not found. Enable 'compute_shap' in config.")

## 10. Save Results

In [None]:
# Save results
runner.save_results(results)
print(f"\n✓ Results saved to: {config.output_dir}")

# Copy to Google Drive for persistence
drive_output = Path('/content/drive/MyDrive/FedNAMs_Results')
drive_output.mkdir(exist_ok=True)
!cp -r {config.output_dir} {drive_output}/
print(f"✓ Results backed up to Google Drive: {drive_output}")

## 11. Export Model

In [None]:
# Save final model
model_path = config.output_dir / 'final_model.pt'
torch.save({
    'model_state_dict': runner.model.state_dict(),
    'config': config_dict,
    'results': results
}, model_path)

print(f"✓ Model saved to: {model_path}")
print(f"Model size: {model_path.stat().st_size / 1e6:.2f} MB")

## Summary

This notebook demonstrated:
1. Setting up FedNAMs+ in Google Colab
2. Configuring a federated learning experiment
3. Training the model across multiple clients
4. Evaluating performance with multiple metrics
5. Visualizing training progress and SHAP explanations
6. Saving results and models

Next steps:
- Use `evaluate_results.ipynb` for detailed analysis
- Use `compare_baselines.ipynb` to compare with other methods
- Adjust hyperparameters and re-run experiments