Skip to content

Commit

Permalink
Fix returns to go calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed May 14, 2023
1 parent 117b65e commit f1bf0b9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
8 changes: 5 additions & 3 deletions d3rlpy/dataset/trajectory_slicers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,26 @@ def __call__(
) -> PartialTrajectory:
end = end_index + 1
start = max(end - size, 0)
actual_size = end - start

# prepare terminal flags
terminals = np.zeros((end - start, 1), dtype=np.float32)
terminals = np.zeros((actual_size, 1), dtype=np.float32)
if episode.terminated and end_index == episode.size() - 1:
terminals[-1][0] = 1.0

# slice data
observations = slice_observations(episode.observations, start, end)
actions = episode.actions[start:end]
rewards = episode.rewards[start:end]
returns_to_go = np.cumsum(rewards, axis=0).reshape((end - start, 1))
all_returns_to_go = np.cumsum(episode.rewards[start:], axis=0)
returns_to_go = all_returns_to_go[:actual_size].reshape((-1, 1))

# prepare metadata
timesteps = np.arange(start, end)
masks = np.ones(end - start, dtype=np.float32)

# compute backward padding size
pad_size = size - (end - start)
pad_size = size - actual_size

if pad_size == 0:
return PartialTrajectory(
Expand Down
5 changes: 5 additions & 0 deletions tests/dataset/test_trajectory_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def test_basic_trajectory_slicer(
episode = create_episode(
observation_shape, action_size, length, terminated=terminated
)
returns_to_go = np.reshape(
np.cumsum(np.reshape(episode.rewards, [-1])), [-1, 1]
)

slicer = BasicTrajectorySlicer()

Expand Down Expand Up @@ -64,6 +67,8 @@ def test_basic_trajectory_slicer(
assert np.all(traj.actions[:pad_size] == 0.0)
assert np.all(traj.rewards[pad_size:] == episode.rewards[start:end])
assert np.all(traj.rewards[:pad_size] == 0.0)
assert np.all(traj.returns_to_go[pad_size:] == returns_to_go[start:end])
assert np.all(traj.returns_to_go[:pad_size] == 0.0)
assert np.all(traj.terminals == 0.0)
assert np.all(traj.timesteps[pad_size:] == np.arange(start, end))
assert np.all(traj.timesteps[:pad_size] == 0.0)
Expand Down

0 comments on commit f1bf0b9

Please sign in to comment.