Skip to content

Commit

Permalink
Add gymnasium support to SAC (#378)
Browse files Browse the repository at this point in the history
* Update sac_atari.py and sac_continuous_action.py to gymnasium's api

* Add testing

* #383

* move test file

* fix final_info bug

* clean up mujoco tests

* update ci

* fix tests scripts

* Comment out test-mujoco-envs-mac

* fix final_observation

* test_pybullet.py

---------

Co-authored-by: Adam Zhao <pazyx728@gmail.com>
  • Loading branch information
pseudo-rnd-thoughts and sdpkjc committed Oct 15, 2023
1 parent 1c4e8b9 commit d8d4ebf
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 237 deletions.
117 changes: 28 additions & 89 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -117,36 +117,6 @@ jobs:
run: poetry run pytest tests/test_procgen.py

test-mujoco-envs:
strategy:
fail-fast: false
matrix:
python-version: [3.8]
poetry-version: [1.3.1]
os: [ubuntu-22.04]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Run image
uses: abatilo/actions-poetry@v2.0.0
with:
poetry-version: ${{ matrix.poetry-version }}

# mujoco tests
- name: Install dependencies
run: poetry install -E "pytest mujoco dm_control jax"
- name: Downgrade setuptools
run: poetry run pip install setuptools==59.5.0
- name: install mujoco dependencies
run: |
sudo apt-get update && sudo apt-get -y install libgl1-mesa-glx libosmesa6 libglfw3
- name: Run mujoco tests
continue-on-error: true # MUJOCO_GL=osmesa results in `free(): invalid pointer`
run: poetry run pytest tests/test_mujoco.py

test-mujoco-gymnasium-envs:
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -176,68 +146,37 @@ jobs:
sudo apt-get update && sudo apt-get -y install libgl1-mesa-glx libosmesa6 libglfw3
- name: Run mujoco tests
continue-on-error: true # MUJOCO_GL=osmesa results in `free(): invalid pointer`
run: poetry run pytest tests/test_mujoco_gymnasium.py

test-mujoco-envs-mac:
strategy:
fail-fast: false
matrix:
python-version: [3.8]
poetry-version: [1.3.1]
os: [macos-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Run image
uses: abatilo/actions-poetry@v2.0.0
with:
poetry-version: ${{ matrix.poetry-version }}

# mujoco tests
- name: Install dependencies
run: poetry install -E "pytest mujoco dm_control jax"
- name: Downgrade setuptools
run: poetry run pip install setuptools==59.5.0
- name: Run mujoco tests
run: poetry run pytest tests/test_mujoco.py

test-mujoco_py-envs:
strategy:
fail-fast: false
matrix:
python-version: [3.8]
poetry-version: [1.3.1]
os: [ubuntu-22.04]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Run image
uses: abatilo/actions-poetry@v2.0.0
with:
poetry-version: ${{ matrix.poetry-version }}
# test-mujoco-envs-mac:
# strategy:
# fail-fast: false
# matrix:
# python-version: [3.8]
# poetry-version: [1.3.1]
# os: [macos-latest]
# runs-on: ${{ matrix.os }}
# steps:
# - uses: actions/checkout@v2
# - uses: actions/setup-python@v2
# with:
# python-version: ${{ matrix.python-version }}
# - name: Run image
# uses: abatilo/actions-poetry@v2.0.0
# with:
# poetry-version: ${{ matrix.poetry-version }}

# mujoco_py tests
- name: Install dependencies
run: poetry install -E "pytest mujoco_py mujoco jax"
- name: Downgrade setuptools
run: poetry run pip install setuptools==59.5.0
- name: install mujoco_py dependencies
run: |
sudo apt-get update && sudo apt-get -y install wget unzip software-properties-common \
libgl1-mesa-dev \
libgl1-mesa-glx \
libglew-dev \
libosmesa6-dev patchelf
- name: Run mujoco_py tests
run: poetry run pytest tests/test_mujoco_py.py
# # mujoco tests
# - name: Install dependencies
# run: poetry install -E "pytest mujoco dm_control jax"
# - name: Downgrade setuptools
# run: poetry run pip install setuptools==59.5.0
# - name: Run gymnasium migration dependencies
# run: poetry run pip install "stable_baselines3==2.0.0a1"
# - name: Run mujoco tests
# run: poetry run pytest tests/test_mujoco.py

test-mujoco_py-envs-gymnasium:
test-mujoco_py-envs:
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -270,7 +209,7 @@ jobs:
libglew-dev \
libosmesa6-dev patchelf
- name: Run mujoco_py tests
run: poetry run pytest tests/test_mujoco_py_gymnasium.py
run: poetry run pytest tests/test_mujoco_py.py

test-envpool-envs:
strategy:
Expand Down
47 changes: 30 additions & 17 deletions cleanrl/sac_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from distutils.util import strtobool

import gym
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -79,11 +79,13 @@ def parse_args():

def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")

env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
Expand All @@ -93,9 +95,8 @@ def thunk():
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.seed(seed)

env.action_space.seed(seed)
env.observation_space.seed(seed)
return env

return thunk
Expand Down Expand Up @@ -174,6 +175,15 @@ def get_action(self, x):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1"
"""
)

args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
Expand Down Expand Up @@ -231,12 +241,12 @@ def get_action(self, x):
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, info = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
Expand All @@ -246,22 +256,25 @@ def get_action(self, x):
actions = actions.detach().cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
if "final_info" in infos:
for info in infos["final_info"]:
# Skip the envs that are not done
if "episode" not in info:
continue
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
42 changes: 25 additions & 17 deletions cleanrl/sac_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from distutils.util import strtobool

import gym
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -71,14 +71,13 @@ def parse_args():

def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env

return thunk
Expand Down Expand Up @@ -145,6 +144,15 @@ def get_action(self, x):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1"
"""
)

args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
Expand Down Expand Up @@ -204,12 +212,12 @@ def get_action(self, x):
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, info = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
Expand All @@ -219,22 +227,22 @@ def get_action(self, x):
actions = actions.detach().cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
if "final_info" in infos:
for info in infos["final_info"]:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
8 changes: 0 additions & 8 deletions tests/test_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,3 @@ def test_ppo_lstm():
shell=True,
check=True,
)


def test_sac():
subprocess.run(
"python cleanrl/sac_atari.py --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4",
shell=True,
check=True,
)
8 changes: 8 additions & 0 deletions tests/test_atari_gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,11 @@ def test_c51_atari_eval():
shell=True,
check=True,
)


def test_sac():
subprocess.run(
"python cleanrl/sac_atari.py --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4",
shell=True,
check=True,
)
Loading

1 comment on commit d8d4ebf

@vercel
Copy link

@vercel vercel bot commented on d8d4ebf Oct 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

cleanrl – ./

cleanrl-git-master-vwxyzjn.vercel.app
cleanrl-vwxyzjn.vercel.app
docs.cleanrl.dev
cleanrl.vercel.app

Please sign in to comment.