Skip to content

Commit

Permalink
Fix: Address flake8 and black formatting issues (#395)
Browse files Browse the repository at this point in the history
This commit applies the black formatter to the codebase and resolves
several flake8 errors.
  • Loading branch information
rickstaa committed Feb 6, 2024
1 parent cfcf81c commit 517ee30
Show file tree
Hide file tree
Showing 74 changed files with 201 additions and 128 deletions.
31 changes: 19 additions & 12 deletions examples/eval_robustness.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Script version of the eval robustness tool. This can be used to manually evaluate the
disturbance if you don't want to implement a disturber.
"""

import os
from pathlib import Path

Expand Down Expand Up @@ -470,14 +471,16 @@ def noise_disturbance(mean, std):

# Replace observations and references with short names.
obs_disturbance_df["observation"] = obs_disturbance_df["observation"].apply(
lambda x: "observation"
if x == "observation"
else x.replace("observation_", "obs_")
lambda x: (
"observation"
if x == "observation"
else x.replace("observation_", "obs_")
)
)
refs_disturbance_df["reference"] = refs_disturbance_df["reference"].apply(
lambda x: "reference"
if x == "reference"
else x.replace("reference_", "ref_")
lambda x: (
"reference" if x == "reference" else x.replace("reference_", "ref_")
)
)

# Initialize plot.
Expand Down Expand Up @@ -610,19 +613,23 @@ def noise_disturbance(mean, std):
ref_errors_disturbance_df["reference_error"] = ref_errors_disturbance_df[
"reference_error"
].apply(
lambda x: "reference_error"
if x == "reference_error"
else x.replace("reference_error_", "ref_error_")
lambda x: (
"reference_error"
if x == "reference_error"
else x.replace("reference_error_", "ref_error_")
)
)

# Initialize plot.
fig, ax = plt.subplots(figsize=(12, 6), tight_layout=True)

# Create plot title.
plot_title = "Mean {} ".format(
"reference errors"
if len(available_ref_errors) > 1
else "reference error",
(
"reference errors"
if len(available_ref_errors) > 1
else "reference error"
),
)
plot_title += (
"under 'RandomActionNoise' disturber with mean {} and std {}.".format(
Expand Down
1 change: 1 addition & 0 deletions examples/manual_env_policy_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
the CLI fails. It uses the ``Oscillator-v1`` environment that is found in the
:stable_gym:`stable_gym <>` package.
"""

import gymnasium as gym

from stable_learning_control.utils.test_policy import (
Expand Down
1 change: 1 addition & 0 deletions examples/manual_robustness_eval_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
generated by the :ref:`robustness evaluation utility <eval_robustness>` utility. This
example plots the first observation and reference for each disturbance in the dataset.
"""

import argparse
from pathlib import Path

Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/lac_ray_hyper_parameter_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
.. _`Weights & Biases`: https://wandb.ai/site
.. _`Weights & Biases documentation`: https://docs.wandb.ai/
""" # noqa: E501

import os.path as osp

import gymnasium as gym
Expand All @@ -49,7 +50,6 @@ def train_lac(config):
Args:
config (dict): The Ray tuning configuration dictionary.
"""

# Unpack trainable arguments.
env_name = config.pop("env_name")

Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/sac_exp_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
.. _`SpinningUp documentation`: https://spinningup.openai.com/en/latest/user/running.html#using-experimentgrid
""" # noqa

import argparse

import torch
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/sac_ray_hyper_parameter_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
.. _`Weights & Biases`: https://wandb.ai/site
.. _`Weights & Biases documentation`: https://docs.wandb.ai/
""" # noqa: E501

import os.path as osp

import gymnasium as gym
Expand All @@ -49,7 +50,6 @@ def train_sac(config):
Args:
config (dict): The Ray tuning configuration dictionary.
"""

# Unpack trainable arguments.
env_name = config.pop("env_name")

Expand Down
2 changes: 1 addition & 1 deletion examples/tf2/lac_ray_hyper_parameter_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
.. _`Weights & Biases`: https://wandb.ai/site
.. _`Weights & Biases documentation`: https://docs.wandb.ai/
""" # noqa: E501

import os.path as osp

import gymnasium as gym
Expand All @@ -49,7 +50,6 @@ def train_lac(config):
Args:
config (dict): The Ray tuning configuration dictionary.
"""

# Unpack trainable arguments.
env_name = config.pop("env_name")

Expand Down
1 change: 1 addition & 0 deletions examples/tf2/sac_exp_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
.. _`SpinningUp documentation`: https://spinningup.openai.com/en/latest/user/running.html#using-experimentgrid
""" # noqa

import argparse

import tensorflow as tf
Expand Down
2 changes: 1 addition & 1 deletion examples/tf2/sac_ray_hyper_parameter_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
.. _`Weights & Biases`: https://wandb.ai/site
.. _`Weights & Biases documentation`: https://docs.wandb.ai/
""" # noqa: E501

import os.path as osp

import gymnasium as gym
Expand All @@ -49,7 +50,6 @@ def train_sac(config):
Args:
config (dict): The Ray tuning configuration dictionary.
"""

# Unpack trainable arguments.
env_name = config.pop("env_name")

Expand Down
1 change: 1 addition & 0 deletions sandbox/test_finite_horizon_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Script used for performing some quick tests on the FiniteHorizonReplayBuffer class.
"""

import gymnasium as gym

# from stable_learning_control.common.buffers import TrajectoryBuffer
Expand Down
1 change: 1 addition & 0 deletions sandbox/test_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
example, we use the ``Oscillator-v1`` environment in the :stable_gym:`stable_gym <>`
package.
"""

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
Expand Down
1 change: 1 addition & 0 deletions sandbox/test_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Script used for performing some quick tests on the ReplayBuffer class.
"""

import gymnasium as gym

# from stable_learning_control.common.buffers import TrajectoryBuffer
Expand Down
1 change: 1 addition & 0 deletions sandbox/test_traj_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
buffer was created for a new monte-carlo algorithm we had in mind. The buffer is
designed to store trajectories of variable length.
"""

import gymnasium as gym

# from stable_learning_control.common.buffers import TrajectoryBuffer
Expand Down
6 changes: 4 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
[flake8]
max-line-length = 89
extend-ignore = E203
exclude =
docs/source/conf.py,
sandbox,
build,
node_modules
node_modules,
tests,
stable_learning_control/version.py
per-file-ignores =
__init__.py: F401, E501
extend-ignore = E203, D400, D401, D205
1 change: 1 addition & 0 deletions stable_learning_control/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module that Initialises the stable_learning_control package."""

# Put algorithms in main namespace.
from stable_learning_control.algos.pytorch.lac.lac import lac as lac_pytorch
from stable_learning_control.algos.pytorch.latc.latc import latc as latc_pytorch
Expand Down
1 change: 1 addition & 0 deletions stable_learning_control/algos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
research. As a result, the TensorFlow implementation has yet to be thoroughly tested,
and no guarantees can be given about the correctness of these algorithms.
"""

from stable_learning_control.algos.pytorch.lac.lac import LAC as LAC_pytorch
from stable_learning_control.algos.pytorch.sac.sac import SAC as SAC_pytorch
from stable_learning_control.utils.import_utils import tf_installed
Expand Down
1 change: 1 addition & 0 deletions stable_learning_control/algos/common/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Functions that are used in multiple Pytorch and TensorFlow algorithms."""

import importlib

import numpy as np
Expand Down
4 changes: 2 additions & 2 deletions stable_learning_control/algos/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Contains the Pytorch implementations of the RL algorithms.
"""
"""Contains the Pytorch implementations of the RL algorithms."""

from stable_learning_control.algos.pytorch.lac.lac import LAC
from stable_learning_control.algos.pytorch.sac.sac import SAC
4 changes: 2 additions & 2 deletions stable_learning_control/algos/pytorch/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Contains several functions that are used across all the RL algorithms.
"""
"""Contains several functions that are used across all the RL algorithms."""

from stable_learning_control.algos.pytorch.common.get_lr_scheduler import (
get_lr_scheduler,
)
4 changes: 2 additions & 2 deletions stable_learning_control/algos/pytorch/common/buffers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Contains several replay buffers used in the Pytorch algorithms.
"""
"""Contains several replay buffers used in the Pytorch algorithms."""

import torch

from stable_learning_control.algos.common.buffers import (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains functions used for creating Pytorch learning rate schedulers."""

from decimal import Decimal

import numpy as np
Expand All @@ -22,7 +23,7 @@ def get_exponential_decay_rate(lr_start, lr_final, steps):


def calc_linear_decay_rate(lr_init, lr_final, steps):
"""Returns the linear decay factor (G) needed to achieve a given final learning
r"""Returns the linear decay factor (G) needed to achieve a given final learning
rate at a certain step. This decay factor can for example be used with a
:class:`torch.optim.lr_scheduler.LambdaLR` scheduler. Keep in mind that this
function assumes the following formula for the learning rate decay.
Expand Down
4 changes: 2 additions & 2 deletions stable_learning_control/algos/pytorch/common/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Contains several Pytorch helper functions.
"""
"""Contains several Pytorch helper functions."""

import numpy as np
import torch
import torch.nn as nn
Expand Down
4 changes: 2 additions & 2 deletions stable_learning_control/algos/pytorch/lac/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""A Lyapunov (soft) Actor-Critic Agent.
"""
"""A Lyapunov (soft) Actor-Critic Agent."""

from stable_learning_control.algos.pytorch.lac.lac import LAC, lac
4 changes: 2 additions & 2 deletions stable_learning_control/algos/pytorch/latc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""A Lyapunov (soft) Actor-Twin Critic Agent.
"""
"""A Lyapunov (soft) Actor-Twin Critic Agent."""

from stable_learning_control.algos.pytorch.latc.latc import latc
4 changes: 2 additions & 2 deletions stable_learning_control/algos/pytorch/policies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Policies and networks used to create the RL agents.
"""
"""Policies and networks used to create the RL agents."""

from stable_learning_control.algos.pytorch.policies.actors.squashed_gaussian_actor import (
SquashedGaussianActor,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Actor network structures.
"""
"""Actor network structures."""

from stable_learning_control.algos.pytorch.policies.actors.squashed_gaussian_actor import (
SquashedGaussianActor,
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This module contains a Pytorch implementation of the Squashed Gaussian Actor policy of
`Haarnoja et al. 2019 <https://arxiv.org/abs/1812.05905>`_.
"""

import numpy as np
import torch
import torch.nn as nn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This module contains a Pytorch implementation of the Lyapunov Critic policy of
`Han et al. 2020 <https://arxiv.org/abs/2004.14288>`_.
"""

import torch
import torch.nn as nn

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This module contains a Pytorch implementation of the Q Critic policy of
`Haarnoja et al. 2019 <https://arxiv.org/abs/1812.05905>`_.
"""

import torch
import torch.nn as nn

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Critic network structures.
"""
"""Critic network structures."""

from stable_learning_control.algos.pytorch.policies.critics.L_critic import LCritic
from stable_learning_control.algos.pytorch.policies.critics.Q_critic import QCritic
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
This module contains a Pytorch implementation of the Soft Actor Critic policy of
`Haarnoja et al. 2019 <https://arxiv.org/abs/1812.05905>`_.
"""

import torch
import torch.nn as nn

Expand Down
4 changes: 2 additions & 2 deletions stable_learning_control/algos/pytorch/sac/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""A Soft Actor-Critic Agent.
"""
"""A Soft Actor-Critic Agent."""

from stable_learning_control.algos.pytorch.sac.sac import sac
4 changes: 2 additions & 2 deletions stable_learning_control/algos/tf2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Contains the TensorFlow 2.x implementations of the RL algorithms.
"""
"""Contains the TensorFlow 2.x implementations of the RL algorithms."""

from stable_learning_control.algos.tf2.lac.lac import LAC
from stable_learning_control.algos.tf2.sac.sac import SAC
4 changes: 2 additions & 2 deletions stable_learning_control/algos/tf2/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Contains several functions that are used across all the RL algorithms.
"""
"""Contains several functions that are used across all the RL algorithms."""

from stable_learning_control.algos.tf2.common.get_lr_scheduler import get_lr_scheduler
4 changes: 3 additions & 1 deletion stable_learning_control/algos/tf2/common/bijectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
`the TensorFlow documentation <https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors/Bijector>`_ and
`this stackoverflow question <https://stackoverflow.com/questions/56425301/what-is-bijectors-in-layman-terms-in-tensorflow-probability>`_.
""" # noqa: E501

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
Expand All @@ -12,7 +13,8 @@

class SquashBijector(tfp.bijectors.Bijector):
"""A squash bijector used to keeps track of the distribution properties when the
distribution is transformed using the tanh squash function."""
distribution is transformed using the tanh squash function.
"""

def __init__(self, validate_args=False, name="tanh"):
"""Initiate squashed bijector object.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module used for creating TensorFlow learning rate schedulers."""

import numpy as np

from stable_learning_control.utils.import_utils import import_tf
Expand Down
4 changes: 2 additions & 2 deletions stable_learning_control/algos/tf2/common/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Several TensorFlow helper functions.
"""
"""Several TensorFlow helper functions."""

import numpy as np
import tensorflow as tf

Expand Down
4 changes: 2 additions & 2 deletions stable_learning_control/algos/tf2/lac/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""A Lyapunov Actor Critic Agent.
"""
"""A Lyapunov Actor Critic Agent."""

from stable_learning_control.algos.tf2.lac.lac import LAC, lac
4 changes: 2 additions & 2 deletions stable_learning_control/algos/tf2/latc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""A Lyapunov (soft) Actor-Twin Critic Agent.
"""
"""A Lyapunov (soft) Actor-Twin Critic Agent."""

from stable_learning_control.algos.tf2.latc.latc import latc

0 comments on commit 517ee30

Please sign in to comment.