Skip to content

Commit

Permalink
Add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jun 23, 2023
1 parent c210d41 commit 4f60713
Show file tree
Hide file tree
Showing 6 changed files with 284 additions and 51 deletions.
6 changes: 3 additions & 3 deletions d3rlpy/dataset/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def append(self, episode: EpisodeBase, index: int) -> None:
Args:
episode: Episode object.
index: transition index.
index: Transition index.
"""
raise NotImplementedError

Expand All @@ -23,7 +23,7 @@ def episodes(self) -> Sequence[EpisodeBase]:
r"""Returns list of episodes.
Returns:
a list of saved episodes.
List of saved episodes.
"""
raise NotImplementedError

Expand All @@ -32,7 +32,7 @@ def transition_count(self) -> int:
r"""Returns the number of transitions.
Returns:
the number of transitions.
Number of transitions.
"""
raise NotImplementedError

Expand Down
14 changes: 7 additions & 7 deletions d3rlpy/dataset/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ class MDPDataset(ReplayBuffer):
r"""Backward-compability class of MDPDataset.
Args:
observations (ObservationSequence): observations.
actions (np.ndarray): actions.
rewards (np.ndarray): rewards.
terminals (np.ndarray): environmental terminal flags.
timeouts (np.ndarray): timeouts.
observations (ObservationSequence): Observations.
actions (np.ndarray): Actions.
rewards (np.ndarray): Rewards.
terminals (np.ndarray): Environmental terminal flags.
timeouts (np.ndarray): Timeouts.
transition_picker (Optional[TransitionPickerProtocol]):
transition picker implementation for Q-learning-based algorithms.
Transition picker implementation for Q-learning-based algorithms.
If ``None`` is given, ``BasicTransitionPicker`` is used by default.
trajectory_slicer (Optional[TrajectorySlicerProtocol]):
trajectory slicer implementation for Transformer-based algorithms.
Trajectory slicer implementation for Transformer-based algorithms.
If ``None`` is given, ``BasicTrajectorySlicer`` is used by default.
"""

Expand Down
164 changes: 164 additions & 0 deletions d3rlpy/dataset/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,21 @@

@dataclasses.dataclass(frozen=True)
class Signature:
r"""Signature of arrays.
Args:
dtype: List of numpy data types.
shape: List of array shapes.
"""
dtype: Sequence[np.dtype]
shape: Sequence[Sequence[int]]

def sample(self) -> Sequence[np.ndarray]:
r"""Returns sampled arrays.
Returns:
List of arrays based on dtypes and shapes.
"""
return [
np.random.random(shape).astype(dtype)
for shape, dtype in zip(self.shape, self.dtype)
Expand All @@ -38,6 +49,17 @@ def sample(self) -> Sequence[np.ndarray]:

@dataclasses.dataclass(frozen=True)
class Transition:
r"""Transition tuple.
Args:
observation: Observation.
action: Action
reward: Reward. This could be a multi-step discounted return.
next_observation: Observation at next timestep. This could be
observation at multi-step ahead.
terminal: Flag of environment termination.
interval: Timesteps between ``observation`` and ``next_observation``.
"""
observation: Observation # (...)
action: np.ndarray # (...)
reward: np.ndarray # (1,)
Expand All @@ -47,6 +69,11 @@ class Transition:

@property
def observation_signature(self) -> Signature:
r"""Returns observation sigunature.
Returns:
Observation signature.
"""
shape = get_shape_from_observation(self.observation)
dtype = get_dtype_from_observation(self.observation)
if isinstance(self.observation, np.ndarray):
Expand All @@ -56,13 +83,23 @@ def observation_signature(self) -> Signature:

@property
def action_signature(self) -> Signature:
r"""Returns action signature.
Returns:
Action signature.
"""
return Signature(
dtype=[self.action.dtype],
shape=[self.action.shape],
)

@property
def reward_signature(self) -> Signature:
r"""Returns reward signature.
Returns:
Reward signature.
"""
return Signature(
dtype=[self.reward.dtype],
shape=[self.reward.shape],
Expand All @@ -71,6 +108,18 @@ def reward_signature(self) -> Signature:

@dataclasses.dataclass(frozen=True)
class PartialTrajectory:
r"""Partial trajectory.
Args:
observations: Sequence of observations.
actions: Sequence of actions.
rewards: Sequence of rewards.
returns_to_go: Sequence of remaining returns.
terminals: Sequence of terminal flags.
timesteps: Sequence of timesteps.
masks: Sequence of masks that represent padding.
length: Sequence length.
"""
observations: ObservationSequence # (L, ...)
actions: np.ndarray # (L, ...)
rewards: np.ndarray # (L, 1)
Expand All @@ -82,6 +131,11 @@ class PartialTrajectory:

@property
def observation_signature(self) -> Signature:
r"""Returns observation sigunature.
Returns:
Observation signature.
"""
shape = get_shape_from_observation_sequence(self.observations)
dtype = get_dtype_from_observation_sequence(self.observations)
if isinstance(self.observations, np.ndarray):
Expand All @@ -91,13 +145,23 @@ def observation_signature(self) -> Signature:

@property
def action_signature(self) -> Signature:
r"""Returns action signature.
Returns:
Action signature.
"""
return Signature(
dtype=[self.actions.dtype],
shape=[self.actions.shape[1:]],
)

@property
def reward_signature(self) -> Signature:
r"""Returns reward signature.
Returns:
Reward signature.
"""
return Signature(
dtype=[self.rewards.dtype],
shape=[self.rewards.shape[1:]],
Expand All @@ -108,57 +172,138 @@ def __len__(self) -> int:


class EpisodeBase(Protocol):
r"""Episode interface.
``Episode`` represens an entire episode.
"""

@property
def observations(self) -> ObservationSequence:
r"""Returns sequence of observations.
Returns:
Sequence of observations.
"""
raise NotImplementedError

@property
def actions(self) -> np.ndarray:
r"""Returns sequence of actions.
Returns:
Sequence of actions.
"""
raise NotImplementedError

@property
def rewards(self) -> np.ndarray:
r"""Returns sequence of rewards.
Returns:
Sequence of rewards.
"""
raise NotImplementedError

@property
def terminated(self) -> bool:
r"""Returns environment terminal flag.
This flag becomes true when this episode is terminated. For timeout,
this flag stays false.
Returns:
Terminal flag.
"""
raise NotImplementedError

@property
def observation_signature(self) -> Signature:
r"""Returns observation signature.
Returns:
Observation signature.
"""
raise NotImplementedError

@property
def action_signature(self) -> Signature:
r"""Returns action signature.
Returns:
Action signature.
"""
raise NotImplementedError

@property
def reward_signature(self) -> Signature:
r"""Returns reward signature.
Returns:
Reward signature.
"""
raise NotImplementedError

def size(self) -> int:
r"""Returns length of an episode.
Returns:
Episode length.
"""
raise NotImplementedError

def compute_return(self) -> float:
r"""Computes total episode return.
Returns:
Total episode return.
"""
raise NotImplementedError

def serialize(self) -> Dict[str, Any]:
r"""Returns serized episode data.
Returns:
Serialized episode data.
"""
raise NotImplementedError

@classmethod
def deserialize(cls, serializedData: Dict[str, Any]) -> "EpisodeBase":
r"""Constructs episode from serialized data.
This is an inverse operation of ``serialize`` method.
Args:
serializedData: Serialized episode data.
Returns:
Episode object.
"""
raise NotImplementedError

def __len__(self) -> int:
raise NotImplementedError

@property
def transition_count(self) -> int:
r"""Returns the number of transitions.
Returns:
Number of transitions.
"""
raise NotImplementedError


@dataclasses.dataclass(frozen=True)
class Episode:
r"""Standard episode implementation.
Args:
observations: Sequence of observations.
actions: Sequence of actions.
rewards: Sequence of rewards.
terminated: Flag of environment termination.
"""
observations: ObservationSequence
actions: np.ndarray
rewards: np.ndarray
Expand Down Expand Up @@ -220,6 +365,17 @@ def transition_count(self) -> int:

@dataclasses.dataclass(frozen=True)
class DatasetInfo:
r"""Dataset information.
Args:
observation_signature: Observation signature.
action_signature: Action signature.
reward_signature: Reward signature.
action_space: Action space type.
action_size: Size of action-space. For continuous action-space,
this represents dimension of action vectors. For discrete
action-space, this represents the number of discrete actions.
"""
observation_signature: Signature
action_signature: Signature
reward_signature: Signature
Expand All @@ -228,6 +384,14 @@ class DatasetInfo:

@classmethod
def from_episodes(cls, episodes: Sequence[EpisodeBase]) -> "DatasetInfo":
r"""Constructs from sequence of episodes.
Args:
episodes: Sequence of episodes.
Returns:
DatasetInfo object.
"""
action_space = detect_action_space(episodes[0].actions)
if action_space == ActionSpace.CONTINUOUS:
action_size = episodes[0].action_signature.shape[0][0]
Expand Down
18 changes: 17 additions & 1 deletion d3rlpy/dataset/episode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,27 @@


class EpisodeGeneratorProtocol(Protocol):
r"""Episode generator interface."""

def __call__(self) -> Sequence[EpisodeBase]:
...
r"""Returns generated episodes.
Returns:
Sequence of episodes.
"""
raise NotImplementedError


class EpisodeGenerator(EpisodeGeneratorProtocol):
r"""Standard episode generator implementation.
Args:
observations: Sequence of observations.
actions: Sequence of actions.
rewards: Sequence of rewards.
terminals: Sequence of environment terminal flags.
timeouts: Sequence of timeout flags.
"""
_observations: ObservationSequence
_actions: np.ndarray
_rewards: np.ndarray
Expand Down
Loading

0 comments on commit 4f60713

Please sign in to comment.