# Using Stable Baseline 3 to solve Lunar Lander

References:
- [Lunar Lander - Gymnasium Documentation](https://gymnasium.farama.org/environments/box2d/lunar_lander/)
- [Examples — Stable Baselines3](https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#basic-usage-training-saving-loading)

In [2]:
import gymnasium as gym
import numpy as np

from stable_baselines3 import DQN

In [None]:
model = DQN(
    "MlpPolicy",
    "LunarLander-v3",
    verbose=1,
    exploration_final_eps=0.1,
    target_update_interval=250,
)

In [4]:
from stable_baselines3.common.evaluation import evaluate_policy

In [5]:
from stable_baselines3.common.evaluation import evaluate_policy

In [None]:
# Separate env for evaluation
eval_env = gym.make("LunarLander-v3")

# Random Agent, before training
mean_reward, std_reward = evaluate_policy(
    model,
    eval_env,
    n_eval_episodes=10,
    deterministic=True,
)

print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")

In [None]:
# Train the agent
model.learn(total_timesteps=int(1e5), progress_bar=True)
# Save the agent
model.save("dqn_lunar")

In [None]:
# Evaluate the trained agent
mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, deterministic=True)

print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")

In [9]:
from typing import Any, Callable, Optional, Sequence, SupportsFloat

import gymnasium as gym
from gymnasium.core import ActType, ObsType, RenderFrame

from matplotlib import pyplot as plt
from matplotlib.animation import ArtistAnimation, TimedAnimation

def plt_animation(frames: Sequence[RenderFrame], fps: int) -> TimedAnimation:
    """Generate a Pyplot animation from a sequence of Gymnasium environment RGB rendering."""
    fig, ax = plt.subplots()
    ax.set_axis_off()
    imgs = []
    if len(frames) > 0:
        imgs.append([ax.imshow(frames[0])])
        for a in frames[1:]:
            imgs.append([ax.imshow(a, animated=True)])
    # prevent showing pyplot default window
    plt.close(fig)
    return ArtistAnimation(fig, imgs, interval=int(1000/fps), repeat=False, blit=True)

def run_episode(env: gym.Env, agent: Optional[Callable[[ObsType], ActType]]=None) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
    """Run an episode of the environment using the provided agent (default to random), returning the last `step` method output"""
    observation, info = env.reset()
    if agent == None:
        agent = lambda o: env.action_space.sample()
    while True:
        action = agent(observation)
        observation, reward, terminated, truncated, info = env.step(action)
        if terminated or truncated:
            return observation, reward, terminated, truncated, info

In [None]:
from IPython import display

env = gym.wrappers.RenderCollection(gym.wrappers.RecordEpisodeStatistics(gym.make("LunarLander-v3", render_mode='rgb_array')), pop_frames=False)

def agent_from_model(model: DQN) -> Callable[[ObsType], ActType]:
    def __agent(obs: ObsType) -> ActType:
        return model.predict(obs, deterministic=True)[0]
    return __agent

observation, reward, terminated, truncated, info = run_episode(env, agent_from_model(model))


landed = terminated and reward >= 100
total = info.get('episode', {}).get('r', None)
print("{}, with a total reward of {}".format('Landed' if landed else 'Crashed', total))
display.HTML(plt_animation(env.render(), fps=30).to_html5_video())