Skip to content

Commit

Permalink
Implement MultiStepTransitionPicker
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jun 18, 2023
1 parent 489169a commit f28c124
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 6 deletions.
51 changes: 49 additions & 2 deletions d3rlpy/dataset/transition_pickers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from typing_extensions import Protocol

from .components import EpisodeBase, Transition
Expand All @@ -11,6 +12,7 @@
"TransitionPickerProtocol",
"BasicTransitionPicker",
"FrameStackTransitionPicker",
"MultiStepTransitionPicker",
]


Expand All @@ -28,7 +30,7 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition:
_validate_index(episode, index)

observation = retrieve_observation(episode.observations, index)
is_terminal = index == episode.size() - 1
is_terminal = episode.terminated and index == episode.size() - 1
if is_terminal:
next_observation = create_zero_observation(observation)
else:
Expand Down Expand Up @@ -57,7 +59,7 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition:
observation = stack_recent_observations(
episode.observations, index, self._n_frames
)
is_terminal = index == episode.size() - 1
is_terminal = episode.terminated and index == episode.size() - 1
if is_terminal:
next_observation = create_zero_observation(observation)
else:
Expand All @@ -72,3 +74,48 @@ def __call__(self, episode: EpisodeBase, index: int) -> Transition:
terminal=float(is_terminal),
interval=1,
)


class MultiStepTransitionPicker(TransitionPickerProtocol):
_n_steps: int
_gamma: float

def __init__(self, n_steps: int, gamma: float):
self._n_steps = n_steps
self._gamma = gamma

def __call__(self, episode: EpisodeBase, index: int) -> Transition:
_validate_index(episode, index)

observation = retrieve_observation(episode.observations, index)

# get observation N-step ahead
if episode.terminated:
next_index = min(index + self._n_steps, episode.size())
is_terminal = next_index == episode.size()
if is_terminal:
next_observation = create_zero_observation(observation)
else:
next_observation = retrieve_observation(
episode.observations, next_index
)
else:
is_terminal = False
next_index = min(index + self._n_steps, episode.size() - 1)
next_observation = retrieve_observation(
episode.observations, next_index
)

# compute multi-step return
interval = next_index - index
cum_gammas = np.expand_dims(self._gamma ** np.arange(interval), axis=1)
ret = np.sum(episode.rewards[index:next_index] * cum_gammas, axis=0)

return Transition(
observation=observation,
action=episode.actions[index],
reward=ret,
next_observation=next_observation,
terminal=float(is_terminal),
interval=interval,
)
61 changes: 61 additions & 0 deletions examples/multi_step_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import argparse

import gym

import d3rlpy

GAMMA = 0.99


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--env", type=str, default="Hopper-v2")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--n-steps", type=int, default=1)
parser.add_argument("--gpu", action="store_true")
args = parser.parse_args()

env = gym.make(args.env)
eval_env = gym.make(args.env)

# fix seed
d3rlpy.seed(args.seed)
d3rlpy.envs.seed_env(env, args.seed)
d3rlpy.envs.seed_env(eval_env, args.seed)

# setup algorithm
sac = d3rlpy.algos.SACConfig(
batch_size=256,
gamma=GAMMA,
actor_learning_rate=3e-4,
critic_learning_rate=3e-4,
temp_learning_rate=3e-4,
).create(device=args.gpu)

# multi-step transition sampling
transition_picker = d3rlpy.dataset.MultiStepTransitionPicker(
n_steps=args.n_steps,
gamma=GAMMA,
)

# replay buffer for experience replay
buffer = d3rlpy.dataset.create_fifo_replay_buffer(
limit=1000000,
env=env,
transition_picker=transition_picker,
)

# start training
sac.fit_online(
env,
buffer,
eval_env=eval_env,
n_steps=1000000,
n_steps_per_epoch=10000,
update_interval=1,
update_start_step=1000,
)


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions scripts/format
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ else
fi

# format package imports
isort -l 80 --profile black $ISORT_ARG d3rlpy tests setup.py reproductions
isort -l 80 --profile black $ISORT_ARG d3rlpy tests setup.py reproductions examples

# use black for the better type annotations
black -l 80 $BLACK_ARG d3rlpy tests setup.py reproductions
black -l 80 $BLACK_ARG d3rlpy tests setup.py reproductions examples
4 changes: 2 additions & 2 deletions scripts/lint
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash -eux

# type check
mypy d3rlpy reproductions tests
mypy d3rlpy reproductions tests examples

# code-format check
pylint d3rlpy reproductions tests
pylint d3rlpy reproductions tests examples
66 changes: 66 additions & 0 deletions tests/dataset/test_transition_pickers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from d3rlpy.dataset import (
BasicTransitionPicker,
FrameStackTransitionPicker,
MultiStepTransitionPicker,
Shape,
)

Expand Down Expand Up @@ -111,3 +112,68 @@ def test_frame_stack_transition_picker(
transition = picker(episode, length - 1)
assert np.all(transition.next_observation == 0.0)
assert transition.terminal == 1.0


@pytest.mark.parametrize("observation_shape", [(4,), ((4,), (8,))])
@pytest.mark.parametrize("action_size", [2])
@pytest.mark.parametrize("length", [100])
@pytest.mark.parametrize("n_steps", [1, 3])
@pytest.mark.parametrize("gamma", [0.99])
def test_multi_step_transition_picker(
observation_shape: Shape,
action_size: int,
length: int,
n_steps: int,
gamma: float,
) -> None:
episode = create_episode(
observation_shape, action_size, length, terminated=True
)

picker = MultiStepTransitionPicker(n_steps=n_steps, gamma=gamma)

# check transition
transition = picker(episode, 0)
if isinstance(observation_shape[0], tuple):
for i, shape in enumerate(observation_shape):
assert transition.observation_signature.shape[i] == shape
assert np.all(
transition.observation[i] == episode.observations[i][0]
)
assert np.all(
transition.next_observation[i]
== episode.observations[i][n_steps]
)
else:
assert transition.observation_signature.shape[0] == observation_shape
assert np.all(transition.observation == episode.observations[0])
assert np.all(
transition.next_observation == episode.observations[n_steps]
)
gammas = gamma ** np.arange(n_steps)
ref_reward = np.sum(gammas * np.reshape(episode.rewards[:n_steps], [-1]))
assert np.all(transition.action == episode.actions[0])
assert np.all(transition.reward == np.reshape(ref_reward, [1]))
assert transition.interval == n_steps
assert transition.terminal == 0

# check terminal transition
transition = picker(episode, length - n_steps)
if isinstance(observation_shape[0], tuple):
for i, shape in enumerate(observation_shape):
dummy_observation = np.zeros(shape)
assert transition.observation_signature.shape[i] == shape
assert np.all(
transition.observation[i] == episode.observations[i][-n_steps]
)
assert np.all(transition.next_observation[i] == dummy_observation)
else:
dummy_observation = np.zeros(observation_shape)
assert transition.observation_signature.shape[0] == observation_shape
assert np.all(transition.observation == episode.observations[-n_steps])
assert np.all(transition.next_observation == dummy_observation)
assert np.all(transition.action == episode.actions[-n_steps])
ref_reward = np.sum(gammas * np.reshape(episode.rewards[-n_steps:], [-1]))
assert np.all(transition.reward == np.reshape(ref_reward, [1]))
assert transition.interval == n_steps
assert transition.terminal == 1.0

0 comments on commit f28c124

Please sign in to comment.