<a href="https://colab.research.google.com/github/victorkobani/Federated-Deep-Reinforcement-Learning/blob/main/Lunar_Lander_Standalone_DQN_Script.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

INSTALL DEPENDENCIES

In [None]:
# Installing required dependencies
!apt-get update
!apt-get install -y swig cmake ffmpeg freeglut3-dev xvfb

# Installing more dependencies
!pip install "gymnasium[box2d]"
!pip install "stable-baselines3[extra]>=2.7.0"
!pip install "huggingface_sb3>=3.0"
!pip install "moviepy>=2.2.1"

IMPORTS

In [None]:
import gymnasium as gym
import os
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
from IPython.display import HTML, display
from base64 import b64encode

CREATE GYM ENVIRONMENT AND INSTANTIATE AGENT

In [None]:
model = DQN(
    "MlpPolicy",
    "LunarLander-v3",
    verbose=1,
    exploration_final_eps=0.1,
    target_update_interval=250,
)

EVALUATE UNTRAINED AGENT

In [None]:
# Let's evaluate the un-trained agent, this should be a random agent.
eval_env = gym.make("LunarLander-v3")
mean_reward, std_reward = evaluate_policy(
    model,
    eval_env,
    n_eval_episodes=20,
    deterministic=True,
)
print(f"Untrained agent mean_reward={mean_reward:.2f} +/- {std_reward}")

SETUP CALLBACK AND TRAIN THE AGENT

In [None]:
log_dir = "/tmp/gym_logs/"
os.makedirs(log_dir, exist_ok=True)

eval_env_monitored = Monitor(gym.make("LunarLander-v3"))

eval_callback = EvalCallback(
    eval_env_monitored,
    best_model_save_path=os.path.join(log_dir, 'best_model'),
    log_path=os.path.join(log_dir, 'results'),
    eval_freq=5000,
    deterministic=True,
    render=False
)

print("\n--- Starting Training ---")
model.learn(
    total_timesteps=int(1e6),
    log_interval=400,
    progress_bar=True,
    callback=eval_callback
)
model.save("dqn_lunar_v3")
del model  # delete trained model to demonstrate loading

LOAD AND EVALUATE TRAINED AGENT

In [None]:
print("\n--- Loading and Evaluating Final Model ---")
model = DQN.load("dqn_lunar_v3")
mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=20, deterministic=True)
print(f"Final trained agent mean_reward={mean_reward:.2f} +/- {std_reward}")

PLOT THE RESULTS

In [None]:
print("\n--- Plotting Training Progress ---")

# The log_path for EvalCallback ('results') is a FOLDER.
# The actual data is in a file named 'evaluations.npz' inside that folder.
results_path = os.path.join(log_dir, "results")
log_file = os.path.join(results_path, "evaluations.npz")

if os.path.exists(log_file):
    print(f"Loading log file from: {log_file}")
    # Load the saved data
    data = np.load(log_file)

    timesteps = data['timesteps']
    mean_rewards = data['results'][:, 0]

    # Create the plot
    plt.figure(figsize=(12, 6))
    plt.title("Standalone DQN Training Performance on LunarLander-v3")
    plt.xlabel("Training Timesteps")
    plt.ylabel("Average Reward")
    plt.plot(timesteps, mean_rewards)
    plt.axhline(y=200, color='r', linestyle='--', label='Success Threshold (200)')
    plt.legend()
    plt.grid(True)
    plt.show()
else:
    print(f"Log file not found at {log_file}. Cannot plot results.")

RECORD VIDEO OF TRAINED AGENT

In [None]:
print("\n--- Recording Video ---")

env_id = "LunarLander-v3"
video_folder = "logs/videos/"
video_length = 6000
os.makedirs(video_folder, exist_ok=True)

vec_env = DummyVecEnv([lambda: gym.make(env_id, render_mode="rgb_array")])
obs = vec_env.reset()

vec_env = VecVideoRecorder(vec_env, video_folder,
                       record_video_trigger=lambda x: x == 0, video_length=video_length,
                       name_prefix=f"dqn-agent-{env_id}")

vec_env.reset()
for _ in range(video_length + 1):
  action, _state = model.predict(obs, deterministic=True)
  obs, _, _, _ = vec_env.step(action)
vec_env.close()

mp4_path = f'./logs/videos/dqn-agent-{env_id}-step-0-to-step-6000.mp4'
if os.path.exists(mp4_path):
    mp4 = open(mp4_path,'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    display(HTML("""
    <video width=400 controls>
          <source src="%s" type="video/mp4">
    </video>
    """ % data_url))
else:
    print(f"Video file not found at {mp4_path}")