Skip to content

Commit

Permalink
fix: correctly close gymnasium environments (#340)
Browse files Browse the repository at this point in the history
This commit ensures that gymnasium environments get correctly closed
down when used.
  • Loading branch information
rickstaa committed Aug 25, 2023
1 parent 90e16e9 commit a179176
Show file tree
Hide file tree
Showing 13 changed files with 30 additions and 5 deletions.
4 changes: 3 additions & 1 deletion examples/manual_env_policy_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)

AGENT_TYPE = "torch" # The type of agent that was trained. Options: 'tf2' and 'torch'.
AGENT_FOLDER = "/home/ricks/Development/stable-learning-control/data/2022-02-24_staa_lac_panda_reach/2022-02-24_09-12-19-staa_lac_panda_reach_s25" # noqa: E501
AGENT_FOLDER = "/home/ricks/development/work/stable-learning-control/data/cmd_lac_pytorch/cmd_lac_pytorch_s0" # noqa: E501

if __name__ == "__main__":
# NOTE: STEP 1a: Try to load the policy and environment
Expand Down Expand Up @@ -45,3 +45,5 @@
"'AGENT_FOLDER' and try again. If the problem persists please open a issue "
"on https://github.com/rickstaa/stable-learning-control/issues."
)

env.close()
3 changes: 3 additions & 0 deletions sandbox/test_finite_horizon_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,6 @@
print("Environment terminated or truncated. Resetting.")
o, _ = env.reset()
ep_ret, ep_len, t = 0, 0, 0

print("Done")
env.close()
2 changes: 2 additions & 0 deletions sandbox/test_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,6 @@
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles, labels, loc=2, fancybox=False, shadow=False)
plt.show()

print("Done")
env.close()
3 changes: 3 additions & 0 deletions sandbox/test_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,6 @@
print("Environment terminated or truncated. Resetting.")
o, _ = env.reset()
ep_ret, ep_len = 0, 0

print("Done")
env.close()
3 changes: 3 additions & 0 deletions sandbox/test_traj_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,6 @@

# Print data.
print(f"Epoch {epoch}:")

print("Done")
env.close()
2 changes: 2 additions & 0 deletions stable_learning_control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,8 @@ def lac(
type="info",
)

# Close environment and return policy and replay buffer.
env.close()
return policy, replay_buffer


Expand Down
2 changes: 2 additions & 0 deletions stable_learning_control/algos/pytorch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,8 @@ def sac(
type="info",
)

# Close environment and return policy and replay buffer.
env.close()
return policy, replay_buffer


Expand Down
2 changes: 2 additions & 0 deletions stable_learning_control/algos/tf2/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,8 @@ def lac(
type="info",
)

# Close environment and return policy and replay buffer.
env.close()
return policy, replay_buffer


Expand Down
2 changes: 2 additions & 0 deletions stable_learning_control/algos/tf2/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,8 @@ def sac(
type="info",
)

# Close environment and return policy and replay buffer.
env.close()
return policy, replay_buffer


Expand Down
3 changes: 2 additions & 1 deletion tests/algos/gpu/test_algos_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def env(self):
env.action_space.seed(0)
env.observation_space.seed(0)

return env
yield env
env.close()

def test_reproducibility(self, algo, device, snapshot, env):
"""Checks if the algorithm is still working as expected."""
Expand Down
3 changes: 2 additions & 1 deletion tests/algos/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def env(self):
env.action_space.seed(0)
env.observation_space.seed(0)

return env
yield env
env.close()

def test_reproducibility(self, algo, snapshot, env):
"""Checks if the algorithm is still working as expected."""
Expand Down
3 changes: 2 additions & 1 deletion tests/algos/tf2/gpu/test_tf2_algos_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def env(self):
env.action_space.seed(0)
env.observation_space.seed(0)

return env
yield env
env.close()

def test_reproducibility(self, algo, device, snapshot, env):
"""Checks if the algorithm is still working as expected."""
Expand Down
3 changes: 2 additions & 1 deletion tests/algos/tf2/test_tf2_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def env(self):
env.action_space.seed(0)
env.observation_space.seed(0)

return env
yield env
env.close()

def test_reproducibility(self, algo, snapshot, env):
"""Checks if the algorithm is still working as expected."""
Expand Down

0 comments on commit a179176

Please sign in to comment.