Skip to content

Commit

Permalink
Add docformatter
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jun 20, 2023
1 parent 427637a commit c210d41
Show file tree
Hide file tree
Showing 28 changed files with 13 additions and 139 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/format_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
python -m pip install --upgrade pip
pip install Cython numpy matplotlib onnxruntime onnx pytest tensorboardX
pip install -e .
pip install black mypy pylint==2.13.5 isort
pip install black mypy pylint==2.13.5 isort docformatter
- name: Check format
run: |
./scripts/format
Expand Down
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ Before making your nice PR, please run the follwing commands to inspect code qua
### testing
```
$ pip install pytest-cov onnxruntime stable-baselines3 # dependencies used in unit tests
$ pip install git+https://github.com/takuseno/d4rl-pybullet
$ ./scripts/test
```

### coding style
This repository is styled with [black](https://github.com/psf/black) formatter.
Also, [isort](https://github.com/PyCQA/isort) is used to format package imports.
[docformatter](https://github.com/PyCQA/docformatter) is additionally used to format docstrings.
```
$ pip install black # formatters
$ pip install black isort docformatter # formatters
$ ./scripts/format
```

Expand Down
1 change: 0 additions & 1 deletion d3rlpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def seed(n: int) -> None:
Args:
n (int): seed value.
"""
random.seed(n)
np.random.seed(n)
Expand Down
1 change: 0 additions & 1 deletion d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class AWACConfig(LearnableConfig):
:math:`A^\pi(s_t, a_t)`.
n_critics (int): the number of Q functions for ensemble.
update_actor_interval (int): interval to update policy function.
"""
actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4
Expand Down
17 changes: 1 addition & 16 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def save_policy(self, fname: str) -> None:
Args:
fname: destination file path.
"""
assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR

Expand Down Expand Up @@ -254,7 +253,6 @@ def predict(self, x: Observation) -> np.ndarray:
Returns:
greedy actions
"""
assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR
assert check_non_1d_array(x), "Input must have batch dimension."
Expand Down Expand Up @@ -300,7 +298,6 @@ def predict_value(self, x: Observation, action: np.ndarray) -> np.ndarray:
Returns:
predicted action-values
"""
assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR
assert check_non_1d_array(x), "Input must have batch dimension."
Expand Down Expand Up @@ -341,7 +338,6 @@ def sample_action(self, x: Observation) -> np.ndarray:
Returns:
sampled actions.
"""
assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR
assert check_non_1d_array(x), "Input must have batch dimension."
Expand Down Expand Up @@ -400,7 +396,6 @@ def fit(
Returns:
list of result tuples (epoch, metrics) per epoch.
"""
results = list(
self.fitter(
Expand Down Expand Up @@ -432,7 +427,7 @@ def fitter(
callback: Optional[Callable[[Self, int, int], None]] = None,
) -> Generator[Tuple[int, Dict[str, float]], None, None]:
"""Iterate over epochs steps to train with the given dataset. At each
iteration algo methods and properties can be changed or queried.
iteration algo methods and properties can be changed or queried.
.. code-block:: python
Expand All @@ -458,7 +453,6 @@ def fitter(
Returns:
iterator yielding current epoch and metrics dict.
"""
dataset_info = DatasetInfo.from_episodes(dataset.episodes)
LOG.info("dataset info", dataset_info=dataset_info)
Expand Down Expand Up @@ -593,7 +587,6 @@ def fit_online(
show_progress: flag to show progress bar for iterations.
callback: callable function that takes ``(algo, epoch, total_step)``
, which is called at the end of epochs.
"""

# create default replay buffer
Expand Down Expand Up @@ -740,7 +733,6 @@ def collect(
Returns:
replay buffer with the collected data.
"""
# create default replay buffer
if buffer is None:
Expand Down Expand Up @@ -806,7 +798,6 @@ def update(self, batch: TransitionMiniBatch) -> Dict[str, float]:
Returns:
dictionary of metrics.
"""
torch_batch = TorchMiniBatch.from_batch(
batch=batch,
Expand All @@ -828,7 +819,6 @@ def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]:
Returns:
dictionary of metrics.
"""
raise NotImplementedError

Expand All @@ -850,7 +840,6 @@ def copy_policy_from(
Args:
algo: algorithm object.
"""
assert self._impl, IMPL_NOT_INITIALIZED_ERROR
assert isinstance(algo.impl, QLearningAlgoImplBase)
Expand All @@ -874,7 +863,6 @@ def copy_policy_optim_from(
Args:
algo: algorithm object.
"""
assert self._impl, IMPL_NOT_INITIALIZED_ERROR
assert isinstance(algo.impl, QLearningAlgoImplBase)
Expand All @@ -898,7 +886,6 @@ def copy_q_function_from(
Args:
algo: algorithm object.
"""
assert self._impl, IMPL_NOT_INITIALIZED_ERROR
assert isinstance(algo.impl, QLearningAlgoImplBase)
Expand All @@ -922,7 +909,6 @@ def copy_q_function_optim_from(
Args:
algo: algorithm object.
"""
assert self._impl, IMPL_NOT_INITIALIZED_ERROR
assert isinstance(algo.impl, QLearningAlgoImplBase)
Expand All @@ -933,7 +919,6 @@ def reset_optimizer_states(self) -> None:
This is especially useful when fine-tuning policies with setting inital
optimizer states.
"""
assert self._impl, IMPL_NOT_INITIALIZED_ERROR
self._impl.reset_optimizer_states()
4 changes: 0 additions & 4 deletions d3rlpy/algos/qlearning/explorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def sample(
Returns:
:math:`\\epsilon`-greedy action.
"""
greedy_actions = algo.predict(x)
random_actions = np.random.randint(algo.action_size, size=x.shape[0])
Expand All @@ -105,7 +104,6 @@ def compute_epsilon(self, step: int) -> float:
Returns:
:math:`\\epsilon`.
"""
if step >= self._duration:
return self._end_epsilon
Expand All @@ -119,7 +117,6 @@ class NormalNoise(Explorer):
Args:
mean (float): mean.
std (float): standard deviation.
"""

_mean: float
Expand All @@ -140,7 +137,6 @@ def sample(
Returns:
action with noise injection.
"""
action = algo.predict(x)
noise = np.random.normal(self._mean, self._std, size=action.shape)
Expand Down
3 changes: 2 additions & 1 deletion d3rlpy/algos/qlearning/plas.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ def get_action_type(self) -> ActionSpace:

@dataclasses.dataclass()
class PLASWithPerturbationConfig(PLASConfig):
r"""Config of Policy in Latent Action Space algorithm with perturbation layer.
r"""Config of Policy in Latent Action Space algorithm with perturbation
layer.
PLAS with perturbation layer enables PLAS to output out-of-distribution
action.
Expand Down
2 changes: 0 additions & 2 deletions d3rlpy/algos/qlearning/random_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class RandomPolicyConfig(LearnableConfig):
``['uniform', 'normal']``.
normal_std (float): standard deviation of the normal distribution. This
is only used when ``distribution='normal'``.
"""
distribution: str = "uniform"
normal_std: float = 1.0
Expand Down Expand Up @@ -96,7 +95,6 @@ class DiscreteRandomPolicyConfig(LearnableConfig):
This is designed for data collection and lightweight interaction tests.
``fit`` and ``fit_online`` methods will raise exceptions.
"""

def create(self, device: DeviceArg = False) -> "DiscreteRandomPolicy": # type: ignore
Expand Down
9 changes: 1 addition & 8 deletions d3rlpy/algos/transformer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ class StatefulTransformerWrapper(Generic[TTransformerImpl, TTransformerConfig]):
Args:
algo (TransformerAlgoBase): Transformer-based algorithm.
target_return (float): target return to achieve.
"""
_algo: "TransformerAlgoBase[TTransformerImpl, TTransformerConfig]"
_target_return: float
Expand Down Expand Up @@ -121,7 +120,6 @@ def predict(self, x: Observation, reward: float) -> Union[np.ndarray, int]:
Returns:
action.
"""
self._observations.append(x)
self._rewards.append(reward)
Expand Down Expand Up @@ -182,7 +180,6 @@ def predict(self, inpt: TransformerInput) -> np.ndarray:
Returns:
action.
"""
assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR
with torch.no_grad():
Expand Down Expand Up @@ -216,7 +213,7 @@ def fit(
callback: Optional[Callable[[Self, int, int], None]] = None,
) -> None:
"""Iterate over epochs steps to train with the given dataset. At each
iteration algo methods and properties can be changed or queried.
iteration algo methods and properties can be changed or queried.
.. code-block:: python
Expand All @@ -240,7 +237,6 @@ def fit(
save_interval: interval to save parameters.
callback: callable function that takes ``(algo, epoch, total_step)``
, which is called every step.
"""
dataset_info = DatasetInfo.from_episodes(dataset.episodes)
LOG.info("dataset info", dataset_info=dataset_info)
Expand Down Expand Up @@ -344,7 +340,6 @@ def update(self, batch: TrajectoryMiniBatch) -> Dict[str, float]:
Returns:
dictionary of metrics.
"""
torch_batch = TorchTrajectoryMiniBatch.from_batch(
batch=batch,
Expand All @@ -366,7 +361,6 @@ def inner_update(self, batch: TorchTrajectoryMiniBatch) -> Dict[str, float]:
Returns:
dictionary of metrics.
"""
raise NotImplementedError

Expand All @@ -380,6 +374,5 @@ def as_stateful_wrapper(
Returns:
StatefulTransformerWrapper object.
"""
return StatefulTransformerWrapper(self, target_return)
1 change: 0 additions & 1 deletion d3rlpy/algos/transformer/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class DecisionTransformerConfig(TransformerConfig):
(``simple`` or ``global``).
warmup_steps (int): warmup steps for learning rate scheduler.
clip_grad_norm (float): norm of gradient clipping.
"""

batch_size: int = 64
Expand Down
Loading

0 comments on commit c210d41

Please sign in to comment.