Skip to content

Commit

Permalink
fix: reward normalize humanoid and swimmer tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
typoverflow committed Aug 11, 2023
1 parent 93bfa58 commit 3a9f74a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion offlinerllib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@

__version__ = "0.0.13"
__version__ = "0.0.14"
8 changes: 5 additions & 3 deletions offlinerllib/utils/d4rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,16 @@ def get_d4rl_dataset(task, normalize_reward=False, normalize_obs=False, terminat
if normalize_reward:
if "antmaze" in task:
dataset, _ = antmaze_normalize_reward(dataset)
elif "halfcheetah" in task or "hopper" in task or "walker2d" in task or "ant" in task:
elif "halfcheetah" in task or "hopper" in task or "walker2d" in task or "ant" in task or "humanoid" in task or "swimmer" in task:
dataset, _ = mujoco_normalize_reward(dataset)
termination_fn = get_termination_fn(task)
if return_termination_fn:
termination_fn = get_termination_fn(task)
if normalize_obs:
dataset, info = _normalize_obs(dataset)
from gym.wrappers.transform_observation import TransformObservation
env = TransformObservation(env, lambda obs: (obs - info["obs_mean"])/info["obs_std"])
termination_fn = get_termination_fn(task, info["obs_mean"], info["obs_std"])
if return_termination_fn:
termination_fn = get_termination_fn(task, info["obs_mean"], info["obs_std"])
if return_termination_fn:
return env, dataset, termination_fn
else:
Expand Down

0 comments on commit 3a9f74a

Please sign in to comment.