In [1]:
import gymnasium
import torch
import numpy as np
from stable_baselines3 import DQN
import imageio

# import gym

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


In [2]:
import warnings
warnings.simplefilter("ignore", UserWarning)

In [None]:
# -------------------------------------------------------
# 1. Load Stable-Baselines3 Pretrained DQN LunarLander Model
# -------------------------------------------------------

# Assuming the pretrained model file "dqn_lunarlander.zip"
# is in your working directory.
model_path = "../../rl-baselines3-zoo/rl-trained-agents/dqn/LunarLander-v2_1/LunarLander-v2.zip"
# model_path = "lunar_lander_retrained.zip"

# Load the model
model = DQN.load(model_path)

# Retrieve the underlying PyTorch Q-network
policy = model.policy
q_net = policy.q_net   # This is a torch.nn.Module that maps state â†’ Q-values

In [6]:
# -------------------------------------------------------
# 2. Run 1 Episode and Record Frames
# -------------------------------------------------------

env = gymnasium.make("LunarLander-v3", render_mode="rgb_array")

frames = []

state, _ = env.reset()

while True:

    state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
    with torch.no_grad():
        q_values = q_net(state_tensor).cpu().numpy().squeeze()

    action, _ = model.predict(state, deterministic=True)
    action = int(action.item())

    next_state, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated

    frames.append(env.render())
    if done:
        break

    state = next_state

imageio.mimwrite("lunar_lander_pretrained.gif", frames, fps=10, loop=0)

env.close()