Skip to content

Commit

Permalink
Fix action scaling for d4rl
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed May 6, 2024
1 parent 2cfa571 commit 4dee692
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
20 changes: 16 additions & 4 deletions d3rlpy/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def get_d4rl(
transition_picker: Optional[TransitionPickerProtocol] = None,
trajectory_slicer: Optional[TrajectorySlicerProtocol] = None,
render_mode: Optional[str] = None,
max_episode_steps: int = 1000,
) -> Tuple[ReplayBuffer, gym.Env[NDArray, NDArray]]:
"""Returns d4rl dataset and envrironment.
Expand All @@ -410,12 +411,17 @@ def get_d4rl(
transition_picker: TransitionPickerProtocol object.
trajectory_slicer: TrajectorySlicerProtocol object.
render_mode: Mode of rendering (``human``, ``rgb_array``).
max_episode_steps: Maximum episode environmental steps.
Returns:
tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment.
"""
try:
import d4rl # type: ignore
import d4rl
from d4rl.locomotion.wrappers import NormalizedBoxEnv
from d4rl.utils.wrappers import (
NormalizedBoxEnv as NormalizedBoxEnvFromUtils,
)

env = gym.make(env_name)
raw_dataset: Dict[str, NDArray] = env.get_dataset() # type: ignore
Expand All @@ -436,11 +442,17 @@ def get_d4rl(
trajectory_slicer=trajectory_slicer,
)

# wrapped by NormalizedBoxEnv that is incompatible with newer Gym
unwrapped_env: gym.Env[Any, Any] = env.env.env.env.wrapped_env # type: ignore
# remove incompatible wrappers
normalized_env = env.env.env.env # type: ignore
assert isinstance(
normalized_env, (NormalizedBoxEnv, NormalizedBoxEnvFromUtils)
)
unwrapped_env: gym.Env[Any, Any] = normalized_env.wrapped_env
unwrapped_env.render_mode = render_mode # overwrite

return dataset, TimeLimit(unwrapped_env, max_episode_steps=1000)
return dataset, TimeLimit(
normalized_env, max_episode_steps=max_episode_steps
)
except ImportError as e:
raise ImportError(
"d4rl is not installed.\n" "$ d3rlpy install d4rl"
Expand Down
5 changes: 5 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,8 @@ follow_imports_for_stubs = True

[mypy-minari.*]
ignore_missing_imports = True

[mypy-d4rl.*]
ignore_missing_imports = True
follow_imports = skip
follow_imports_for_stubs = True

0 comments on commit 4dee692

Please sign in to comment.