Skip to content

Commit

Permalink
fix: ensure 'test_policy' works with gymnasium>=0.28.1 (#276)
Browse files Browse the repository at this point in the history
This commit ensures the `test_policy` utility works with
gymnasium>=0.28.1 (see
https://gymnasium.farama.org/content/migration-guide/).
  • Loading branch information
rickstaa committed Jul 3, 2023
1 parent d2b95e8 commit 80fe370
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions stable_learning_control/utils/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ def run_policy(
)

logger = EpochLogger(verbose_fmt="table")
o, r, d, ep_ret, ep_len, n = env.reset(), 0, False, 0, 0, 0
o, _ = env.reset()
r, d, ep_ret, ep_len, n = 0, False, 0, 0, 0
supports_deterministic = True # Only supported with gaussian algorithms.
render_error = False
while n < num_episodes:
Expand All @@ -298,7 +299,7 @@ def run_policy(
render_error = True
log_to_std_out(
(
"WARNING: Nothing was rendered since no render method was "
"Nothing was rendered since no render method was "
f"implemented for the '{env.unwrapped.spec.id}' environment."
),
type="warning",
Expand All @@ -321,13 +322,14 @@ def run_policy(
a = policy.get_action(o)

# Perform action in the environment and store result.
o, r, d, _ = env.step(a)
o, r, d, truncated, _ = env.step(a)
ep_ret += r
ep_len += 1
if d or (ep_len == max_ep_len):
if d or truncated:
logger.store(EpRet=ep_ret, EpLen=ep_len)
logger.log("Episode %d \t EpRet %.3f \t EpLen %d" % (n, ep_ret, ep_len))
o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
o, _ = env.reset()
r, d, ep_ret, ep_len = 0, False, 0, 0
n += 1

print("")
Expand Down

0 comments on commit 80fe370

Please sign in to comment.