# Federated Learning with MNIST - Complete Tutorial

This notebook demonstrates how to implement Federated Learning using the **Flower** framework with the MNIST dataset.

## ‚ö° FAST MODE ENABLED
This notebook is configured for **quick testing** (~5-10 minutes on Kaggle).  
Settings: 5 clients, 3 rounds, 3 clients per round.

To run full experiments, modify Cell 3 with:
- `NUM_CLIENTS = 10`
- `NUM_ROUNDS = 10`
- `CLIENTS_PER_ROUND = 5`

## What you'll learn:
1. How to partition a dataset for federated learning (IID vs Non-IID)
2. How to create FL clients and server
3. How to run a federated learning simulation
4. How to visualize data distribution across clients

**Author:** Generated for Federated Learning Tutorial  
**Platform:** Kaggle / Google Colab  
**Framework:** Flower (flwr)

---
## Cell 1: Install Dependencies
Run this cell first to install the required packages.

In [None]:
# Install Flower and Flower Datasets
!pip install -q flwr flwr-datasets torch torchvision matplotlib seaborn

print("Installation complete!")

---
## Cell 2: Import Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Normalize, Compose
from collections import OrderedDict
from typing import List, Tuple, Dict
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Flower imports
import flwr as fl
from flwr.common import Metrics
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, DirichletPartitioner

# Check versions
print(f"PyTorch version: {torch.__version__}")
print(f"Flower version: {fl.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

---
## Cell 3: Configuration (‚ö° FAST MODE)

These settings are optimized for quick testing (~5-10 minutes on Kaggle).

| Parameter | Fast Mode | Full Experiment |
|-----------|-----------|-----------------|
| NUM_CLIENTS | 5 | 10 |
| NUM_ROUNDS | 3 | 10 |
| CLIENTS_PER_ROUND | 3 | 5 |
| BATCH_SIZE | 64 | 32 |

In [None]:
# =============================================================================
# CONFIGURATION - FAST MODE (for quick testing on Kaggle ~5-10 minutes)
# =============================================================================
# To run full experiments later, increase these values:
#   NUM_CLIENTS=10, NUM_ROUNDS=10, CLIENTS_PER_ROUND=5

# Federated Learning Settings (REDUCED FOR FAST TESTING)
NUM_CLIENTS = 5           # Number of clients (simulated devices) [was 10]
NUM_ROUNDS = 3            # Number of federated learning rounds [was 5]
CLIENTS_PER_ROUND = 3     # Number of clients selected per round [was 5]

# Data Partitioning Settings
PARTITION_TYPE = "dirichlet"  # Options: "iid" or "dirichlet"
DIRICHLET_ALPHA = 0.5         # Lower = more non-IID (only used if PARTITION_TYPE="dirichlet")

# Training Settings
BATCH_SIZE = 64           # Larger batch = faster training [was 32]
LOCAL_EPOCHS = 1          # Number of local epochs per client per round
LEARNING_RATE = 0.01

# Random seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

print("="*60)
print("‚ö° FAST MODE CONFIGURATION (for quick pipeline testing)")
print("="*60)
print(f"  - Number of clients: {NUM_CLIENTS}")
print(f"  - Clients per round: {CLIENTS_PER_ROUND}")
print(f"  - Number of rounds: {NUM_ROUNDS}")
print(f"  - Partition type: {PARTITION_TYPE}")
if PARTITION_TYPE == "dirichlet":
    print(f"  - Dirichlet alpha: {DIRICHLET_ALPHA}")
print(f"  - Local epochs: {LOCAL_EPOCHS}")
print(f"  - Batch size: {BATCH_SIZE}")
print("="*60)
print("Estimated time on Kaggle: 5-10 minutes")

---
## Cell 4: Create Federated Dataset

This is the KEY difference from centralized learning!

Instead of one dataset, we partition MNIST into multiple client datasets:
- **IID**: Each client gets a random, uniform sample (all classes represented equally)
- **Non-IID (Dirichlet)**: Each client gets a skewed distribution (some classes overrepresented)

In [None]:
# Create the partitioner based on configuration
if PARTITION_TYPE == "iid":
    partitioner = IidPartitioner(num_partitions=NUM_CLIENTS)
    print(f"Using IID partitioning with {NUM_CLIENTS} partitions")
else:
    partitioner = DirichletPartitioner(
        num_partitions=NUM_CLIENTS,
        partition_by="label",
        alpha=DIRICHLET_ALPHA,
        min_partition_size=100,
        self_balancing=True
    )
    print(f"Using Dirichlet partitioning with alpha={DIRICHLET_ALPHA}")

# Create FederatedDataset - this downloads MNIST automatically!
print("\nDownloading and partitioning MNIST dataset...")
fds = FederatedDataset(
    dataset="mnist",
    partitioners={"train": partitioner}
)

print("Dataset ready!")

---
## Cell 5: Visualize Data Distribution Across Clients

This visualization shows how data is distributed across clients.
- **IID**: All clients have similar distributions
- **Non-IID**: Each client has a different, skewed distribution

In [None]:
def visualize_data_distribution(fds, num_clients):
    """Visualize the label distribution across all clients."""
    
    # Collect label counts for each client
    client_label_counts = []
    
    print("Data distribution across clients:")
    print("=" * 60)
    
    for client_id in range(num_clients):
        partition = fds.load_partition(client_id, "train")
        labels = [sample["label"] for sample in partition]
        
        # Count labels
        label_counts = np.zeros(10)
        for label in labels:
            label_counts[label] += 1
        
        client_label_counts.append(label_counts)
        
        # Print summary
        present_labels = [i for i in range(10) if label_counts[i] > 0]
        print(f"Client {client_id:2d}: {len(labels):5d} samples | "
              f"Classes present: {present_labels}")
    
    # Create visualization
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Heatmap
    data_matrix = np.array(client_label_counts)
    sns.heatmap(data_matrix, annot=True, fmt='.0f', cmap='YlOrRd',
                xticklabels=[str(i) for i in range(10)],
                yticklabels=[f'Client {i}' for i in range(num_clients)],
                ax=axes[0])
    axes[0].set_xlabel('Digit Label')
    axes[0].set_ylabel('Client')
    axes[0].set_title(f'Label Distribution Across Clients\n({PARTITION_TYPE.upper()} Partitioning)')
    
    # Stacked bar chart
    data_normalized = data_matrix / data_matrix.sum(axis=1, keepdims=True)
    bottom = np.zeros(num_clients)
    colors = plt.cm.tab10(np.linspace(0, 1, 10))
    
    for label in range(10):
        axes[1].bar(range(num_clients), data_normalized[:, label], 
                   bottom=bottom, label=str(label), color=colors[label])
        bottom += data_normalized[:, label]
    
    axes[1].set_xlabel('Client ID')
    axes[1].set_ylabel('Proportion')
    axes[1].set_title('Normalized Label Distribution')
    axes[1].legend(title='Digit', bbox_to_anchor=(1.05, 1), loc='upper left')
    axes[1].set_xticks(range(num_clients))
    
    plt.tight_layout()
    plt.show()
    
    return client_label_counts

# Visualize the distribution
client_distributions = visualize_data_distribution(fds, NUM_CLIENTS)

---
## Cell 6: Define the Neural Network Model

A simple CNN for MNIST classification. This same model architecture is used by all clients.

In [None]:
class MNISTNet(nn.Module):
    """Simple CNN for MNIST classification."""
    
    def __init__(self):
        super(MNISTNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        # Ensure input is 4D: (batch, channels, height, width)
        if x.dim() == 3:
            x = x.unsqueeze(1)
        
        x = self.pool(F.relu(self.conv1(x)))  # 28x28 -> 14x14
        x = self.pool(F.relu(self.conv2(x)))  # 14x14 -> 7x7
        x = self.dropout1(x)
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

# Test the model
model = MNISTNet().to(DEVICE)
print(f"Model architecture:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

---
## Cell 7: Data Loading Functions

Functions to load and preprocess data for each client.

In [None]:
def apply_transforms(batch):
    """Apply transformations to a batch of data."""
    transforms = Compose([
        ToTensor(),
        Normalize((0.1307,), (0.3081,))  # MNIST mean and std
    ])
    batch["image"] = [transforms(img) for img in batch["image"]]
    return batch


def load_client_data(partition_id: int) -> Tuple[DataLoader, DataLoader]:
    """
    Load training and validation data for a specific client.
    
    Args:
        partition_id: The client ID (0 to NUM_CLIENTS-1)
    
    Returns:
        trainloader, valloader: DataLoader objects for training and validation
    """
    # Load the partition for this client
    partition = fds.load_partition(partition_id, "train")
    
    # Split into train (80%) and validation (20%)
    partition = partition.train_test_split(test_size=0.2, seed=SEED)
    
    # Apply transforms
    partition = partition.with_transform(apply_transforms)
    
    # Create DataLoaders
    trainloader = DataLoader(
        partition["train"], 
        batch_size=BATCH_SIZE, 
        shuffle=True,
        drop_last=True
    )
    valloader = DataLoader(
        partition["test"], 
        batch_size=BATCH_SIZE
    )
    
    return trainloader, valloader


# Test loading data for client 0
print("Testing data loading for Client 0...")
train_loader, val_loader = load_client_data(0)
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")

# Check a batch
batch = next(iter(train_loader))
print(f"  Batch image shape: {batch['image'].shape}")
print(f"  Batch label shape: {batch['label'].shape}")

---
## Cell 8: Training and Evaluation Functions

In [None]:
def train(model: nn.Module, trainloader: DataLoader, epochs: int) -> float:
    """
    Train the model on the local data.
    
    Args:
        model: The neural network model
        trainloader: DataLoader for training data
        epochs: Number of local epochs
    
    Returns:
        Average training loss
    """
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    
    total_loss = 0.0
    num_batches = 0
    
    for epoch in range(epochs):
        for batch in trainloader:
            images = batch["image"].to(DEVICE)
            labels = batch["label"].to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
    
    return total_loss / num_batches


def evaluate(model: nn.Module, valloader: DataLoader) -> Tuple[float, float]:
    """
    Evaluate the model on validation data.
    
    Args:
        model: The neural network model
        valloader: DataLoader for validation data
    
    Returns:
        loss: Average validation loss
        accuracy: Validation accuracy
    """
    model.eval()
    criterion = nn.CrossEntropyLoss()
    
    total_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in valloader:
            images = batch["image"].to(DEVICE)
            labels = batch["label"].to(DEVICE)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = correct / total
    avg_loss = total_loss / total
    
    return avg_loss, accuracy


print("Training and evaluation functions defined!")

---
## Cell 9: Define the Flower Client

The Flower Client is the core component that runs on each "device".
It handles:
1. Receiving model parameters from the server
2. Training on local data
3. Sending updated parameters back to the server

In [None]:
class FlowerClient(fl.client.NumPyClient):
    """
    Flower client for federated learning.
    
    Each client:
    1. Receives global model parameters from server
    2. Trains on its local (private) data
    3. Sends updated parameters back to server
    """
    
    def __init__(self, partition_id: int):
        self.partition_id = partition_id
        self.model = MNISTNet().to(DEVICE)
        self.trainloader, self.valloader = load_client_data(partition_id)
    
    def get_parameters(self, config):
        """Return the current model parameters."""
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
    
    def set_parameters(self, parameters):
        """Set model parameters received from the server."""
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)
    
    def fit(self, parameters, config):
        """
        Train the model on local data.
        
        This is called by the server during each round.
        """
        # Update local model with global parameters
        self.set_parameters(parameters)
        
        # Train on local data
        avg_loss = train(self.model, self.trainloader, LOCAL_EPOCHS)
        
        # Return updated parameters and training info
        return (
            self.get_parameters(config={}),
            len(self.trainloader.dataset),
            {"loss": avg_loss, "partition_id": self.partition_id}
        )
    
    def evaluate(self, parameters, config):
        """
        Evaluate the model on local validation data.
        """
        self.set_parameters(parameters)
        loss, accuracy = evaluate(self.model, self.valloader)
        
        return (
            loss,
            len(self.valloader.dataset),
            {"accuracy": accuracy, "partition_id": self.partition_id}
        )


def client_fn(cid: str) -> FlowerClient:
    """Create a Flower client for the given client ID."""
    return FlowerClient(partition_id=int(cid))


print("Flower client defined!")

---
## Cell 10: Define Metrics Aggregation

Functions to aggregate metrics (like accuracy) from all clients.

In [None]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    """
    Aggregate accuracy metrics from all clients using weighted average.
    
    Args:
        metrics: List of (num_examples, metrics_dict) tuples from each client
    
    Returns:
        Aggregated metrics dictionary
    """
    # Multiply accuracy by number of examples
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    
    # Weighted average
    return {"accuracy": sum(accuracies) / sum(examples)}


# Store metrics for plotting
round_metrics = {
    "round": [],
    "accuracy": [],
    "loss": []
}

print("Metrics aggregation defined!")

---
## Cell 11: Run Federated Learning Simulation ‚è±Ô∏è

**Estimated time: 3-5 minutes** (Fast Mode)

This is where the magic happens! The simulation:
1. Initializes a global model on the server
2. For each round:
   - Server sends model to selected clients
   - Clients train on their local data
   - Clients send updated models back
   - Server aggregates updates (FedAvg)
3. Repeat until convergence

In [None]:
import time

# Define the federated learning strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=CLIENTS_PER_ROUND / NUM_CLIENTS,  # Fraction of clients for training
    fraction_evaluate=0.5,  # Fraction of clients for evaluation
    min_fit_clients=CLIENTS_PER_ROUND,  # Minimum clients for training
    min_evaluate_clients=2,  # Minimum clients for evaluation
    min_available_clients=NUM_CLIENTS,  # Minimum available clients
    evaluate_metrics_aggregation_fn=weighted_average,  # Aggregate evaluation metrics
)

print("="*60)
print("‚ö° STARTING FEDERATED LEARNING SIMULATION (FAST MODE)")
print("="*60)
print(f"\nConfiguration:")
print(f"  - Total clients: {NUM_CLIENTS}")
print(f"  - Clients per round: {CLIENTS_PER_ROUND}")
print(f"  - Total rounds: {NUM_ROUNDS}")
print(f"  - Partition type: {PARTITION_TYPE}")
print(f"  - Local epochs: {LOCAL_EPOCHS}")
print("\n" + "="*60)
print("Starting... (this will take ~3-5 minutes)")
print("="*60 + "\n")

# Track time
start_time = time.time()

# Run the simulation
history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
    client_resources={"num_cpus": 1, "num_gpus": 0.1 if torch.cuda.is_available() else 0.0},
)

# Calculate elapsed time
elapsed_time = time.time() - start_time
minutes = int(elapsed_time // 60)
seconds = int(elapsed_time % 60)

print("\n" + "="*60)
print("‚úÖ SIMULATION COMPLETE!")
print(f"‚è±Ô∏è  Total time: {minutes}m {seconds}s")
print("="*60)

---
## Cell 12: Visualize Training Results

In [None]:
# Extract metrics from history
print("\nTraining History:")
print("-" * 40)

# Get distributed (federated) evaluation accuracy
if history.metrics_distributed:
    rounds = [r for r, _ in history.metrics_distributed["accuracy"]]
    accuracies = [acc for _, acc in history.metrics_distributed["accuracy"]]
    
    for r, acc in zip(rounds, accuracies):
        print(f"Round {r}: Accuracy = {acc:.4f} ({acc*100:.2f}%)")

# Get losses
if history.losses_distributed:
    losses = [loss for _, loss in history.losses_distributed]

# Create visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Plot accuracy
if history.metrics_distributed and "accuracy" in history.metrics_distributed:
    axes[0].plot(rounds, accuracies, 'b-o', linewidth=2, markersize=8)
    axes[0].set_xlabel('Round')
    axes[0].set_ylabel('Accuracy')
    axes[0].set_title('Federated Learning Accuracy')
    axes[0].grid(True, alpha=0.3)
    axes[0].set_ylim([0, 1])

# Plot loss
if history.losses_distributed:
    loss_rounds = [r for r, _ in history.losses_distributed]
    axes[1].plot(loss_rounds, losses, 'r-o', linewidth=2, markersize=8)
    axes[1].set_xlabel('Round')
    axes[1].set_ylabel('Loss')
    axes[1].set_title('Federated Learning Loss')
    axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final results
if accuracies:
    print(f"\n" + "="*40)
    print(f"FINAL RESULTS")
    print(f"="*40)
    print(f"Final Accuracy: {accuracies[-1]*100:.2f}%")
    print(f"Best Accuracy:  {max(accuracies)*100:.2f}%")

---
## Cell 13: Compare IID vs Non-IID (Optional)

Run this cell to see how different partitioning strategies affect training.
This demonstrates the challenge of non-IID data in federated learning.

In [None]:
def run_experiment(partition_type: str, alpha: float = 0.5, num_rounds: int = 3):
    """
    Run a federated learning experiment with specified partitioning.
    
    Args:
        partition_type: "iid" or "dirichlet"
        alpha: Dirichlet concentration parameter (only used if partition_type="dirichlet")
        num_rounds: Number of FL rounds
    
    Returns:
        History object with training metrics
    """
    global fds
    
    # Create partitioner
    if partition_type == "iid":
        partitioner = IidPartitioner(num_partitions=NUM_CLIENTS)
    else:
        partitioner = DirichletPartitioner(
            num_partitions=NUM_CLIENTS,
            partition_by="label",
            alpha=alpha,
            min_partition_size=100,
            self_balancing=True
        )
    
    # Create new FederatedDataset
    fds = FederatedDataset(
        dataset="mnist",
        partitioners={"train": partitioner}
    )
    
    # Run simulation
    strategy = fl.server.strategy.FedAvg(
        fraction_fit=CLIENTS_PER_ROUND / NUM_CLIENTS,
        fraction_evaluate=0.5,
        min_fit_clients=CLIENTS_PER_ROUND,
        min_evaluate_clients=2,
        min_available_clients=NUM_CLIENTS,
        evaluate_metrics_aggregation_fn=weighted_average,
    )
    
    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=NUM_CLIENTS,
        config=fl.server.ServerConfig(num_rounds=num_rounds),
        strategy=strategy,
        client_resources={"num_cpus": 1, "num_gpus": 0.1 if torch.cuda.is_available() else 0.0},
    )
    
    return history


# =============================================================================
# OPTIONAL: Run comparison experiments
# =============================================================================
# These are commented out by default to save time.
# Uncomment to compare IID vs Non-IID performance (~10 min total)

# print("Running IID experiment...")
# history_iid = run_experiment("iid", num_rounds=3)
# iid_acc = [acc for _, acc in history_iid.metrics_distributed["accuracy"]]
# print(f"IID Final Accuracy: {iid_acc[-1]*100:.2f}%")

# print("\nRunning Non-IID (alpha=0.1) experiment...")
# history_noniid = run_experiment("dirichlet", alpha=0.1, num_rounds=3)
# noniid_acc = [acc for _, acc in history_noniid.metrics_distributed["accuracy"]]
# print(f"Non-IID Final Accuracy: {noniid_acc[-1]*100:.2f}%")

print("Comparison function defined!")
print("Uncomment the experiment lines above to run IID vs Non-IID comparison.")

---
## Cell 14: Examine a Single Client's Training (Educational)

This cell shows what happens on a single client during federated learning.

In [None]:
print("Examining Client 0's local training...")
print("="*50)

# Create a client
client = FlowerClient(partition_id=0)

# Show client's data distribution
partition = fds.load_partition(0, "train")
labels = [sample["label"] for sample in partition]
unique, counts = np.unique(labels, return_counts=True)

print(f"\nClient 0 data distribution:")
for label, count in zip(unique, counts):
    bar = "‚ñà" * (count // 50)
    print(f"  Digit {label}: {count:4d} samples {bar}")

# Train for a few epochs and show progress
print(f"\nTraining Client 0 for {LOCAL_EPOCHS} epoch(s)...")
model = MNISTNet().to(DEVICE)
trainloader, valloader = load_client_data(0)

# Evaluate before training
loss_before, acc_before = evaluate(model, valloader)
print(f"  Before training: Loss={loss_before:.4f}, Accuracy={acc_before*100:.2f}%")

# Train
train_loss = train(model, trainloader, LOCAL_EPOCHS)

# Evaluate after training
loss_after, acc_after = evaluate(model, valloader)
print(f"  After training:  Loss={loss_after:.4f}, Accuracy={acc_after*100:.2f}%")
print(f"  Improvement: +{(acc_after - acc_before)*100:.2f}%")

---
## Cell 15: Summary and Key Takeaways

### What We Learned:

1. **Data Partitioning** is the key difference between centralized and federated learning
   - IID: Data uniformly distributed across clients
   - Non-IID: Data heterogeneously distributed (realistic scenario)

2. **Flower Framework** makes FL easy:
   - `FederatedDataset`: Automatic dataset partitioning
   - `FlowerClient`: Define client behavior
   - `start_simulation`: Run FL on a single machine

3. **FedAvg Algorithm**:
   - Server sends global model to clients
   - Clients train locally
   - Server averages client models (weighted by data size)

4. **Non-IID Challenge**:
   - Lower Dirichlet Œ± = more heterogeneous data
   - Non-IID data can slow convergence and reduce accuracy

### üöÄ Scale Up for Full Experiments

Now that you've verified the pipeline works, try these settings for better results:

```python
# In Cell 3, change to:
NUM_CLIENTS = 10
NUM_ROUNDS = 10
CLIENTS_PER_ROUND = 5
BATCH_SIZE = 32
```

Expected results with full settings:
- **IID**: ~95%+ accuracy
- **Non-IID (Œ±=0.5)**: ~90%+ accuracy
- **Non-IID (Œ±=0.1)**: ~80-85% accuracy

### Next Steps:
- Try different `DIRICHLET_ALPHA` values (0.1, 0.5, 1.0, 10.0)
- Increase `NUM_ROUNDS` for better convergence
- Experiment with different models and optimizers
- Explore other FL algorithms (FedProx, SCAFFOLD, etc.)

In [None]:
print("\n" + "="*60)
print("   FEDERATED LEARNING TUTORIAL COMPLETE!")
print("="*60)
print("\nYou've learned how to:")
print("  ‚úì Partition datasets for federated learning")
print("  ‚úì Create FL clients using Flower")
print("  ‚úì Run FL simulations")
print("  ‚úì Visualize data distribution and training progress")
print("\nHappy Federated Learning! üå∏")