# Tutorial 3: Building a Flower Client

In this notebook, we'll learn:
1. What is a Flower client?
2. How to implement train and evaluate functions
3. Understanding the client lifecycle
4. Message passing in Flower

In [None]:
import torch
from flwr.app import ArrayRecord, ConfigRecord, Context, Message, MetricRecord, RecordDict
from flwr.clientapp import ClientApp
from fltutorial.task import Net, load_data, train as train_fn, test as test_fn

print("All imports successful!")

## 1. What is a Flower Client?

In Flower's federated learning architecture:

- **Client**: Trains the model on local data and sends updates to the server
- **Server**: Coordinates the training process and aggregates model updates

### Client Responsibilities:
1. **Receive** global model parameters from server
2. **Train** the model on local data
3. **Send** updated model parameters back to server
4. **Evaluate** the model on local test data (optional)

### Communication Flow:
```
Server                          Client
   │                              │
   ├──── Send Model ────────────>│
   │                              │ (Load model)
   │                              │ (Train locally)
   │                              │
   │<──── Send Updates ───────────┤
   │                              │
   ├──── Request Eval ──────────>│
   │                              │ (Evaluate)
   │<──── Send Metrics ───────────┤
```

## 2. Understanding Flower Messages

Flower uses **Messages** to communicate between server and clients.

### Message Structure:
```python
Message(
    content={
        "arrays": ArrayRecord,    # Model weights
        "config": ConfigRecord,   # Training config (lr, epochs, etc.)
        "metrics": MetricRecord   # Metrics (loss, accuracy, etc.)
    }
)
```

### Key Components:
- **ArrayRecord**: Stores model parameters (weights, biases)
- **ConfigRecord**: Training configuration (learning rate, batch size, etc.)
- **MetricRecord**: Performance metrics (loss, accuracy, number of samples, etc.)

## 3. Implementing the Train Function

The train function:
1. Receives a message with global model weights
2. Loads the model and sets the weights
3. Trains on local data
4. Returns updated weights and metrics

In [None]:
# Create a ClientApp instance
app = ClientApp()

@app.train()
def train(msg: Message, context: Context) -> Message:
    """Train the model on local data."""
    
    # Extract message content
    assert isinstance(msg.content["arrays"], ArrayRecord)
    assert isinstance(msg.content["config"], ConfigRecord)
    assert isinstance(msg.content["metrics"], MetricRecord)
    
    print("\n[CLIENT] Received training request")
    
    # Load the model and initialize it with received weights
    model = Net()
    model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"[CLIENT] Model loaded on {device}")
    
    # Load local data
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    batch_size = context.run_config["batch-size"]
    
    if (
        isinstance(partition_id, int)
        and isinstance(num_partitions, int)
        and isinstance(batch_size, int)
    ):
        trainloader, _ = load_data(partition_id, num_partitions, batch_size)
        print(f"[CLIENT] Loaded data for partition {partition_id}/{num_partitions}")
        print(f"[CLIENT] Training samples: {len(trainloader.dataset)}")
    else:
        raise ValueError(
            "partition_id, num_partitions, and batch_size must be integers"
        )
    
    # Train the model
    local_epochs = context.run_config["local-epochs"]
    learning_rate = msg.content["config"]["lr"]
    print(f"[CLIENT] Training for {local_epochs} epochs with lr={learning_rate}")
    
    train_loss = train_fn(
        model,
        trainloader,
        local_epochs,
        learning_rate,
        device,
    )
    print(f"[CLIENT] Training completed. Loss: {train_loss:.4f}")
    
    # Construct reply message with updated model and metrics
    model_record = ArrayRecord(model.state_dict())
    metrics = {
        "train_loss": train_loss,
        "num-examples": len(trainloader.dataset),
    }
    metric_record = MetricRecord(metrics)
    content = RecordDict({"arrays": model_record, "metrics": metric_record})
    
    print(f"[CLIENT] Sending {metrics['num-examples']} examples back to server\n")
    return Message(content=content, reply_to=msg)

print("Train function defined!")

## 4. Implementing the Evaluate Function

The evaluate function:
1. Receives a message with model weights
2. Evaluates the model on local test data
3. Returns evaluation metrics

In [None]:
@app.evaluate()
def evaluate(msg: Message, context: Context):
    """Evaluate the model on local data."""
    
    # Extract message content
    assert isinstance(msg.content["arrays"], ArrayRecord)
    assert isinstance(msg.content["config"], ConfigRecord)
    assert isinstance(msg.content["metrics"], MetricRecord)
    
    print("\n[CLIENT] Received evaluation request")
    
    # Load the model and initialize it with received weights
    model = Net()
    model.load_state_dict(msg.content["arrays"].to_torch_state_dict())
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Load local data
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    batch_size = context.run_config["batch-size"]
    
    if (
        isinstance(partition_id, int)
        and isinstance(num_partitions, int)
        and isinstance(batch_size, int)
    ):
        _, valloader = load_data(partition_id, num_partitions, batch_size)
        print(f"[CLIENT] Evaluating on partition {partition_id}")
        print(f"[CLIENT] Test samples: {len(valloader.dataset)}")
    else:
        raise ValueError(
            "partition_id, num_partitions, and batch_size must be integers"
        )
    
    # Evaluate the model
    eval_loss, eval_acc = test_fn(
        model,
        valloader,
        device,
    )
    print(f"[CLIENT] Evaluation completed.")
    print(f"[CLIENT] Loss: {eval_loss:.4f}, Accuracy: {eval_acc*100:.2f}%")
    
    # Construct reply message with metrics
    metrics = {
        "eval_loss": eval_loss,
        "eval_acc": eval_acc,
        "num-examples": len(valloader.dataset),
    }
    metric_record = MetricRecord(metrics)
    content = RecordDict({"metrics": metric_record})
    
    print(f"[CLIENT] Sending metrics back to server\n")
    return Message(content=content, reply_to=msg)

print("Evaluate function defined!")

## 5. Understanding the Client Lifecycle

During a federated learning session, each client goes through multiple rounds:

### Round Structure:
```
Round 1:
  1. Receive global model from server
  2. Train on local data (update weights)
  3. Send updated weights to server
  4. [Optional] Evaluate model on local test data
  5. Send evaluation metrics to server

Round 2:
  1. Receive NEW global model (aggregated from all clients)
  2. Train on local data
  3. Send updated weights to server
  4. ...
  
... (continue for N rounds)
```

### Key Points:
- Each round starts with a **fresh global model** (not the local updated one)
- Clients don't keep state between rounds
- The server aggregates updates from ALL clients before next round

## 6. Visualizing Weight Updates

Let's visualize how weights change during training:

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Create a simple demonstration
device = torch.device("cpu")

# Initialize model
model = Net()
initial_weights = model.fc3.weight.data.clone()

# Load data for one client
from fltutorial.task import load_data
trainloader, testloader = load_data(partition_id=0, num_partitions=5, batch_size=32)

# Train for 1 epoch
train_loss = train_fn(model, trainloader, epochs=1, lr=0.01, device=device)
updated_weights = model.fc3.weight.data.clone()

# Calculate the difference
weight_diff = (updated_weights - initial_weights).abs()

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

im1 = axes[0].imshow(initial_weights.numpy(), cmap='viridis', aspect='auto')
axes[0].set_title('Initial Weights (from server)')
axes[0].set_xlabel('Input features')
axes[0].set_ylabel('Output classes')
plt.colorbar(im1, ax=axes[0])

im2 = axes[1].imshow(updated_weights.numpy(), cmap='viridis', aspect='auto')
axes[1].set_title('Updated Weights (after local training)')
axes[1].set_xlabel('Input features')
axes[1].set_ylabel('Output classes')
plt.colorbar(im2, ax=axes[1])

im3 = axes[2].imshow(weight_diff.numpy(), cmap='hot', aspect='auto')
axes[2].set_title('Weight Updates (sent to server)')
axes[2].set_xlabel('Input features')
axes[2].set_ylabel('Output classes')
plt.colorbar(im3, ax=axes[2])

plt.tight_layout()
plt.show()

print(f"Average weight change: {weight_diff.mean():.6f}")
print(f"Max weight change: {weight_diff.max():.6f}")

## 7. The Complete Client Code

The complete client implementation is available in [src/fltutorial/client.py](../src/fltutorial/client.py).

### Key Takeaways:

1. **Stateless**: Clients don't maintain state between rounds
2. **Message-Based**: All communication via Flower Messages
3. **Flexible**: Can customize training logic, data loading, etc.
4. **Privacy-Preserving**: Only model weights are shared, not raw data

## Summary

In this notebook, we learned:
1. ✅ What a Flower client is and its responsibilities
2. ✅ How to implement train and evaluate functions
3. ✅ Understanding message passing in Flower
4. ✅ The client lifecycle and round structure
5. ✅ Visualizing weight updates during training

**Next Steps**: In Notebook 4, we'll learn how to build the server that coordinates all clients!

## Exercises for Students

**Exercise 1**: What information is sent from client to server? What is NOT sent?

**Exercise 2**: Why is it important that clients are stateless (don't keep info between rounds)?

**Exercise 3**: Modify the code to track and plot the training loss over epochs within a single round.

**Exercise 4**: What would happen if one client has much more data than others? How does FedAvg handle this?