<a href="https://colab.research.google.com/github/victorkobani/Federated-Deep-Reinforcement-Learning/blob/main/Lunar_Lander_Federated_PPO_Script.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

INSTALL DEPENDENCIES AND IMPORTS

In [None]:
# Install all necessary libraries
!apt-get update && apt-get install -y swig cmake ffmpeg
!pip uninstall -y gym
!pip install -U flwr[simulation] stable-baselines3[extra] gymnasium[box2d] moviepy

import flwr as fl
import gymnasium as gym
import numpy as np
import torch
from collections import OrderedDict
from typing import Dict, List, Tuple

from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
import matplotlib.pyplot as plt

DEFINE HYPERPARAMETERS

In [None]:
# --- Environment ---
ENV_NAME = "LunarLander-v3"

# --- SB3 PPO Hyperparameters ---
ppo_params = {
    "policy": "MlpPolicy",
    "verbose": 0, # Set to 0 to prevent SB3 from printing its own logs inside the client
}

# --- Federation Hyperparameters ---
NUM_CLIENTS = 5
NUM_ROUNDS = 25
LOCAL_TIMESTEPS_PER_ROUND = 8192

THE FLOWER CLIENT (wrapping a Stable-Baselines3 Agent)

In [None]:
class SB3FlowerClient(fl.client.NumPyClient):
    def __init__(self):
        """Initialize the Flower client."""
        self.env = gym.make(ENV_NAME)
        self.model = PPO(env=self.env, **ppo_params)

    def get_parameters(self, config: Dict) -> List[np.ndarray]:
        """Gets the parameters of the local PPO model's policy network."""
        return [val.cpu().numpy() for _, val in self.model.policy.state_dict().items()]

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        """Sets the parameters of the local PPO model's policy network."""
        params_dict = zip(self.model.policy.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.policy.load_state_dict(state_dict, strict=True)

    def fit(self, parameters: List[np.ndarray], config: Dict) -> Tuple[List[np.ndarray], int, Dict]:
        """Train the local PPO model."""
        self.set_parameters(parameters)
        self.model.learn(
            total_timesteps=int(config["local_timesteps"]),
            reset_num_timesteps=False
        )
        return self.get_parameters(config={}), int(self.model.num_timesteps), {}

    def evaluate(self, parameters: List[np.ndarray], config: Dict) -> Tuple[float, int, Dict]:
        """Evaluate the global PPO model."""
        self.set_parameters(parameters)
        mean_reward, _ = evaluate_policy(
            self.model, self.env, n_eval_episodes=10, deterministic=True
        )
        # FedAvg minimizes loss, so we return negative reward to maximize reward
        return -float(mean_reward), 10, {"avg_reward": float(mean_reward)}

FEDERATED SIMULATION SETUP

In [None]:
def client_fn(cid: str) -> SB3FlowerClient:
    return SB3FlowerClient()

def fit_config(server_round: int) -> Dict:
    """Return training configuration dict for each round."""
    return {
        "server_round": server_round,
        "local_timesteps": LOCAL_TIMESTEPS_PER_ROUND,
    }

def evaluate_metrics_aggregation_fn(metrics: List[Tuple[int, Dict[str, float]]]) -> Dict[str, float]:
    """Aggregate evaluation results."""
    if not metrics:
        return {}
    # Access the history object via a global or other means if needed for round number
    # For simplicity, we'll just print based on the list content
    avg_rewards = [m["avg_reward"] for _, m in metrics]
    mean_reward = sum(avg_rewards) / len(avg_rewards)
    # The 'history' object is not available here, so we can't easily get the round number.
    # A simple print statement will suffice.
    print(f"Aggregated evaluation results | Average Reward: {mean_reward:.2f}")
    return {"avg_reward": mean_reward}

# Define the FedAvg strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=1.0,
    min_fit_clients=NUM_CLIENTS,
    min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    on_fit_config_fn=fit_config,
    evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
)

# Start the simulation
print(f"--- Starting Federated PPO Simulation using Stable-Baselines3 ---")
history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
    client_resources={"num_cpus": 1},
)
print("\n--- Federated Learning Final Results ---")

PLOT THE RESULTS

In [None]:
# Check history.metrics_distributed, not history.metrics_centralized
if history.metrics_distributed and "avg_reward" in history.metrics_distributed:
    # Unpack the list of tuples (round, metric)
    rounds, rewards = zip(*history.metrics_distributed["avg_reward"])

    plt.figure(figsize=(12, 6))
    plt.title("Federated PPO with Stable-Baselines3 - Average Reward")
    plt.xlabel("Federated Round")
    plt.ylabel("Average Reward")
    plt.plot(rounds, rewards)
    plt.axhline(y=200, color='r', linestyle='--', label='Success Threshold (200)')
    plt.legend()
    plt.grid(True)
    plt.show()
else:
    print("No evaluation metrics were recorded.")