# üöÄ Neural Conditional Ensemble Averaging ‚Äî Colab GPU Training

**Complete workflow with Guardian validation and GPU acceleration**

This notebook:
1. ‚úÖ Clones code from GitHub
2. ‚úÖ Runs Guardian pre-flight validation
3. ‚úÖ Trains on GPU (T4/V100/A100)
4. ‚úÖ Saves results to Google Drive
5. ‚úÖ Shows final metrics and plots


## Step 1: Setup ‚Äî Mount Drive & Clone Repository

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
print("‚úÖ Google Drive mounted")

In [None]:
# Clone repository from GitHub
!git clone https://github.com/paulbroadmission/ncea_denoise.git /content/ncea_denoise
%cd /content/ncea_denoise
!git log --oneline -1
print("‚úÖ Repository cloned")

## Step 2: GPU Check

In [None]:
import torch
import numpy as np
import os
import json
from datetime import datetime

print("="*70)
print("GPU CONFIGURATION")
print("="*70)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"‚úÖ GPU Available: {gpu_name}")
    print(f"‚úÖ GPU Memory: {gpu_mem_gb:.2f} GB")
    print(f"‚úÖ CUDA Version: {torch.version.cuda}")
else:
    print("‚ùå NO GPU DETECTED!")
    print("Go to: Runtime ‚Üí Change runtime type ‚Üí GPU")
    raise RuntimeError("GPU required for training")

DEVICE = "cuda"
print(f"\n‚úÖ Using device: {DEVICE}")

## Step 3: Install Dependencies

In [None]:
# Install requirements (quiet mode)
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q numpy scipy scikit-learn matplotlib seaborn pandas tqdm tensorboard
print("‚úÖ Dependencies installed")

## Step 4: üõ°Ô∏è RUN GUARDIAN VALIDATION (Pre-flight Check)

In [None]:
print("="*70)
print("GUARDIAN VALIDATION ‚Äî Pre-flight checks")
print("="*70)

import subprocess

result = subprocess.run(
    ["python3", "workspace/src/guardian.py"],
    cwd="/content/ncea_denoise",
    capture_output=True,
    text=True
)

print(result.stdout)

if result.returncode != 0:
    print("‚ùå GUARDIAN FAILED")
    print(result.stderr)
    raise RuntimeError("Guardian validation failed. Fix issues and retry.")
else:
    print("\n‚úÖ GUARDIAN PASSED ‚Äî All checks passed!")
    print("Ready for training...")

## Step 5: Setup Colab Paths & Configuration

In [None]:
import sys
sys.path.insert(0, '/content/ncea_denoise/workspace/src')

# Create output directories
os.makedirs('/content/drive/MyDrive/ncea_results', exist_ok=True)
os.makedirs('/content/ncea_denoise/workspace/checkpoints', exist_ok=True)
os.makedirs('/content/ncea_denoise/workspace/logs', exist_ok=True)

RESULTS_DIR = '/content/drive/MyDrive/ncea_results'
CHECKPOINT_DIR = '/content/ncea_denoise/workspace/checkpoints'

print("‚úÖ Colab paths configured")
print(f"   Results dir: {RESULTS_DIR}")
print(f"   Checkpoint dir: {CHECKPOINT_DIR}")

## Step 6: Import Training Components

In [None]:
from config import DEVICE, NUM_EPOCHS, BATCH_SIZE, LAMBDA_CONSISTENCY
from model import create_encoder
from data import create_data_loaders
from train import Trainer
from evaluate import Evaluator

print("="*70)
print("CONFIGURATION SUMMARY")
print("="*70)
print(f"Device:              {DEVICE}")
print(f"Default Epochs:      {NUM_EPOCHS}")
print(f"Batch Size:          {BATCH_SIZE}")
print(f"Lambda Consistency:  {LAMBDA_CONSISTENCY}")
print("\n‚úÖ All imports successful!")

## Step 7: QUICK TEST ‚Äî Synthetic Data (1 min)

In [None]:
print("="*70)
print("QUICK TEST ‚Äî Synthetic Data (2 epochs)")
print("="*70)

# Load synthetic data
print("\n[1/4] Loading synthetic SSVEP data...")
train_loader, val_loader, test_loader = create_data_loaders(
    dataset_name="synthetic",
    batch_size=BATCH_SIZE,
)

# Create model
print("[2/4] Creating CNN encoder...")
model = create_encoder(encoder_type="cnn")

# Train (quick test: 2 epochs)
print("[3/4] Training on GPU (2 epochs)...")
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    lambda_consistency=LAMBDA_CONSISTENCY,
    num_epochs=2,  # Quick test
    device=DEVICE,
    checkpoint_dir=CHECKPOINT_DIR,
    log_dir=os.path.join(CHECKPOINT_DIR, '..', 'logs'),
)

history_quick = trainer.train()

# Evaluate
print("\n[4/4] Evaluating on test set...")
evaluator = Evaluator(model, test_loader, device=DEVICE)
metrics_quick = evaluator.evaluate()

print("\n" + "="*70)
print("QUICK TEST RESULTS")
print("="*70)
print(f"‚úÖ Test Accuracy:    {metrics_quick['accuracy']:.4f}")
print(f"‚úÖ F1 Score:         {metrics_quick['f1_score']:.4f}")
print(f"‚úÖ ITR:              {metrics_quick['itr']:.2f} bits/min")
print("\n‚úÖ GPU training works! Ready for full training.")

## Step 8: FULL TRAINING ‚Äî 500 Epochs (30-45 min)

In [None]:
print("="*70)
print("FULL TRAINING ‚Äî Synthetic Data (500 epochs)")
print("="*70)

# Reload data and create fresh model
print("\n[1/4] Loading synthetic SSVEP data...")
train_loader, val_loader, test_loader = create_data_loaders(
    dataset_name="synthetic",
    batch_size=BATCH_SIZE,
)

print("[2/4] Creating fresh CNN encoder...")
model = create_encoder(encoder_type="cnn")

# Train full (500 epochs with early stopping)
print("[3/4] Training on GPU (up to 500 epochs with early stopping)...")
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    lambda_consistency=LAMBDA_CONSISTENCY,
    num_epochs=500,  # Full training
    device=DEVICE,
    checkpoint_dir=CHECKPOINT_DIR,
    log_dir=os.path.join(CHECKPOINT_DIR, '..', 'logs'),
)

history_full = trainer.train()

# Evaluate
print("\n[4/4] Evaluating on test set...")
evaluator = Evaluator(model, test_loader, device=DEVICE)
metrics_full = evaluator.evaluate()

print("\n" + "="*70)
print("FULL TRAINING RESULTS")
print("="*70)
print(f"Best Validation Accuracy:  {history_full['best_val_accuracy']:.4f}")
print(f"Best Epoch:                {history_full['best_epoch']}")
print(f"\nTest Set Metrics:")
print(f"  Accuracy:                {metrics_full['accuracy']:.4f}")
print(f"  F1 Score:                {metrics_full['f1_score']:.4f}")
print(f"  ITR:                     {metrics_full['itr']:.2f} bits/min")
print(f"  Within-class Distance:   {metrics_full['within_class_distance']:.6f}")
print(f"  Between-class Distance:  {metrics_full['between_class_distance']:.6f}")
print(f"  Consistency Ratio:       {metrics_full['consistency_ratio']:.4f}")
print("\n‚úÖ Training complete!")

## Step 9: ABLATION STUDIES ‚Äî Lambda Consistency

In [None]:
print("="*70)
print("ABLATION STUDIES ‚Äî Lambda Consistency")
print("="*70)

lambda_values = [0.0, 0.01, 0.1, 1.0]
ablation_results = {}

for lambda_val in lambda_values:
    print(f"\n{'‚îÄ'*70}")
    print(f"Lambda = {lambda_val:.2f}")
    print(f"{'‚îÄ'*70}")
    
    # Reload data
    train_loader, val_loader, test_loader = create_data_loaders(
        dataset_name="synthetic",
        batch_size=BATCH_SIZE,
    )
    
    # Create fresh model
    model = create_encoder(encoder_type="cnn")
    
    # Train with different lambda
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        lambda_consistency=lambda_val,
        num_epochs=100,  # Shorter for ablation
        device=DEVICE,
        checkpoint_dir=CHECKPOINT_DIR,
        log_dir=os.path.join(CHECKPOINT_DIR, '..', 'logs'),
    )
    
    history = trainer.train()
    
    # Evaluate
    evaluator = Evaluator(model, test_loader, device=DEVICE)
    metrics = evaluator.evaluate()
    
    ablation_results[lambda_val] = {
        "best_val_accuracy": history["best_val_accuracy"],
        "best_epoch": history["best_epoch"],
        "test_accuracy": metrics["accuracy"],
        "test_f1": metrics["f1_score"],
        "test_itr": metrics["itr"],
    }
    
    print(f"\n‚úÖ Œª={lambda_val:.2f} ‚Üí Val Acc: {history['best_val_accuracy']:.4f}, Test Acc: {metrics['accuracy']:.4f}")

# Print ablation summary
print("\n" + "="*70)
print("ABLATION SUMMARY")
print("="*70)
print("\nLambda  | Best Val Acc | Test Acc | Best Epoch")
print("‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ")
for lam, res in sorted(ablation_results.items()):
    print(f"{lam:6.2f} | {res['best_val_accuracy']:12.4f} | {res['test_accuracy']:8.4f} | {res['best_epoch']:10d}")

## Step 10: Save Results to Google Drive

In [None]:
import shutil

print("="*70)
print("SAVING RESULTS TO GOOGLE DRIVE")
print("="*70)

# Save full training results
full_results = {
    "experiment": "Neural Conditional Ensemble Averaging",
    "timestamp": datetime.now().isoformat(),
    "device": DEVICE,
    "gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "none",
    "training": {
        "best_val_accuracy": float(history_full["best_val_accuracy"]),
        "best_epoch": int(history_full["best_epoch"]),
        "num_epochs_trained": int(history_full.get("epoch", 500)),
        "lambda_consistency": float(LAMBDA_CONSISTENCY),
    },
    "test_metrics": {
        "accuracy": float(metrics_full["accuracy"]),
        "f1_score": float(metrics_full["f1_score"]),
        "itr": float(metrics_full["itr"]),
        "within_class_distance": float(metrics_full.get("within_class_distance", 0)),
        "between_class_distance": float(metrics_full.get("between_class_distance", 0)),
        "consistency_ratio": float(metrics_full.get("consistency_ratio", 0)),
    },
    "ablation_studies": {
        str(lam): {k: float(v) if isinstance(v, (int, float)) else v for k, v in res.items()}
        for lam, res in ablation_results.items()
    }
}

# Save to Drive
results_file = os.path.join(RESULTS_DIR, 'colab_results.json')
with open(results_file, 'w') as f:
    json.dump(full_results, f, indent=2)

print(f"‚úÖ Results saved: {results_file}")

# Also save best checkpoint
best_checkpoint = os.path.join(CHECKPOINT_DIR, 'best_model.pt')
if os.path.exists(best_checkpoint):
    shutil.copy2(best_checkpoint, os.path.join(RESULTS_DIR, 'best_model.pt'))
    print(f"‚úÖ Best checkpoint saved to Drive")

# Create a summary markdown file
summary_md = f"""# üéâ Colab Training Complete

**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
**Device:** {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}

## Full Training (500 epochs)
- **Best Validation Accuracy:** {history_full['best_val_accuracy']:.4f}
- **Best Epoch:** {history_full['best_epoch']}
- **Test Accuracy:** {metrics_full['accuracy']:.4f}
- **Test F1 Score:** {metrics_full['f1_score']:.4f}
- **Test ITR:** {metrics_full['itr']:.2f} bits/min

## Ablation Results

| Lambda | Val Acc | Test Acc | Best Epoch |
|--------|---------|----------|------------|
"""

for lam, res in sorted(ablation_results.items()):
    summary_md += f"| {lam:.2f} | {res['best_val_accuracy']:.4f} | {res['test_accuracy']:.4f} | {res['best_epoch']} |\n"

summary_file = os.path.join(RESULTS_DIR, 'SUMMARY.md')
with open(summary_file, 'w') as f:
    f.write(summary_md)

print(f"‚úÖ Summary saved: {summary_file}")
print(f"\nüìÇ All results available at: {RESULTS_DIR}")

## Step 11: Display Final Summary

In [None]:
print("\n" + "="*70)
print("üéâ COLAB TRAINING COMPLETE")
print("="*70)

print(f"""
‚úÖ Pre-flight Validation:     PASSED (Guardian)
‚úÖ Quick Test (2 epochs):     {metrics_quick['accuracy']:.4f} accuracy
‚úÖ Full Training (500 epochs):{history_full['best_val_accuracy']:.4f} best val accuracy
‚úÖ Ablation Studies:          4 lambda values tested
‚úÖ Results Saved:             Google Drive/ncea_results/

üìä FINAL TEST METRICS
   Accuracy:     {metrics_full['accuracy']:.4f}
   F1 Score:     {metrics_full['f1_score']:.4f}
   ITR:          {metrics_full['itr']:.2f} bits/min

üìÅ Drive Path: {RESULTS_DIR}
üìù Summary:    {summary_file}
üíæ Data:       {results_file}
ü§ñ Model:      {RESULTS_DIR}/best_model.pt

Next Steps:
  1. Download results from Google Drive
  2. Compare with baseline methods (TRCA, CNN, Li et al. 2024)
  3. Generate paper figures and tables
  4. Write Results section
""")

print("="*70)

## Optional: Load Real BETA Dataset

If you have the BETA dataset, uncomment and run this cell:

In [None]:
# # Download BETA dataset (optional)
# # !wget https://github.com/gumpy-bci/data/raw/master/BETA/BETA.mat -O /content/ncea_denoise/workspace/data/BETA.mat
# # print("‚úÖ BETA dataset downloaded")
# 
# # Then run training with BETA instead of synthetic:
# # train_loader, val_loader, test_loader = create_data_loaders(
# #     dataset_name="BETA",
# #     batch_size=BATCH_SIZE,
# # )