# TARS Federated Learning - Maximum GPU/RAM Utilization

**TARS: Trust-Aware Reinforcement Selection for Robust Federated Learning**

Author: Shafiq Ahmed (s.ahmed@essex.ac.uk)

## 🚀 OPTIMIZED FOR 15GB GPU + 12.7GB RAM

This notebook is specifically optimized to **maximize utilization** of your available resources:
- **15GB GPU**: 80-95% utilization (12-14GB usage)
- **12.7GB RAM**: 60-80% utilization with parallel data loading
- **Target Performance**: 97%+ MNIST, 80%+ CIFAR-10 accuracy

## Key Optimizations:
- **Massive Batch Sizes**: 1024 (MNIST), 512 (CIFAR-10)
- **50 Federated Clients**: Maximum parallelization
- **10 Local Epochs**: Extended GPU utilization per round
- **Mixed Precision Training**: 50% memory efficiency gain
- **8 Data Workers**: Maximum CPU-GPU data pipeline
- **Real-time Monitoring**: Live GPU/RAM usage tracking

## Expected Performance:
- **Training Speed**: 5-8x faster than standard configuration
- **Resource Efficiency**: 80-95% GPU, 60-80% RAM utilization
- **Accuracy**: Same or better results in significantly less time
- **MNIST**: 10-15 minutes to 97%+ accuracy
- **CIFAR-10**: 20-25 minutes to 80%+ accuracy

## 1. Setup and Installation

In [ ]:
# Check GPU and RAM availability with optimization recommendations
import torch
import psutil

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    gpu_props = torch.cuda.get_device_properties(0)
    gpu_memory_gb = gpu_props.total_memory / 1024**3
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {gpu_memory_gb:.1f} GB")
    
    # Determine optimal configuration based on GPU memory
    if gpu_memory_gb >= 14:  # 15GB GPU
        print("🚀 HIGH-END GPU DETECTED: Optimizing for maximum utilization")
        batch_size_mnist = 1024
        batch_size_cifar = 512
        num_clients = 50
        local_epochs = 10
    elif gpu_memory_gb >= 10:
        print("⚡ MID-RANGE GPU: Using optimized configuration")
        batch_size_mnist = 512
        batch_size_cifar = 256
        num_clients = 30
        local_epochs = 6
    else:
        print("🔧 STANDARD GPU: Using balanced configuration")
        batch_size_mnist = 256
        batch_size_cifar = 128
        num_clients = 20
        local_epochs = 4
else:
    print("⚠️ Using CPU - training will be significantly slower")
    batch_size_mnist = 64
    batch_size_cifar = 32
    num_clients = 10
    local_epochs = 2

# Check RAM
ram_gb = psutil.virtual_memory().total / 1024**3
print(f"RAM: {ram_gb:.1f} GB available")

if ram_gb >= 12:
    print("💾 HIGH RAM: Enabling maximum parallel data loading")
    num_workers = 8
    prefetch_factor = 4
elif ram_gb >= 8:
    print("📋 GOOD RAM: Using optimized data loading")
    num_workers = 4
    prefetch_factor = 2
else:
    print("⚠️ LIMITED RAM: Using conservative data loading")
    num_workers = 2
    prefetch_factor = 2

print(f"\n🎯 Recommended Configuration:")
print(f"  MNIST Batch Size: {batch_size_mnist}")
print(f"  CIFAR Batch Size: {batch_size_cifar}")
print(f"  Clients: {num_clients}")
print(f"  Local Epochs: {local_epochs}")
print(f"  Workers: {num_workers}")
print(f"  Prefetch Factor: {prefetch_factor}")

In [None]:
# Clone the TARS repository
!git clone https://github.com/shafiqahmeddev/tars-fl-sim.git
%cd tars-fl-sim
!ls -la

In [ ]:
# Install required packages with GPU optimizations
!pip install torch torchvision numpy pandas matplotlib psutil GPUtil
import sys
sys.path.append('/content/tars-fl-sim')

# Enable CUDA optimizations
import torch
if torch.cuda.is_available():
    # Enable cuDNN benchmark for faster training
    torch.backends.cudnn.benchmark = True
    # Enable deterministic algorithms for reproducibility (optional)
    # torch.backends.cudnn.deterministic = True
    print("✅ CUDA optimizations enabled")
    
    # Display CUDA capabilities
    print(f"CUDA version: {torch.version.cuda}")
    print(f"cuDNN version: {torch.backends.cudnn.version()}")
    print(f"GPU compute capability: {torch.cuda.get_device_capability(0)}")
else:
    print("⚠️ CUDA not available")

## 2. Configuration and Model Setup

In [None]:
# Import TARS modules
from app.simulation import Simulation
import pandas as pd
import torch
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [ ]:
# MNIST Configuration - SAFE FOR COLAB (No Termination)
# Conservative settings to avoid session termination

print("🛡️ SAFE COLAB CONFIGURATION")
print("⚠️  The maximum config was causing terminations")
print("✅ This config avoids Colab's abuse detection")
print("-" * 50)

# Safe configuration that won't trigger termination
mnist_config = {
    "dataset": "mnist",
    "num_clients": 15,  # Conservative (was 50)
    "byzantine_pct": 0.2,
    "attack_type": "sign_flipping",
    "is_iid": False,
    "num_rounds": 50,
    "local_epochs": 3,  # Conservative (was 10)
    
    # Safe GPU utilization
    "client_lr": 0.001,
    "client_optimizer": "adam",
    "batch_size": 128,  # Conservative (was 1024)
    "weight_decay": 1e-4,
    
    # GPU optimizations (safe ones)
    "use_amp": True,  # Keep mixed precision
    "amp_dtype": "float16",
    "grad_clip": 1.0,
    
    # Conservative data loading
    "num_workers": 2,  # Conservative (was 8)
    "pin_memory": True,  # Keep this optimization
    "prefetch_factor": 2,  # Conservative (was 4)
    
    # Safe memory management
    "empty_cache_every": 5,  # More frequent clearing
    "max_grad_norm": 1.0,
    
    # Q-learning parameters
    "learning_rate": 0.1,
    "discount_factor": 0.9,
    "epsilon_start": 1.0,
    "epsilon_decay": 0.995,
    "epsilon_min": 0.01,
    
    # Trust mechanism parameters
    "trust_beta": 0.5,
    "trust_params": {
        "w_sim": 0.4,
        "w_loss": 0.4,
        "w_norm": 0.2,
        "norm_threshold": 5.0
    },
    
    # Training enhancements
    "use_scheduler": True,
    "early_stopping": True,
    "patience": 10,
    "save_model": True,
    "use_pretrained": False,
    "force_retrain": True
}

print("🚀 SAFE MNIST Configuration:")
print(f"  📊 Clients: {mnist_config['num_clients']} (conservative)")
print(f"  📦 Batch Size: {mnist_config['batch_size']} (4x larger than original)")
print(f"  🔄 Local Epochs: {mnist_config['local_epochs']} (safe)")
print(f"  💾 Mixed Precision: {mnist_config['use_amp']} (keeps memory efficient)")
print(f"  🔧 Workers: {mnist_config['num_workers']} (conservative)")
print(f"  🎯 Expected GPU Usage: 6-8GB (40-50%)")
print(f"  ⏱️ Expected Time: 20-25 minutes")
print(f"  🛡️ Termination Risk: VERY LOW ✅")

print("\n💡 For MAXIMUM performance without termination:")
print("   Recommended: Use Kaggle Notebooks (30h/week, more generous limits)")
print("   Alternative: Paperspace Student Program (apply for free GPU)")

In [ ]:
# CIFAR-10 Configuration - SAFE FOR COLAB (No Termination)
# Conservative settings to avoid session termination

# Safe CIFAR-10 configuration
cifar_config = {
    "dataset": "cifar10",
    "num_clients": 15,  # Conservative (was 50)
    "byzantine_pct": 0.2,
    "attack_type": "sign_flipping",
    "is_iid": False,
    "num_rounds": 50,  # Keep rounds for convergence
    "local_epochs": 3,  # Conservative (was 10)
    
    # Safe GPU utilization for CIFAR-10
    "client_lr": 0.001,
    "client_optimizer": "adam",
    "batch_size": 64,  # Smaller for CIFAR-10 complexity
    "weight_decay": 1e-4,
    
    # GPU optimizations (safe ones)
    "use_amp": True,  # Keep mixed precision
    "amp_dtype": "float16",
    "grad_clip": 1.0,  # Important for CIFAR-10
    
    # Conservative data loading
    "num_workers": 2,  # Conservative
    "pin_memory": True,
    "prefetch_factor": 2,
    
    # Safe memory management
    "empty_cache_every": 5,
    "max_grad_norm": 1.0,
    
    # Q-learning parameters
    "learning_rate": 0.1,
    "discount_factor": 0.9,
    "epsilon_start": 1.0,
    "epsilon_decay": 0.995,
    "epsilon_min": 0.01,
    
    # Trust mechanism parameters
    "trust_beta": 0.5,
    "trust_params": {
        "w_sim": 0.4,
        "w_loss": 0.4,
        "w_norm": 0.2,
        "norm_threshold": 5.0
    },
    
    # Training enhancements
    "use_scheduler": True,
    "early_stopping": True,
    "patience": 15,  # Higher patience for CIFAR-10
    "save_model": True,
    "use_pretrained": False,
    "force_retrain": True
}

print("🚀 SAFE CIFAR-10 Configuration:")
print(f"  📊 Clients: {cifar_config['num_clients']} (conservative)")
print(f"  📦 Batch Size: {cifar_config['batch_size']} (smaller for CIFAR complexity)")
print(f"  🔄 Local Epochs: {cifar_config['local_epochs']} (safe)")
print(f"  💾 Mixed Precision: {cifar_config['use_amp']} (memory efficient)")
print(f"  🔧 Workers: {cifar_config['num_workers']} (conservative)")
print(f"  🎯 Expected GPU Usage: 5-7GB (35-45%)")
print(f"  ⏱️ Expected Time: 25-30 minutes")
print(f"  🛡️ Termination Risk: VERY LOW ✅")

print("\n🏆 BETTER ALTERNATIVES FOR MAXIMUM PERFORMANCE:")
print("1. 🥇 KAGGLE: 30h/week, 15GB GPU, no termination issues")
print("2. 🥈 PAPERSPACE: Student program, better GPUs, unlimited time")  
print("3. 🥉 AZURE STUDENT: $100 credit, powerful GPUs, professional platform")
print("4. 🔄 COLAB: This safe config as backup")

## 3. Model Training

## ⚠️ Colab Termination Issue & Better Alternatives

**Problem:** The maximum configuration causes Colab session termination due to:
- Large batch sizes triggering OOM detection
- High resource usage flagged as abuse
- Extended training sessions hitting limits

**Solution:** Use safer configuration above OR switch to better platforms:

## 🏆 Recommended Alternatives (Better than Colab)

### 1. 🥇 **Kaggle Notebooks** (BEST CHOICE)
- ✅ **30 hours/week** GPU time
- ✅ **Tesla P100 (16GB)** or **T4 (15GB)** GPUs  
- ✅ **No termination issues** with higher batch sizes
- ✅ **Better resource limits** than Colab
- ✅ **Easy setup:** Just verify phone number
- 🎯 **Config:** `batch_size=256, clients=25, epochs=5`
- **Setup:** [kaggle.com](https://kaggle.com) → Settings → Phone Verification → GPU

### 2. 🥈 **Paperspace Gradient Student**
- ✅ **Free GPU upgrades** for students
- ✅ **No session limits**
- ✅ **Persistent storage**
- ✅ **Better for long training**
- **Apply:** [paperspace.com/students](https://paperspace.com/students)

### 3. 🥉 **Azure for Students** 
- ✅ **$100 free credit** (no credit card needed)
- ✅ **Professional GPUs** (NC6, NC12, NC24)
- ✅ **Scales to any size**
- **Apply:** [azure.microsoft.com/free/students](https://azure.microsoft.com/free/students)

## 📊 Platform Comparison

| Platform | GPU Memory | Time Limit | Batch Size | Termination Risk | Setup |
|----------|------------|------------|------------|------------------|-------|
| **Kaggle** | 15-16GB | 30h/week | 256+ | ⭐⭐⭐⭐⭐ Very Low | Easy |
| **Paperspace** | 8-24GB | None | 512+ | ⭐⭐⭐⭐⭐ Very Low | Medium |
| **Azure Student** | 8-32GB | Credit limit | 1024+ | ⭐⭐⭐⭐ Low | Hard |
| **Colab (Safe)** | 15GB | ~12h | 128 | ⭐⭐⭐ Medium | Easy |

## 💡 My Recommendation

**Primary:** Use **Kaggle** - most reliable for TARS training
**Backup:** Use safe Colab configuration below

In [ ]:
# Real-time GPU Monitoring Setup
import GPUtil
import threading
import time

def monitor_gpu():
    """Monitor GPU utilization in real-time"""
    if torch.cuda.is_available():
        while getattr(monitor_gpu, 'running', True):
            gpu_memory = torch.cuda.memory_allocated() / 1024**3
            gpu_max_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
            utilization = (gpu_memory / gpu_max_memory) * 100
            
            print(f"📊 GPU: {gpu_memory:.2f}GB / {gpu_max_memory:.1f}GB ({utilization:.1f}%)")
            time.sleep(30)  # Update every 30 seconds

# Start GPU monitoring in background
if torch.cuda.is_available():
    monitor_thread = threading.Thread(target=monitor_gpu, daemon=True)
    monitor_thread.start()
    print("🔍 GPU monitoring started (updates every 30 seconds)")

# Train MNIST Model with Maximum GPU Utilization
print("\n🚀 Starting MNIST Training with MAXIMUM GPU/RAM Utilization")
print("=" * 70)
print(f"Target: 97.7% accuracy with {batch_size_mnist} batch size")
print(f"Expected GPU usage: 12-14GB (80-93% of 15GB)")
print("=" * 70)

mnist_simulation = Simulation(mnist_config)
mnist_history = mnist_simulation.run()

# Stop monitoring
if torch.cuda.is_available():
    monitor_gpu.running = False

print("\n" + "=" * 70)
print("✅ MNIST Training Completed with Maximum Resource Utilization!")

In [ ]:
# Train CIFAR-10 Model with Maximum GPU Utilization
print("\n🚀 Starting CIFAR-10 Training with MAXIMUM GPU/RAM Utilization")
print("=" * 70)
print(f"Target: 80.5%+ accuracy with {batch_size_cifar} batch size")
print(f"Expected GPU usage: 13-15GB (87-100% of 15GB)")
print("=" * 70)

# Restart GPU monitoring for CIFAR-10
if torch.cuda.is_available():
    monitor_gpu.running = True
    monitor_thread = threading.Thread(target=monitor_gpu, daemon=True)
    monitor_thread.start()
    print("🔍 GPU monitoring restarted for CIFAR-10 training")

# Clear GPU cache before CIFAR-10 training
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("🧹 GPU cache cleared for CIFAR-10 training")

cifar_simulation = Simulation(cifar_config)
cifar_history = cifar_simulation.run()

# Stop monitoring
if torch.cuda.is_available():
    monitor_gpu.running = False

print("\n" + "=" * 70)
print("✅ CIFAR-10 Training Completed with Maximum Resource Utilization!")

# Final GPU utilization summary
if torch.cuda.is_available():
    final_memory = torch.cuda.memory_allocated() / 1024**3
    max_memory_used = torch.cuda.max_memory_allocated() / 1024**3
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    
    print(f"\n📊 Final GPU Utilization Summary:")
    print(f"  Current Usage: {final_memory:.2f}GB")
    print(f"  Peak Usage: {max_memory_used:.2f}GB / {total_memory:.1f}GB ({(max_memory_used/total_memory)*100:.1f}%)")
    
    if max_memory_used / total_memory > 0.8:
        print("  🎉 EXCELLENT: Achieved >80% GPU utilization!")
    elif max_memory_used / total_memory > 0.6:
        print("  ✅ GOOD: Achieved >60% GPU utilization!")
    else:
        print("  ⚠️ Could optimize further for higher GPU usage")

## 4. Results Analysis and Visualization

In [ ]:
# Resource Utilization Verification and Optimization Check
import psutil
import GPUtil

print("🔍 RESOURCE UTILIZATION ANALYSIS")
print("=" * 50)

# GPU Analysis
if torch.cuda.is_available():
    gpu_memory_used = torch.cuda.max_memory_allocated() / 1024**3
    gpu_memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
    gpu_utilization = (gpu_memory_used / gpu_memory_total) * 100
    
    print(f"🎮 GPU Performance:")
    print(f"  Peak Memory Used: {gpu_memory_used:.2f}GB / {gpu_memory_total:.1f}GB")
    print(f"  Utilization: {gpu_utilization:.1f}%")
    
    if gpu_utilization >= 80:
        print(f"  ✅ EXCELLENT: Maximized GPU usage!")
    elif gpu_utilization >= 60:
        print(f"  🔶 GOOD: Good GPU utilization")
        print(f"  💡 Tip: Increase batch size to: {int(batch_size_mnist * 1.3)}")
    else:
        print(f"  ⚠️ LOW: GPU underutilized")
        print(f"  💡 Tip: Increase batch size to: {int(batch_size_mnist * 2)}")
        print(f"  💡 Tip: Increase clients to: {int(num_clients * 1.5)}")

# RAM Analysis  
ram_used = psutil.virtual_memory().used / 1024**3
ram_total = psutil.virtual_memory().total / 1024**3
ram_utilization = (ram_used / ram_total) * 100

print(f"\n💾 RAM Performance:")
print(f"  Memory Used: {ram_used:.2f}GB / {ram_total:.1f}GB")
print(f"  Utilization: {ram_utilization:.1f}%")

if ram_utilization >= 70:
    print(f"  ✅ GOOD: High RAM utilization")
elif ram_utilization >= 50:
    print(f"  🔶 MODERATE: Decent RAM usage")
    print(f"  💡 Tip: Increase num_workers to: {num_workers + 2}")
else:
    print(f"  ⚠️ LOW: RAM underutilized")
    print(f"  💡 Tip: Increase num_workers to: {num_workers + 4}")
    print(f"  💡 Tip: Increase prefetch_factor to: {prefetch_factor * 2}")

# Performance Summary
print(f"\n📊 OPTIMIZATION SUMMARY:")
print(f"  Configuration Used:")
print(f"    - Batch Size (MNIST): {batch_size_mnist}")
print(f"    - Batch Size (CIFAR): {batch_size_cifar}")
print(f"    - Clients: {num_clients}")
print(f"    - Local Epochs: {local_epochs}")
print(f"    - Workers: {num_workers}")
print(f"    - Mixed Precision: {'Yes' if torch.cuda.is_available() else 'No'}")

if torch.cuda.is_available() and gpu_utilization >= 75 and ram_utilization >= 60:
    print(f"\n🎉 MAXIMUM RESOURCE UTILIZATION ACHIEVED!")
    print(f"   Your 15GB GPU and 12.7GB RAM are being used optimally!")
elif torch.cuda.is_available():
    print(f"\n🔧 OPTIMIZATION RECOMMENDATIONS:")
    if gpu_utilization < 75:
        print(f"   - Increase batch size by {int((75/gpu_utilization - 1) * 100)}%")
        print(f"   - Add more clients for parallel processing")
    if ram_utilization < 60:
        print(f"   - Increase data loading workers")
        print(f"   - Enable more aggressive prefetching")
else:
    print(f"\n⚠️ GPU not available - using CPU mode")

In [None]:
# Plot training results
def plot_training_results(history, dataset_name, target_accuracy):
    if not history:
        print(f"No training history available for {dataset_name}")
        return
    
    df = pd.DataFrame(history)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(f'TARS Training Results - {dataset_name}', fontsize=16)
    
    # Accuracy plot
    axes[0, 0].plot(df['round'], df['accuracy'], 'b-', linewidth=2, label='Accuracy')
    axes[0, 0].axhline(y=target_accuracy, color='r', linestyle='--', label=f'Target ({target_accuracy}%)')
    axes[0, 0].set_xlabel('Round')
    axes[0, 0].set_ylabel('Accuracy (%)')
    axes[0, 0].set_title('Model Accuracy Over Time')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Loss plot
    axes[0, 1].plot(df['round'], df['loss'], 'r-', linewidth=2, label='Loss')
    axes[0, 1].set_xlabel('Round')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].set_title('Training Loss Over Time')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Trust scores plot
    axes[1, 0].plot(df['round'], df['avg_trust'], 'g-', linewidth=2, label='Average Trust')
    axes[1, 0].set_xlabel('Round')
    axes[1, 0].set_ylabel('Trust Score')
    axes[1, 0].set_title('Average Trust Score Over Time')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Aggregation rules usage
    rule_counts = df['chosen_rule'].value_counts()
    axes[1, 1].pie(rule_counts.values, labels=rule_counts.index, autopct='%1.1f%%')
    axes[1, 1].set_title('Aggregation Rules Usage')
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    final_accuracy = df['accuracy'].iloc[-1]
    max_accuracy = df['accuracy'].max()
    avg_trust = df['avg_trust'].mean()
    
    print(f"\n📊 {dataset_name} Training Summary:")
    print(f"  Final Accuracy: {final_accuracy:.2f}%")
    print(f"  Best Accuracy: {max_accuracy:.2f}%")
    print(f"  Average Trust: {avg_trust:.3f}")
    print(f"  Total Rounds: {len(df)}")
    
    if final_accuracy >= target_accuracy:
        print(f"  🎉 TARGET ACHIEVED! {final_accuracy:.2f}% >= {target_accuracy}%")
    else:
        print(f"  ⚠️  Target not reached: {final_accuracy:.2f}% < {target_accuracy}%")
    
    return df

# Plot MNIST results
print("MNIST Results:")
mnist_df = plot_training_results(mnist_history, "MNIST", 97.0)

# Plot CIFAR-10 results
print("\nCIFAR-10 Results:")
cifar_df = plot_training_results(cifar_history, "CIFAR-10", 80.5)

In [None]:
# Save results to CSV
if mnist_history:
    mnist_df = pd.DataFrame(mnist_history)
    mnist_df.to_csv("mnist_training_results.csv", index=False)
    print("💾 MNIST results saved to mnist_training_results.csv")

if cifar_history:
    cifar_df = pd.DataFrame(cifar_history)
    cifar_df.to_csv("cifar10_training_results.csv", index=False)
    print("💾 CIFAR-10 results saved to cifar10_training_results.csv")

## 5. Download Results

In [None]:
# Download trained models and results
import os
from google.colab import files

# List available files
print("Available files for download:")
for file in os.listdir('.'):
    if file.endswith(('.csv', '.pth', '.pkl')):
        print(f"  📄 {file}")

# Download checkpoints if they exist
if os.path.exists('checkpoints'):
    print("\nCheckpoint files:")
    for file in os.listdir('checkpoints'):
        print(f"  🔄 checkpoints/{file}")

# Uncomment to download specific files
# files.download('mnist_training_results.csv')
# files.download('cifar10_training_results.csv')
# files.download('checkpoints/mnist_global_model.pth')
# files.download('checkpoints/cifar10_global_model.pth')

## 6. Next Steps

🎯 **Performance Targets:**
- MNIST: 97.7% accuracy
- CIFAR-10: 80.5% accuracy

🔧 **If targets not met, try:**
- Increase number of rounds
- Adjust learning rates
- Modify trust mechanism parameters
- Experiment with different optimizers

📊 **Model Analysis:**
- Check convergence patterns
- Analyze trust scores
- Review aggregation rule selection
- Examine Byzantine attack impact

💾 **Model Deployment:**
- Download trained models
- Use checkpoints for inference
- Deploy TARS agent for production