<a href="https://colab.research.google.com/github/rumeshsmrr/reinforcement-learning-lunarlander/blob/main/a2c_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!apt-get install -y swig ffmpeg
!pip install gymnasium[box2d] stable-baselines3[extra]


Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).
Suggested packages:
  swig-doc swig-examples swig4.0-examples swig4.0-doc
The following NEW packages will be installed:
  swig swig4.0
0 upgraded, 2 newly installed, 0 to remove and 34 not upgraded.
Need to get 1,116 kB of archives.
After this operation, 5,542 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 swig4.0 amd64 4.0.2-1ubuntu1 [1,110 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 swig all 4.0.2-1ubuntu1 [5,632 B]
Fetched 1,116 kB in 1s (1,534 kB/s)
Selecting previously unselected package swig4.0.
(Reading database ... 126333 files and directories currently installed.)
Preparing to unpack .../swig4.0_4.0.2-1ubuntu1_amd64.deb ...
Unpacking swig4.0 (4.0.2-1ubuntu1) ...
Selecting previously unselected package swig.
Preparing to unpack .../swig_4.0.2-1u

In [4]:
import os
import gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.evaluation import evaluate_policy
from gymnasium.wrappers import RecordVideo
from IPython.display import HTML
from base64 import b64encode


In [5]:
def display_video(video_path):
    if not os.path.exists(video_path):
        print(f"Video file {video_path} not found.")
        return
    with open(video_path, 'rb') as f:
        mp4 = f.read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    return HTML(f"""
    <video width=600 controls>
        <source src="{data_url}" type="video/mp4">
    </video>
    """)


In [6]:
# Folder to save training videos
video_folder = "./a2c_training_videos"
os.makedirs(video_folder, exist_ok=True)

# Training settings
total_timesteps = 2_000_000  # 2 million timesteps
record_every_timesteps = 500_000  # Save video after every 500k steps


In [7]:
# Main training environment
train_env = gym.make("LunarLander-v3")

# A2C model
model = A2C("MlpPolicy", train_env, verbose=1)


  from pkg_resources import resource_stream, resource_exists
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




In [8]:
for step in range(0, total_timesteps, record_every_timesteps):
    # Train the model
    model.learn(total_timesteps=record_every_timesteps, reset_num_timesteps=False)

    # Set up video recording
    eval_env = gym.make("LunarLander-v3", render_mode="rgb_array")
    eval_env = RecordVideo(eval_env, video_folder=video_folder, name_prefix=f"a2c_step_{step + record_every_timesteps}", episode_trigger=lambda x: True)

    obs, _ = eval_env.reset()
    done = False
    while not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, _ = eval_env.step(action)
        done = terminated or truncated

    eval_env.close()

    # Display the latest video
    video_files = [f for f in os.listdir(video_folder) if f.endswith('.mp4')]
    video_files.sort()
    latest_video = os.path.join(video_folder, video_files[-1])
    display(display_video(latest_video))


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
|    value_loss         | 10.6     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 464      |
|    ep_rew_mean        | 68.8     |
| time/                 |          |
|    fps                | 402      |
|    iterations         | 70700    |
|    time_elapsed       | 878      |
|    total_timesteps    | 353500   |
| train/                |          |
|    entropy_loss       | -0.207   |
|    explained_variance | 0.327    |
|    learning_rate      | 0.0007   |
|    n_updates          | 70699    |
|    policy_loss        | 0.0391   |
|    value_loss         | 1.06     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 472      |
|    ep_rew_mean        | 68.3     |
| time/                 |          |
|    fps                | 402      |
|    itera

  logger.warn(
  """


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
|    value_loss         | 0.0853   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 474      |
|    ep_rew_mean        | 42.4     |
| time/                 |          |
|    fps                | 395      |
|    iterations         | 70700    |
|    time_elapsed       | 893      |
|    total_timesteps    | 853500   |
| train/                |          |
|    entropy_loss       | -0.604   |
|    explained_variance | 0.387    |
|    learning_rate      | 0.0007   |
|    n_updates          | 170699   |
|    policy_loss        | -0.214   |
|    value_loss         | 1.27     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 469      |
|    ep_rew_mean        | 44.5     |
| time/                 |          |
|    fps                | 395      |
|    itera

  logger.warn(


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
|    value_loss         | 13.4     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 400      |
|    ep_rew_mean        | 119      |
| time/                 |          |
|    fps                | 398      |
|    iterations         | 70700    |
|    time_elapsed       | 887      |
|    total_timesteps    | 1353500  |
| train/                |          |
|    entropy_loss       | -0.334   |
|    explained_variance | 0.621    |
|    learning_rate      | 0.0007   |
|    n_updates          | 270699   |
|    policy_loss        | 0.828    |
|    value_loss         | 0.97     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 407      |
|    ep_rew_mean        | 118      |
| time/                 |          |
|    fps                | 398      |
|    itera

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
|    value_loss         | 4.18     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 449      |
|    ep_rew_mean        | 140      |
| time/                 |          |
|    fps                | 404      |
|    iterations         | 70700    |
|    time_elapsed       | 873      |
|    total_timesteps    | 1853500  |
| train/                |          |
|    entropy_loss       | -0.208   |
|    explained_variance | 0.992    |
|    learning_rate      | 0.0007   |
|    n_updates          | 370699   |
|    policy_loss        | -0.00459 |
|    value_loss         | 0.122    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 448      |
|    ep_rew_mean        | 143      |
| time/                 |          |
|    fps                | 404      |
|    itera

In [9]:
mean_reward, std_reward = evaluate_policy(model, train_env, n_eval_episodes=10)
print(f"Final A2C agent performance: {mean_reward:.2f} +/- {std_reward:.2f}")




Final A2C agent performance: 108.68 +/- 123.33
