# Tutorial 4: Building a Flower Server

In this notebook, we'll learn:
1. What is a Flower server?
2. Understanding aggregation strategies (FedAvg)
3. Implementing server logic
4. Global model evaluation
5. Saving the final model

In [None]:
import torch
from flwr.app import ArrayRecord, ConfigRecord, Context, MetricRecord
from flwr.serverapp import Grid, ServerApp
from flwr.serverapp.strategy import FedAvg
from fltutorial.task import Net, load_centralized_dataset, test

print("All imports successful!")

## 1. What is a Flower Server?

The server is the coordinator of federated learning:

### Server Responsibilities:
1. **Initialize** the global model
2. **Select** clients for each round
3. **Send** global model to selected clients
4. **Receive** model updates from clients
5. **Aggregate** updates using a strategy (e.g., FedAvg)
6. **Evaluate** the global model (optional)
7. **Save** the final model

### Federated Learning Flow:
```
Server                      Clients
  │                           │
  ├─ Initialize model         │
  │                           │
  ├─ Round 1 ─────────────────┤
  │  ├─ Select clients         │
  │  ├─ Send model ──────────>│ (Train locally)
  │  │                         │
  │  │<─ Receive updates ──────┤
  │  ├─ Aggregate (FedAvg)     │
  │  └─ Evaluate               │
  │                           │
  ├─ Round 2 ─────────────────┤
  │  ...                       │
  │                           │
  └─ Save final model         │
```

## 2. Understanding FedAvg Strategy

**FedAvg** (Federated Averaging) is the most common aggregation strategy.

### How FedAvg Works:

1. **Receive Updates**: Get model weights from K clients
   - Client 1: $w_1$, trained on $n_1$ samples
   - Client 2: $w_2$, trained on $n_2$ samples
   - ...
   - Client K: $w_K$, trained on $n_K$ samples

2. **Compute Weighted Average**:
   $$w_{global} = \frac{\sum_{i=1}^{K} n_i \cdot w_i}{\sum_{i=1}^{K} n_i}$$

3. **Update Global Model**: Set $w_{global}$ as the new global model

### Why Weighted?
- Clients with more data contribute more to the global model
- Ensures fairness and better convergence
- Handles heterogeneous data sizes

In [None]:
# Simple demonstration of FedAvg
import numpy as np
import matplotlib.pyplot as plt

# Simulate 3 clients with different number of samples
clients = [
    {"id": 0, "weight": 2.5, "samples": 100},
    {"id": 1, "weight": 3.0, "samples": 200},
    {"id": 2, "weight": 1.5, "samples": 50},
]

# Simple average (incorrect)
simple_avg = np.mean([c["weight"] for c in clients])

# Weighted average (FedAvg - correct)
total_samples = sum(c["samples"] for c in clients)
weighted_avg = sum(c["weight"] * c["samples"] for c in clients) / total_samples

# Visualize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart of weights
client_ids = [c["id"] for c in clients]
weights = [c["weight"] for c in clients]
samples = [c["samples"] for c in clients]

ax1.bar(client_ids, weights, color=['lightblue', 'lightgreen', 'lightcoral'])
ax1.axhline(y=simple_avg, color='red', linestyle='--', label=f'Simple avg: {simple_avg:.2f}')
ax1.axhline(y=weighted_avg, color='green', linestyle='--', label=f'Weighted avg (FedAvg): {weighted_avg:.2f}')
ax1.set_xlabel('Client ID')
ax1.set_ylabel('Model Weight (example)')
ax1.set_title('Client Weights')
ax1.legend()
ax1.set_xticks(client_ids)

# Pie chart of sample distribution
ax2.pie(samples, labels=[f'Client {i}\n{s} samples' for i, s in zip(client_ids, samples)], 
        autopct='%1.1f%%', colors=['lightblue', 'lightgreen', 'lightcoral'])
ax2.set_title('Data Distribution')

plt.tight_layout()
plt.show()

print(f"Simple average: {simple_avg:.4f}")
print(f"Weighted average (FedAvg): {weighted_avg:.4f}")
print(f"\nClient 1 has more data (200 samples), so it influences the global model more!")

## 3. Implementing the Server

Let's build the server step by step:

In [None]:
# Create ServerApp
app = ServerApp()

@app.main()
def main(grid: Grid, context: Context) -> None:
    """Main entry point for the ServerApp."""
    
    print("\n" + "="*60)
    print("[SERVER] Starting Federated Learning Server")
    print("="*60)
    
    # Read configuration from context
    fraction_evaluate: float = context.run_config["fraction-evaluate"]
    num_rounds: int = context.run_config["num-server-rounds"]
    lr: float = context.run_config["learning-rate"]
    
    print(f"[SERVER] Configuration:")
    print(f"  - Number of rounds: {num_rounds}")
    print(f"  - Learning rate: {lr}")
    print(f"  - Fraction evaluate: {fraction_evaluate}")
    
    # Initialize global model
    print(f"\n[SERVER] Initializing global model...")
    global_model = Net()
    arrays = ArrayRecord(global_model.state_dict())
    print(f"[SERVER] Global model initialized with {sum(p.numel() for p in global_model.parameters()):,} parameters")
    
    # Initialize FedAvg strategy
    print(f"\n[SERVER] Initializing FedAvg strategy...")
    strategy = FedAvg(fraction_evaluate=fraction_evaluate)
    
    # Start federated learning
    print(f"\n[SERVER] Starting federated learning for {num_rounds} rounds...")
    print("="*60 + "\n")
    
    result = strategy.start(
        grid=grid,
        initial_arrays=arrays,
        train_config=ConfigRecord({"lr": lr}),
        num_rounds=num_rounds,
        evaluate_fn=global_evaluate,
    )
    
    # Save final model
    print("\n" + "="*60)
    print("[SERVER] Federated learning completed!")
    print("[SERVER] Saving final model to disk...")
    state_dict = result.arrays.to_torch_state_dict()
    torch.save(state_dict, "final_model.pt")
    print("[SERVER] Model saved as 'final_model.pt'")
    print("="*60 + "\n")

print("Server main function defined!")

## 4. Global Model Evaluation

The server can evaluate the global model on a centralized test set:

### Why Global Evaluation?
- **Monitor Progress**: Track how the global model improves over rounds
- **Early Stopping**: Stop training if model converges
- **Fair Comparison**: Evaluate on same data across all rounds

### When to Evaluate?
- After each round of aggregation
- On a centralized test set (not distributed to clients)

In [None]:
def global_evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord:
    """Evaluate model on central test data."""
    
    print(f"\n[SERVER] Round {server_round}: Evaluating global model...")
    
    # Load the model and initialize with received weights
    model = Net()
    model.load_state_dict(arrays.to_torch_state_dict())
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Load entire test set
    test_dataloader = load_centralized_dataset()
    
    # Evaluate the global model
    test_loss, test_acc = test(model, test_dataloader, device)
    
    print(f"[SERVER] Round {server_round}: Loss={test_loss:.4f}, Accuracy={test_acc*100:.2f}%")
    
    # Return the evaluation metrics
    return MetricRecord({"accuracy": test_acc, "loss": test_loss})

print("Global evaluate function defined!")

## 5. Understanding the Grid

The **Grid** is Flower's abstraction for managing clients:

### Grid Responsibilities:
- **Client Discovery**: Find available clients
- **Client Selection**: Choose which clients participate in each round
- **Communication**: Handle message passing between server and clients
- **Fault Tolerance**: Handle client failures gracefully

### Client Selection Strategies:
1. **All clients**: Use all available clients (default)
2. **Random sampling**: Randomly select a fraction of clients
3. **Custom selection**: Implement custom logic based on client properties

## 6. Server Configuration

The server uses configuration parameters from `context.run_config`:

### Common Parameters:
- `num-server-rounds`: Number of federated learning rounds
- `learning-rate`: Learning rate for client training
- `fraction-evaluate`: Fraction of clients to use for evaluation
- `local-epochs`: Number of local training epochs per round
- `batch-size`: Batch size for training

These parameters are typically set in a configuration file or command-line arguments.

## 7. Tracking Training Progress

Let's visualize what happens during federated learning:

In [None]:
# Simulate federated learning progress
import matplotlib.pyplot as plt

# Simulated metrics over 10 rounds
rounds = list(range(1, 11))
loss = [2.3, 1.8, 1.4, 1.1, 0.9, 0.8, 0.7, 0.65, 0.62, 0.60]
accuracy = [10, 35, 52, 65, 72, 77, 80, 82, 83, 84]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot loss
ax1.plot(rounds, loss, marker='o', linewidth=2, markersize=8, color='red')
ax1.set_xlabel('Round', fontsize=12)
ax1.set_ylabel('Test Loss', fontsize=12)
ax1.set_title('Global Model Loss Over Rounds', fontsize=14)
ax1.grid(True, alpha=0.3)

# Plot accuracy
ax2.plot(rounds, accuracy, marker='s', linewidth=2, markersize=8, color='green')
ax2.set_xlabel('Round', fontsize=12)
ax2.set_ylabel('Test Accuracy (%)', fontsize=12)
ax2.set_title('Global Model Accuracy Over Rounds', fontsize=14)
ax2.grid(True, alpha=0.3)
ax2.set_ylim(0, 100)

plt.tight_layout()
plt.show()

print("Notice how:")
print("- Loss decreases over rounds (model is learning)")
print("- Accuracy increases over rounds")
print("- Improvement slows down in later rounds (convergence)")

## 8. The Complete Server Code

The complete server implementation is available in [src/fltutorial/server.py](../src/fltutorial/server.py).

### Key Takeaways:

1. **Coordinator**: Server orchestrates the entire training process
2. **Aggregation**: Uses FedAvg to combine client updates
3. **Evaluation**: Tracks global model performance
4. **Stateful**: Server maintains the global model across rounds
5. **Configurable**: Behavior controlled by configuration parameters

## Summary

In this notebook, we learned:
1. ✅ What a Flower server is and its responsibilities
2. ✅ Understanding the FedAvg aggregation strategy
3. ✅ Implementing server logic with ServerApp
4. ✅ Global model evaluation
5. ✅ Tracking training progress over rounds

**Next Steps**: In Notebook 5, we'll run the complete federated learning experiment with multiple clients!

## Exercises for Students

**Exercise 1**: Why do we use weighted averaging instead of simple averaging in FedAvg?

**Exercise 2**: What would happen if we didn't evaluate the global model? How would we know if training is working?

**Exercise 3**: Research: What are some alternative aggregation strategies besides FedAvg? (Hint: FedProx, FedAdam)

**Exercise 4**: If one client sends corrupted updates, how might it affect the global model? How could we detect this?