# Tutorial 5: Running the Full Federated Learning Experiment

In this final notebook, we'll:
1. Set up configuration for the FL experiment
2. Run federated learning with Flower CLI
3. Analyze the results
4. Compare with centralized training
5. Explore next steps and advanced topics

## 1. Setting Up the Experiment

Before running federated learning, we need to configure:
- Number of clients
- Number of rounds
- Learning rate
- Local epochs per round
- Batch size

In [None]:
# Configuration for our experiment
config = {
    "num_clients": 5,
    "num_server_rounds": 10,
    "learning_rate": 0.01,
    "local_epochs": 3,
    "batch_size": 32,
    "fraction_evaluate": 1.0,  # Use all clients for evaluation
}

print("Federated Learning Configuration:")
print("=" * 50)
for key, value in config.items():
    print(f"{key:.<30} {value}")
print("=" * 50)

## 2. Understanding the Flower Configuration File

Flower uses YAML configuration files to set up experiments. Here's what a typical config looks like:

```yaml
# pyproject.toml
[tool.flwr.app]
publisher = "flwrlabs"

[tool.flwr.app.components]
serverapp = "fltutorial.server:app"
clientapp = "fltutorial.client:app"

[tool.flwr.app.config]
num-server-rounds = 10
learning-rate = 0.01
local-epochs = 3
batch-size = 32
fraction-evaluate = 1.0

[tool.flwr.federations]
default = "local-simulation"

[tool.flwr.federations.local-simulation]
options.num-supernodes = 5
```

## 3. Running Federated Learning

To run the experiment, we use the Flower CLI:

```bash
# Run in simulation mode (all clients on same machine)
flwr run
```

This command:
1. Reads configuration from `pyproject.toml`
2. Starts the server
3. Spawns the specified number of client processes
4. Runs the federated learning for the configured number of rounds
5. Saves the final model

**Note**: We can't run this directly in a notebook because it requires multiple processes. See the next section for how to run it from the terminal.

## 4. Alternative: Simulation with Flower Simulation

For educational purposes, let's see how to set up a simple simulation:

In [None]:
# Note: This is a simplified example for understanding
# Real federated learning should use 'flwr run' command

import torch
from fltutorial.task import Net, load_data, train as train_fn, test as test_fn
import matplotlib.pyplot as plt

print("Setting up simulation...")
print("This demonstrates the FL process but isn't actual federated learning")
print("For real FL, use: flwr run\n")

## 5. Analyzing Results

After running federated learning, we can analyze:
- Training curves (loss and accuracy over rounds)
- Final model performance
- Per-client performance
- Communication costs

In [None]:
# Example: Visualizing training progress
# These would be real metrics from your FL run
import matplotlib.pyplot as plt
import numpy as np

# Simulated results for demonstration
rounds = np.arange(1, 11)
train_loss = [2.3, 1.5, 1.0, 0.8, 0.6, 0.5, 0.45, 0.42, 0.40, 0.38]
test_accuracy = [15, 42, 58, 68, 75, 78, 80, 81, 82, 83]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss curve
ax1.plot(rounds, train_loss, 'o-', linewidth=2, markersize=8, color='#e74c3c')
ax1.set_xlabel('Federated Round', fontsize=12)
ax1.set_ylabel('Average Training Loss', fontsize=12)
ax1.set_title('Training Loss Over Federated Rounds', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)
ax1.set_xticks(rounds)

# Accuracy curve
ax2.plot(rounds, test_accuracy, 's-', linewidth=2, markersize=8, color='#27ae60')
ax2.set_xlabel('Federated Round', fontsize=12)
ax2.set_ylabel('Test Accuracy (%)', fontsize=12)
ax2.set_title('Test Accuracy Over Federated Rounds', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)
ax2.set_xticks(rounds)
ax2.set_ylim(0, 100)

plt.tight_layout()
plt.show()

print(f"Final Test Accuracy: {test_accuracy[-1]:.2f}%")
print(f"Final Training Loss: {train_loss[-1]:.4f}")

## 6. Comparing Federated vs Centralized Learning

Let's compare federated learning with traditional centralized learning:

In [None]:
import pandas as pd

# Comparison table
comparison = pd.DataFrame({
    'Aspect': [
        'Data Privacy',
        'Communication Cost',
        'Training Speed',
        'Final Accuracy',
        'Scalability',
        'Regulatory Compliance'
    ],
    'Centralized': [
        'Low (all data in one place)',
        'Low (no model exchange)',
        'Fast',
        'High (typically 85-90%)',
        'Limited by server capacity',
        'Difficult (data aggregation)'
    ],
    'Federated': [
        'High (data stays local)',
        'Higher (model exchange)',
        'Slower (multiple rounds)',
        'Good (typically 80-85%)',
        'High (distributed compute)',
        'Easier (GDPR compliant)'
    ]
})

print("Centralized vs Federated Learning Comparison")
print("=" * 100)
print(comparison.to_string(index=False))
print("=" * 100)

In [None]:
# Visualize accuracy comparison
import matplotlib.pyplot as plt

methods = ['Centralized\nLearning', 'Federated\nLearning\n(IID)', 'Federated\nLearning\n(Non-IID)']
accuracies = [87, 83, 78]
colors = ['#3498db', '#27ae60', '#e67e22']

fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.bar(methods, accuracies, color=colors, alpha=0.7, edgecolor='black', linewidth=2)

# Add value labels on bars
for bar, acc in zip(bars, accuracies):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{acc}%',
            ha='center', va='bottom', fontsize=14, fontweight='bold')

ax.set_ylabel('Test Accuracy (%)', fontsize=12)
ax.set_title('FashionMNIST Classification: Centralized vs Federated', fontsize=14, fontweight='bold')
ax.set_ylim(0, 100)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print("\nKey Observations:")
print("- Centralized learning achieves highest accuracy (all data together)")
print("- Federated learning (IID) is close, with ~4% accuracy drop")
print("- Non-IID data makes FL more challenging (~9% drop)")
print("- Privacy benefits of FL outweigh small accuracy loss in many cases")

## 7. Understanding Communication Costs

One important aspect of federated learning is communication efficiency:

In [None]:
# Calculate communication costs
from fltutorial.task import Net
import torch

model = Net()
num_params = sum(p.numel() for p in model.parameters())
bytes_per_param = 4  # float32
model_size_mb = (num_params * bytes_per_param) / (1024 * 1024)

num_clients = 5
num_rounds = 10

# Each round: server sends to all clients, clients send back
total_communication_mb = num_rounds * num_clients * 2 * model_size_mb

print("Communication Cost Analysis")
print("=" * 50)
print(f"Model parameters: {num_params:,}")
print(f"Model size: {model_size_mb:.2f} MB")
print(f"\nPer round:")
print(f"  Server → Clients: {num_clients * model_size_mb:.2f} MB")
print(f"  Clients → Server: {num_clients * model_size_mb:.2f} MB")
print(f"  Total per round: {num_clients * 2 * model_size_mb:.2f} MB")
print(f"\nTotal for {num_rounds} rounds:")
print(f"  {total_communication_mb:.2f} MB")
print("=" * 50)

# Comparison with centralized
# In centralized, you'd send all data once
dataset_size_mb = 60000 * 28 * 28 / (1024 * 1024)  # FashionMNIST training set
print(f"\nFor comparison:")
print(f"Sending all FashionMNIST training data: {dataset_size_mb:.2f} MB")
print(f"FL communication overhead: {(total_communication_mb / dataset_size_mb):.2f}x")
print("\nBut remember: In FL, raw data NEVER leaves the client!")

## 8. Loading and Testing the Final Model

After FL training completes, we can load and test the final model:

In [None]:
# This code would work after running 'flwr run'
# It loads the saved model and tests it

import torch
from fltutorial.task import Net, load_centralized_dataset, test as test_fn

# Load the final model (if it exists)
model_path = "final_model.pt"

try:
    # Load saved model
    model = Net()
    model.load_state_dict(torch.load(model_path))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Test on centralized test set
    test_dataloader = load_centralized_dataset()
    test_loss, test_acc = test_fn(model, test_dataloader, device)
    
    print("Final Model Performance")
    print("=" * 50)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc*100:.2f}%")
    print("=" * 50)
    
except FileNotFoundError:
    print("No saved model found.")
    print("Run 'flwr run' first to train the model.")
    print("\nFor now, this notebook demonstrates the concepts.")

## 9. Running the Real Experiment

To run the actual federated learning experiment:

### Step 1: Open a terminal in the project directory

### Step 2: Run the Flower command
```bash
flwr run
```

### Step 3: Watch the output
You'll see:
- Server starting
- Clients connecting
- Training progress for each round
- Global model evaluation after each round
- Final model being saved

### Step 4: Analyze results
Come back to this notebook and run the analysis cells above!

### Alternative: Docker Setup
```bash
# Build and run with docker-compose
docker-compose up
```

## 10. Advanced Topics and Next Steps

Now that you understand the basics, here are advanced topics to explore:

### 1. **Non-IID Data**
   - Use `DirichletPartitioner` instead of `IidPartitioner`
   - Observe how heterogeneous data affects convergence
   - Experiment with different alpha values

### 2. **Advanced Aggregation Strategies**
   - **FedProx**: Adds a proximal term to handle heterogeneity
   - **FedAdam**: Uses adaptive learning rates
   - **FedBN**: Doesn't aggregate batch normalization layers

### 3. **Differential Privacy**
   - Add noise to model updates
   - Implement DP-SGD in the client
   - Balance privacy and accuracy

### 4. **Communication Efficiency**
   - **Gradient compression**: Reduce communication by compressing updates
   - **Partial aggregation**: Only update some layers
   - **Quantization**: Use lower precision for communication

### 5. **Cross-Device vs Cross-Silo**
   - **Cross-device**: Many mobile devices (phones, IoT)
   - **Cross-silo**: Few organizations (hospitals, banks)
   - Different challenges and solutions

### 6. **Secure Aggregation**
   - Prevent server from seeing individual updates
   - Homomorphic encryption
   - Secure multi-party computation

### 7. **Personalization**
   - Fine-tune global model for each client
   - Meta-learning approaches
   - Per-client model customization

## 11. Resources for Further Learning

### Documentation
- **Flower Docs**: https://flower.ai/docs/
- **Flower Examples**: https://github.com/adap/flower/tree/main/examples

### Papers
- **Original FedAvg**: McMahan et al. (2017) "Communication-Efficient Learning of Deep Networks from Decentralized Data"
- **FedProx**: Li et al. (2020) "Federated Optimization in Heterogeneous Networks"
- **DP-FL**: Geyer et al. (2017) "Differentially Private Federated Learning"

### Courses
- Stanford CS329S: Machine Learning Systems Design
- Andrew Ng's Privacy-Preserving ML course

### Communities
- Flower Slack community
- r/MachineLearning on Reddit
- Federated Learning subreddit

## Summary

Congratulations! You've completed the Federated Learning tutorial! 🎉

### What you learned:
1. ✅ Federated learning concepts and motivation
2. ✅ FashionMNIST dataset and data partitioning
3. ✅ Building CNN models for image classification
4. ✅ Implementing Flower clients
5. ✅ Implementing Flower servers and FedAvg
6. ✅ Running complete FL experiments
7. ✅ Analyzing results and comparing with centralized learning
8. ✅ Understanding communication costs

### Your Journey:
```
Introduction → Dataset → Model → Client → Server → Experiment
                                                         ↓
                                              🎯 You are here!
```

### What's Next?
- Experiment with different configurations
- Try non-IID data partitioning
- Implement your own aggregation strategy
- Apply FL to your own datasets
- Contribute to the Flower community!

**Happy Federated Learning!** 🚀

## Final Exercises

**Exercise 1**: Run the FL experiment with different numbers of clients (2, 5, 10). How does it affect convergence?

**Exercise 2**: Change the local epochs parameter. What happens with more/fewer local epochs?

**Exercise 3**: Implement a function to visualize the class distribution across clients for non-IID partitioning.

**Exercise 4**: Research and propose a solution for the "stragglers problem" (when some clients are much slower than others).

**Challenge Project**: Implement federated learning for a different dataset (e.g., MNIST, CIFAR-10) or a different task (e.g., sentiment analysis, time series forecasting).