<a href="https://colab.research.google.com/github/victorkobani/Federated-Deep-Reinforcement-Learning/blob/main/Lunar_Lander_Federated_DQN_Script.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

INSTALL DEPENDENCIES AND IMPORTS

In [None]:
!apt-get update && apt-get install -y swig cmake ffmpeg
!pip uninstall -y gym
!pip install -U flwr[simulation] stable-baselines3[extra] gymnasium[box2d] moviepy

import flwr as fl
import gymnasium as gym
import numpy as np
import torch
from collections import OrderedDict
from typing import Dict, List, Tuple

from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
import matplotlib.pyplot as plt

DEFINE HYPERPARAMETERS

In [None]:
# --- Environment ---
ENV_NAME = "LunarLander-v3"

# --- SB3 DQN Hyperparameters (as per a standard setup) ---
DQN_PARAMS = {
    "policy": "MlpPolicy",
    "learning_rate": 1e-4,
    "buffer_size": 50000,
    "learning_starts": 1000,
    "batch_size": 64,
    "gamma": 0.99,
    "train_freq": (4, "step"),
    "gradient_steps": 1,
    "target_update_interval": 250,
    "exploration_fraction": 0.12,
    "exploration_final_eps": 0.1,
    "verbose": 0, # Set to 0 to prevent SB3 from printing its own logs
}

# --- Federation Hyperparameters ---
NUM_CLIENTS = 5
NUM_ROUNDS = 25
LOCAL_TIMESTEPS_PER_ROUND = 4096 # Adjusted for DQN, which is often more sample-efficient per update cycle

THE FLOWER CLIENT (wrapping a Stable-Baselines3 DQN Agent)

In [None]:
class DQNFlowerClient(fl.client.NumPyClient):
    def __init__(self):
        self.env = gym.make(ENV_NAME)
        # Instantiate the DQN model with our defined hyperparameters
        self.model = DQN(env=self.env, **DQN_PARAMS)

    def get_parameters(self, config: Dict) -> List[np.ndarray]:
        """Gets the parameters of the local DQN model's policy network."""
        # The parameter extraction logic is identical to PPO
        return [val.cpu().numpy() for _, val in self.model.policy.state_dict().items()]

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        """Sets the parameters of the local DQN model's policy network."""
        # The parameter setting logic is identical to PPO
        params_dict = zip(self.model.policy.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.policy.load_state_dict(state_dict, strict=True)

    def fit(self, parameters: List[np.ndarray], config: Dict) -> Tuple[List[np.ndarray], int, Dict]:
        """Train the local DQN model."""
        self.set_parameters(parameters)
        # The learn method is called the same way as in PPO
        self.model.learn(
            total_timesteps=int(config["local_timesteps"]),
            reset_num_timesteps=False
        )
        return self.get_parameters(config={}), int(self.model.num_timesteps), {}

    def evaluate(self, parameters: List[np.ndarray], config: Dict) -> Tuple[float, int, Dict]:
        """Evaluate the global DQN model."""
        self.set_parameters(parameters)
        # The evaluation logic is identical to PPO
        mean_reward, _ = evaluate_policy(
            self.model, self.env, n_eval_episodes=10, deterministic=True
        )
        # FedAvg minimizes loss, so we return negative reward to maximize reward
        return -float(mean_reward), 10, {"avg_reward": float(mean_reward)}

FEDERATED SIMULATION SETUP

In [None]:
def client_fn(cid: str) -> DQNFlowerClient:
    """Factory function to create a new client."""
    return DQNFlowerClient()

def fit_config(server_round: int) -> Dict:
    """Return training configuration dict for each round."""
    # This function is identical to the PPO example
    return {
        "server_round": server_round,
        "local_timesteps": LOCAL_TIMESTEPS_PER_ROUND,
    }

def evaluate_metrics_aggregation_fn(metrics: List[Tuple[int, Dict[str, float]]]) -> Dict[str, float]:
    """Aggregate evaluation results from multiple clients."""
    # This function is identical to the PPO example
    if not metrics:
        return {}
    # Calculate the mean of the 'avg_reward' metric from all clients
    avg_rewards = [m["avg_reward"] for _, m in metrics]
    mean_reward = sum(avg_rewards) / len(avg_rewards)
    print(f"Aggregated evaluation results | Average Reward: {mean_reward:.2f}")
    return {"avg_reward": mean_reward}

# Define the FedAvg strategy, using the same robust setup as the PPO example
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=1.0,
    min_fit_clients=NUM_CLIENTS,
    min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    on_fit_config_fn=fit_config,
    evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
)

# Start the simulation
print(f"--- Starting Federated DQN Simulation using Stable-Baselines3 ---")
history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
    client_resources={"num_cpus": 1}, # Assuming CPU-only training
)

PLOT THE RESULTS

In [None]:
print("\n--- Federated Learning Final Results ---")

# The plotting logic is identical to the PPO example
if history.metrics_distributed and "avg_reward" in history.metrics_distributed:
    # Unpack the list of tuples (round, metric)
    rounds, rewards = zip(*history.metrics_distributed["avg_reward"])

    plt.figure(figsize=(12, 6))
    plt.title("Federated DQN with Stable-Baselines3 - Average Reward")
    plt.xlabel("Federated Round")
    plt.ylabel("Average Reward")
    plt.plot(rounds, rewards)
    plt.axhline(y=200, color='r', linestyle='--', label='Success Threshold (200)')
    plt.legend()
    plt.grid(True)
    plt.show()
else:
    print("No distributed evaluation metrics were recorded.")