# TARS Federated Learning Training on Google Colab

**TARS: Trust-Aware Reinforcement Selection for Robust Federated Learning**

Author: Shafiq Ahmed (s.ahmed@essex.ac.uk)

This notebook trains the TARS federated learning system to achieve 97%+ accuracy on MNIST and CIFAR-10 datasets.

## Key Features:
- Enhanced CNN architectures with batch normalization
- Comprehensive data augmentation
- TARS trust mechanism with Q-learning
- Byzantine fault tolerance
- Performance monitoring and early stopping

## 1. Setup and Installation

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("Using CPU - training will be slower")

In [None]:
# Clone the TARS repository
!git clone https://github.com/shafiqahmeddev/tars-fl-sim.git
%cd tars-fl-sim
!ls -la

In [None]:
# Install required packages
!pip install torch torchvision numpy pandas matplotlib
import sys
sys.path.append('/content/tars-fl-sim')

## 2. Configuration and Model Setup

In [None]:
# Import TARS modules
from app.simulation import Simulation
import pandas as pd
import torch
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# MNIST Configuration - Target: 97.7% accuracy
mnist_config = {
    "dataset": "mnist",
    "num_clients": 10,
    "byzantine_pct": 0.2,
    "attack_type": "sign_flipping",
    "is_iid": False,  # Non-IID for realistic federated learning
    "num_rounds": 50,
    "local_epochs": 3,
    
    # Client training parameters
    "client_lr": 0.001,
    "client_optimizer": "adam",
    "batch_size": 32,
    "weight_decay": 1e-4,
    
    # Q-learning parameters
    "learning_rate": 0.1,
    "discount_factor": 0.9,
    "epsilon_start": 1.0,
    "epsilon_decay": 0.995,
    "epsilon_min": 0.01,
    
    # Trust mechanism parameters
    "trust_beta": 0.5,
    "trust_params": {
        "w_sim": 0.4,
        "w_loss": 0.4,
        "w_norm": 0.2,
        "norm_threshold": 5.0
    },
    
    # Training enhancements
    "use_scheduler": True,
    "early_stopping": True,
    "patience": 10,
    "save_model": True,
    "use_pretrained": False,
    "force_retrain": True
}

print("MNIST Configuration:")
for key, value in mnist_config.items():
    print(f"  {key}: {value}")

In [None]:
# CIFAR-10 Configuration - Target: 80.5%+ accuracy
cifar_config = {
    "dataset": "cifar10",
    "num_clients": 10,
    "byzantine_pct": 0.2,
    "attack_type": "sign_flipping",
    "is_iid": False,
    "num_rounds": 50,
    "local_epochs": 3,
    
    # Client training parameters (higher learning rate for CIFAR-10)
    "client_lr": 0.001,
    "client_optimizer": "adam",
    "batch_size": 32,
    "weight_decay": 1e-4,
    
    # Q-learning parameters
    "learning_rate": 0.1,
    "discount_factor": 0.9,
    "epsilon_start": 1.0,
    "epsilon_decay": 0.995,
    "epsilon_min": 0.01,
    
    # Trust mechanism parameters
    "trust_beta": 0.5,
    "trust_params": {
        "w_sim": 0.4,
        "w_loss": 0.4,
        "w_norm": 0.2,
        "norm_threshold": 5.0
    },
    
    # Training enhancements
    "use_scheduler": True,
    "early_stopping": True,
    "patience": 10,
    "save_model": True,
    "use_pretrained": False,
    "force_retrain": True
}

print("CIFAR-10 Configuration:")
for key, value in cifar_config.items():
    print(f"  {key}: {value}")

## 3. Model Training

In [None]:
# Train MNIST Model
print("🚀 Starting MNIST Training - Target: 97.7% accuracy")
print("=" * 60)

mnist_simulation = Simulation(mnist_config)
mnist_history = mnist_simulation.run()

print("\n" + "=" * 60)
print("✅ MNIST Training Completed!")

In [None]:
# Train CIFAR-10 Model
print("🚀 Starting CIFAR-10 Training - Target: 80.5%+ accuracy")
print("=" * 60)

cifar_simulation = Simulation(cifar_config)
cifar_history = cifar_simulation.run()

print("\n" + "=" * 60)
print("✅ CIFAR-10 Training Completed!")

## 4. Results Analysis and Visualization

In [None]:
# Plot training results
def plot_training_results(history, dataset_name, target_accuracy):
    if not history:
        print(f"No training history available for {dataset_name}")
        return
    
    df = pd.DataFrame(history)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(f'TARS Training Results - {dataset_name}', fontsize=16)
    
    # Accuracy plot
    axes[0, 0].plot(df['round'], df['accuracy'], 'b-', linewidth=2, label='Accuracy')
    axes[0, 0].axhline(y=target_accuracy, color='r', linestyle='--', label=f'Target ({target_accuracy}%)')
    axes[0, 0].set_xlabel('Round')
    axes[0, 0].set_ylabel('Accuracy (%)')
    axes[0, 0].set_title('Model Accuracy Over Time')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Loss plot
    axes[0, 1].plot(df['round'], df['loss'], 'r-', linewidth=2, label='Loss')
    axes[0, 1].set_xlabel('Round')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].set_title('Training Loss Over Time')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Trust scores plot
    axes[1, 0].plot(df['round'], df['avg_trust'], 'g-', linewidth=2, label='Average Trust')
    axes[1, 0].set_xlabel('Round')
    axes[1, 0].set_ylabel('Trust Score')
    axes[1, 0].set_title('Average Trust Score Over Time')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Aggregation rules usage
    rule_counts = df['chosen_rule'].value_counts()
    axes[1, 1].pie(rule_counts.values, labels=rule_counts.index, autopct='%1.1f%%')
    axes[1, 1].set_title('Aggregation Rules Usage')
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    final_accuracy = df['accuracy'].iloc[-1]
    max_accuracy = df['accuracy'].max()
    avg_trust = df['avg_trust'].mean()
    
    print(f"\n📊 {dataset_name} Training Summary:")
    print(f"  Final Accuracy: {final_accuracy:.2f}%")
    print(f"  Best Accuracy: {max_accuracy:.2f}%")
    print(f"  Average Trust: {avg_trust:.3f}")
    print(f"  Total Rounds: {len(df)}")
    
    if final_accuracy >= target_accuracy:
        print(f"  🎉 TARGET ACHIEVED! {final_accuracy:.2f}% >= {target_accuracy}%")
    else:
        print(f"  ⚠️  Target not reached: {final_accuracy:.2f}% < {target_accuracy}%")
    
    return df

# Plot MNIST results
print("MNIST Results:")
mnist_df = plot_training_results(mnist_history, "MNIST", 97.0)

# Plot CIFAR-10 results
print("\nCIFAR-10 Results:")
cifar_df = plot_training_results(cifar_history, "CIFAR-10", 80.5)

In [None]:
# Save results to CSV
if mnist_history:
    mnist_df = pd.DataFrame(mnist_history)
    mnist_df.to_csv("mnist_training_results.csv", index=False)
    print("💾 MNIST results saved to mnist_training_results.csv")

if cifar_history:
    cifar_df = pd.DataFrame(cifar_history)
    cifar_df.to_csv("cifar10_training_results.csv", index=False)
    print("💾 CIFAR-10 results saved to cifar10_training_results.csv")

## 5. Download Results

In [None]:
# Download trained models and results
import os
from google.colab import files

# List available files
print("Available files for download:")
for file in os.listdir('.'):
    if file.endswith(('.csv', '.pth', '.pkl')):
        print(f"  📄 {file}")

# Download checkpoints if they exist
if os.path.exists('checkpoints'):
    print("\nCheckpoint files:")
    for file in os.listdir('checkpoints'):
        print(f"  🔄 checkpoints/{file}")

# Uncomment to download specific files
# files.download('mnist_training_results.csv')
# files.download('cifar10_training_results.csv')
# files.download('checkpoints/mnist_global_model.pth')
# files.download('checkpoints/cifar10_global_model.pth')

## 6. Next Steps

🎯 **Performance Targets:**
- MNIST: 97.7% accuracy
- CIFAR-10: 80.5% accuracy

🔧 **If targets not met, try:**
- Increase number of rounds
- Adjust learning rates
- Modify trust mechanism parameters
- Experiment with different optimizers

📊 **Model Analysis:**
- Check convergence patterns
- Analyze trust scores
- Review aggregation rule selection
- Examine Byzantine attack impact

💾 **Model Deployment:**
- Download trained models
- Use checkpoints for inference
- Deploy TARS agent for production