<a href="https://colab.research.google.com/github/victorkobani/Federated-Deep-Reinforcement-Learning/blob/main/Lunar_Lander_Federated_A2C_Script.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

INSATLL DEPENDENCIES AND IMPORTS

In [None]:
!apt-get update && apt-get install -y swig cmake ffmpeg
!pip uninstall -y gym
!pip install -U flwr[simulation] stable-baselines3[extra] gymnasium[box2d] moviepy

import flwr as fl
import gymnasium as gym
import numpy as np
import torch
from collections import OrderedDict
from typing import Dict, List, Tuple

from stable_baselines3 import A2C
from stable_baselines3.common.evaluation import evaluate_policy
import matplotlib.pyplot as plt

DEFINE HYPERPARAMETERS

In [None]:
# --- Environment ---
ENV_NAME = "LunarLander-v3"

# --- SB3 A2C Hyperparameters ---
A2C_PARAMS = {
    "policy": "MlpPolicy",
    "n_steps": 5,
    "gamma": 0.99,
    "gae_lambda": 1.0,
    "ent_coef": 0.0,
    "vf_coef": 0.5,
    "max_grad_norm": 0.5,
    "use_rms_prop": True,
    "learning_rate": lambda _: 0.0007, # A2C often uses a constant learning rate
    "verbose": 0, # Set to 0 to prevent SB3 from printing its own logs
}

# --- Federation Hyperparameters ---
NUM_CLIENTS = 5
NUM_ROUNDS = 30 # A2C may need more rounds
LOCAL_TIMESTEPS_PER_ROUND = 5000 # On-policy methods like A2C learn from data collected in the current round

THE FLOWER CLIENT (wrapping a Stable-Baselines3 A2C Agent)

In [None]:
class A2CFlowerClient(fl.client.NumPyClient):
    def __init__(self):
        self.env = gym.make(ENV_NAME)
        # Instantiate the A2C model with our defined hyperparameters
        self.model = A2C(env=self.env, **A2C_PARAMS)

    def get_parameters(self, config: Dict) -> List[np.ndarray]:
        """Gets the parameters of the local A2C model's policy network."""
        # This logic is identical for most SB3 models, including A2C
        return [val.cpu().numpy() for _, val in self.model.policy.state_dict().items()]

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        """Sets the parameters of the local A2C model's policy network."""
        # This logic is identical for most SB3 models, including A2C
        params_dict = zip(self.model.policy.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.policy.load_state_dict(state_dict, strict=True)

    def fit(self, parameters: List[np.ndarray], config: Dict) -> Tuple[List[np.ndarray], int, Dict]:
        """Train the local A2C model."""
        self.set_parameters(parameters)
        self.model.learn(
            total_timesteps=int(config["local_timesteps"]),
            reset_num_timesteps=False
        )
        return self.get_parameters(config={}), int(self.model.num_timesteps), {}

    def evaluate(self, parameters: List[np.ndarray], config: Dict) -> Tuple[float, int, Dict]:
        """Evaluate the global A2C model."""
        self.set_parameters(parameters)
        mean_reward, _ = evaluate_policy(
            self.model, self.env, n_eval_episodes=10, deterministic=True
        )
        # FedAvg minimizes loss, so we return negative reward to maximize reward
        return -float(mean_reward), 10, {"avg_reward": float(mean_reward)}

FEDERATED SIMULATION SETUP

In [None]:
def client_fn(cid: str) -> A2CFlowerClient:
    """Factory function to create a new client."""
    return A2CFlowerClient()

def fit_config(server_round: int) -> Dict:
    """Return training configuration dict for each round."""
    return {
        "server_round": server_round,
        "local_timesteps": LOCAL_TIMESTEPS_PER_ROUND,
    }

def evaluate_metrics_aggregation_fn(metrics: List[Tuple[int, Dict[str, float]]]) -> Dict[str, float]:
    """Aggregate evaluation results from multiple clients."""
    if not metrics:
        return {}
    avg_rewards = [m["avg_reward"] for _, m in metrics]
    mean_reward = sum(avg_rewards) / len(avg_rewards)
    print(f"Aggregated evaluation results | Average Reward: {mean_reward:.2f}")
    return {"avg_reward": mean_reward}

# Define the FedAvg strategy, using the same robust setup as the PPO example
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=1.0,
    min_fit_clients=NUM_CLIENTS,
    min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    on_fit_config_fn=fit_config,
    evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
)

# Start the simulation
print(f"--- Starting Federated A2C Simulation using Stable-Baselines3 ---")
history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
    client_resources={"num_cpus": 1}, # Assuming CPU-only training
)

PLOT THE RESULTS

In [None]:
print("\n--- Federated Learning Final Results ---")

if history.metrics_distributed and "avg_reward" in history.metrics_distributed:
    # Unpack the list of tuples (round, metric)
    rounds, rewards = zip(*history.metrics_distributed["avg_reward"])

    plt.figure(figsize=(12, 6))
    plt.title("Federated A2C with Stable-Baselines3 - Average Reward")
    plt.xlabel("Federated Round")
    plt.ylabel("Average Reward")
    plt.plot(rounds, rewards)
    plt.axhline(y=200, color='r', linestyle='--', label='Success Threshold (200)')
    plt.legend()
    plt.grid(True)
    plt.show()
else:
    print("No distributed evaluation metrics were recorded.")