From 3fbc9be54f8700633f31c2ccb528189e7052948e Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 14 Oct 2020 11:30:47 +0900 Subject: [PATCH 1/5] Add black-compatible isort config --- setup.cfg | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/setup.cfg b/setup.cfg index 5ba1b695d..13664eed0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,3 +35,11 @@ ignore_missing_imports = True [mypy-slimevolleygym.*] ignore_missing_imports = True + +[isort] +multi_line_output = 3 +include_trailing_comma = True +force_grid_wrap = 0 +use_parentheses = True +ensure_newline_before_comments = True +line_length = 88 From 1799246ae7f75b480f815bb270b15b85d7258f8b Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 14 Oct 2020 11:34:50 +0900 Subject: [PATCH 2/5] Apply isort to pfrl, tests, and examples --- examples/atari/reproduction/a3c/train_a3c.py | 6 +-- examples/atari/reproduction/dqn/train_dqn.py | 16 +++----- examples/atari/reproduction/iqn/train_iqn.py | 5 +-- .../reproduction/rainbow/train_rainbow.py | 9 ++--- examples/atari/train_a2c_ale.py | 3 +- examples/atari/train_acer_ale.py | 6 +-- examples/atari/train_categorical_dqn_ale.py | 8 +--- examples/atari/train_dqn_ale.py | 15 +++---- examples/atari/train_dqn_batch_ale.py | 15 +++---- examples/atari/train_drqn_ale.py | 6 +-- examples/atari/train_ppo_ale.py | 5 +-- .../atlas/train_soft_actor_critic_atlas.py | 7 +--- examples/grasping/train_dqn_batch_grasping.py | 9 ++--- examples/gym/train_categorical_dqn_gym.py | 8 +--- examples/gym/train_dqn_gym.py | 15 +++---- examples/gym/train_reinforce_gym.py | 6 +-- .../mujoco/reproduction/ddpg/train_ddpg.py | 8 +--- examples/mujoco/reproduction/ppo/train_ppo.py | 3 +- .../train_soft_actor_critic.py | 11 ++--- examples/mujoco/reproduction/td3/train_td3.py | 5 +-- examples/slimevolley/train_rainbow.py | 9 ++--- pfrl/__init__.py | 2 +- pfrl/action_value.py | 4 +- pfrl/agent.py | 10 +---- pfrl/agents/__init__.py | 2 +- pfrl/agents/a2c.py | 4 +- pfrl/agents/a3c.py | 6 +-- pfrl/agents/acer.py | 8 ++-- pfrl/agents/al.py | 1 + pfrl/agents/ddpg.py | 8 ++-- pfrl/agents/dpp.py | 3 +- pfrl/agents/dqn.py | 40 +++++++++---------- pfrl/agents/iqn.py | 3 +- pfrl/agents/ppo.py | 14 ++++--- pfrl/agents/reinforce.py | 4 +- pfrl/agents/soft_actor_critic.py | 8 ++-- pfrl/agents/state_q_function_actor.py | 10 +++-- pfrl/agents/td3.py | 8 ++-- pfrl/agents/trpo.py | 32 ++++++++------- pfrl/collections/persistent_collections.py | 5 +-- pfrl/distributions/delta.py | 2 +- pfrl/env.py | 3 +- pfrl/envs/abc.py | 2 +- pfrl/envs/multiprocess_vector_env.py | 3 +- pfrl/experiments/__init__.py | 5 +-- pfrl/experiments/hooks.py | 3 +- pfrl/experiments/prepare_output_dir.py | 2 +- pfrl/experiments/train_agent.py | 3 +- pfrl/experiments/train_agent_async.py | 15 ++++--- pfrl/experiments/train_agent_batch.py | 6 +-- pfrl/explorer.py | 3 +- pfrl/explorers/__init__.py | 2 +- pfrl/explorers/boltzmann.py | 2 +- pfrl/functions/lower_triangular_matrix.py | 2 +- pfrl/initializers/__init__.py | 3 +- pfrl/initializers/chainer_default.py | 1 + pfrl/nn/__init__.py | 8 ++-- pfrl/nn/atari_cnn.py | 1 + pfrl/nn/empirical_normalization.py | 1 - pfrl/nn/mlp.py | 4 +- pfrl/nn/mlp_bn.py | 1 + pfrl/nn/noisy_linear.py | 1 - pfrl/nn/recurrent_sequential.py | 10 +++-- pfrl/policy.py | 5 +-- pfrl/q_function.py | 3 +- pfrl/q_functions/dueling_dqn.py | 2 +- pfrl/q_functions/state_action_q_functions.py | 8 ++-- pfrl/q_functions/state_q_functions.py | 17 ++++---- pfrl/replay_buffer.py | 12 +++--- pfrl/replay_buffers/episodic.py | 3 +- pfrl/replay_buffers/persistent.py | 6 +-- pfrl/replay_buffers/prioritized_episodic.py | 5 +-- pfrl/replay_buffers/replay_buffer.py | 2 +- pfrl/utils/__init__.py | 8 ++-- pfrl/utils/async_.py | 3 +- pfrl/utils/batch_states.py | 4 +- pfrl/utils/random_seed.py | 3 +- pfrl/wrappers/__init__.py | 7 ---- pfrl/wrappers/atari_wrappers.py | 2 - pfrl/wrappers/monitor.py | 2 +- pfrl/wrappers/vector_frame_stack.py | 2 +- tests/agents_tests/basetest_ddpg.py | 9 ++--- tests/agents_tests/basetest_dqn_like.py | 6 +-- tests/agents_tests/basetest_training.py | 16 +++++--- tests/agents_tests/test_a2c.py | 6 ++- tests/agents_tests/test_a3c.py | 11 ++--- tests/agents_tests/test_acer.py | 9 ++--- tests/agents_tests/test_al.py | 1 + tests/agents_tests/test_categorical_dqn.py | 17 +++----- tests/agents_tests/test_ddpg.py | 4 +- .../test_double_categorical_dqn.py | 3 +- tests/agents_tests/test_double_dqn.py | 1 + tests/agents_tests/test_double_pal.py | 4 +- tests/agents_tests/test_dpp.py | 7 ++-- tests/agents_tests/test_dqn.py | 10 ++--- tests/agents_tests/test_iqn.py | 8 ++-- tests/agents_tests/test_pal.py | 1 + tests/agents_tests/test_ppo.py | 24 ++++++----- tests/agents_tests/test_reinforce.py | 10 +++-- tests/agents_tests/test_soft_actor_critic.py | 16 ++++---- tests/agents_tests/test_td3.py | 12 ++++-- tests/agents_tests/test_trpo.py | 18 ++++++--- tests/collections_tests/test_prioritized.py | 2 +- .../test_prepare_output_dir.py | 3 +- .../test_train_agent_async.py | 2 +- tests/explorers_tests/test_boltzmann.py | 2 +- .../test_lower_triangular_matrix.py | 4 +- tests/nn_tests/test_mlp_bn.py | 8 ++-- tests/nn_tests/test_noisy_chain.py | 5 ++- tests/nn_tests/test_noisy_linear.py | 5 +-- tests/nn_tests/test_recurrent_branched.py | 13 +++--- tests/nn_tests/test_recurrent_sequential.py | 6 +-- .../basetest_state_action_q_function.py | 2 +- .../test_state_action_q_function.py | 8 ++-- .../replay_buffers_test/test_replay_buffer.py | 5 +-- tests/utils_tests/test_batch_states.py | 4 +- tests/utils_tests/test_clip_l2_grad_norm.py | 2 +- tests/utils_tests/test_copy_param.py | 4 +- .../utils_tests/test_mode_of_distribution.py | 2 +- tests/utils_tests/test_random_seed.py | 5 ++- tests/utils_tests/test_recurrent.py | 8 ++-- tests/utils_tests/test_stoppable_thread.py | 2 +- tests/wrappers_tests/test_atari_wrappers.py | 4 +- tests/wrappers_tests/test_monitor.py | 2 +- tests/wrappers_tests/test_render.py | 1 + .../wrappers_tests/test_vector_frame_stack.py | 6 +-- 126 files changed, 366 insertions(+), 456 deletions(-) diff --git a/examples/atari/reproduction/a3c/train_a3c.py b/examples/atari/reproduction/a3c/train_a3c.py index 3bfb8622a..b86e55fb4 100644 --- a/examples/atari/reproduction/a3c/train_a3c.py +++ b/examples/atari/reproduction/a3c/train_a3c.py @@ -8,12 +8,10 @@ from torch import nn import pfrl +from pfrl import experiments, utils from pfrl.agents import a3c -from pfrl import experiments -from pfrl import utils -from pfrl.policies import SoftmaxCategoricalHead from pfrl.optimizers import SharedRMSpropEpsInsideSqrt - +from pfrl.policies import SoftmaxCategoricalHead from pfrl.wrappers import atari_wrappers diff --git a/examples/atari/reproduction/dqn/train_dqn.py b/examples/atari/reproduction/dqn/train_dqn.py index 89572f09b..1677109e6 100644 --- a/examples/atari/reproduction/dqn/train_dqn.py +++ b/examples/atari/reproduction/dqn/train_dqn.py @@ -1,21 +1,17 @@ import argparse +import json import os -import torch.nn as nn import numpy as np +import torch.nn as nn import pfrl -from pfrl.q_functions import DiscreteActionValueHead -from pfrl import agents -from pfrl import experiments -from pfrl import explorers +from pfrl import agents, experiments, explorers from pfrl import nn as pnn -from pfrl import utils -from pfrl import replay_buffers - -from pfrl.wrappers import atari_wrappers +from pfrl import replay_buffers, utils from pfrl.initializers import init_chainer_default -import json +from pfrl.q_functions import DiscreteActionValueHead +from pfrl.wrappers import atari_wrappers def main(): diff --git a/examples/atari/reproduction/iqn/train_iqn.py b/examples/atari/reproduction/iqn/train_iqn.py index 205f4e362..0e605467d 100644 --- a/examples/atari/reproduction/iqn/train_iqn.py +++ b/examples/atari/reproduction/iqn/train_iqn.py @@ -7,10 +7,7 @@ from torch import nn import pfrl -from pfrl import experiments -from pfrl import explorers -from pfrl import utils -from pfrl import replay_buffers +from pfrl import experiments, explorers, replay_buffers, utils from pfrl.wrappers import atari_wrappers diff --git a/examples/atari/reproduction/rainbow/train_rainbow.py b/examples/atari/reproduction/rainbow/train_rainbow.py index 7e2b75bc1..5edbacc4b 100644 --- a/examples/atari/reproduction/rainbow/train_rainbow.py +++ b/examples/atari/reproduction/rainbow/train_rainbow.py @@ -2,17 +2,14 @@ import json import os -import torch import numpy as np +import torch import pfrl -from pfrl import agents -from pfrl import experiments -from pfrl import explorers +from pfrl import agents, experiments, explorers from pfrl import nn as pnn -from pfrl import utils +from pfrl import replay_buffers, utils from pfrl.q_functions import DistributionalDuelingDQN -from pfrl import replay_buffers from pfrl.wrappers import atari_wrappers diff --git a/examples/atari/train_a2c_ale.py b/examples/atari/train_a2c_ale.py index 7e6dfa8d3..16a90a8ed 100644 --- a/examples/atari/train_a2c_ale.py +++ b/examples/atari/train_a2c_ale.py @@ -6,9 +6,8 @@ from torch import nn import pfrl +from pfrl import experiments, utils from pfrl.agents import a2c -from pfrl import experiments -from pfrl import utils from pfrl.policies import SoftmaxCategoricalHead from pfrl.wrappers import atari_wrappers diff --git a/examples/atari/train_acer_ale.py b/examples/atari/train_acer_ale.py index f21c5fdd4..80824e8fe 100644 --- a/examples/atari/train_acer_ale.py +++ b/examples/atari/train_acer_ale.py @@ -10,13 +10,11 @@ from torch import nn import pfrl +from pfrl import experiments, utils from pfrl.agents import acer -from pfrl import experiments -from pfrl import utils -from pfrl.replay_buffers import EpisodicReplayBuffer from pfrl.policies import SoftmaxCategoricalHead from pfrl.q_functions import DiscreteActionValueHead - +from pfrl.replay_buffers import EpisodicReplayBuffer from pfrl.wrappers import atari_wrappers diff --git a/examples/atari/train_categorical_dqn_ale.py b/examples/atari/train_categorical_dqn_ale.py index 61744481f..4620209a0 100644 --- a/examples/atari/train_categorical_dqn_ale.py +++ b/examples/atari/train_categorical_dqn_ale.py @@ -1,14 +1,10 @@ import argparse -import torch import numpy as np +import torch import pfrl -from pfrl import experiments -from pfrl import explorers -from pfrl import utils -from pfrl import replay_buffers - +from pfrl import experiments, explorers, replay_buffers, utils from pfrl.wrappers import atari_wrappers diff --git a/examples/atari/train_dqn_ale.py b/examples/atari/train_dqn_ale.py index bd68edbab..7f7dcc213 100644 --- a/examples/atari/train_dqn_ale.py +++ b/examples/atari/train_dqn_ale.py @@ -1,21 +1,16 @@ import argparse +import numpy as np import torch import torch.nn as nn -import numpy as np import pfrl -from pfrl.q_functions import DiscreteActionValueHead -from pfrl import agents -from pfrl import experiments -from pfrl import explorers +from pfrl import agents, experiments, explorers from pfrl import nn as pnn -from pfrl import utils -from pfrl.q_functions import DuelingDQN -from pfrl import replay_buffers - -from pfrl.wrappers import atari_wrappers +from pfrl import replay_buffers, utils from pfrl.initializers import init_chainer_default +from pfrl.q_functions import DiscreteActionValueHead, DuelingDQN +from pfrl.wrappers import atari_wrappers class SingleSharedBias(nn.Module): diff --git a/examples/atari/train_dqn_batch_ale.py b/examples/atari/train_dqn_batch_ale.py index 25bd52dce..413ab830d 100644 --- a/examples/atari/train_dqn_batch_ale.py +++ b/examples/atari/train_dqn_batch_ale.py @@ -1,23 +1,18 @@ import argparse import functools +import numpy as np import torch import torch.nn as nn import torch.optim as optim -import numpy as np import pfrl -from pfrl import agents -from pfrl import experiments -from pfrl import explorers +from pfrl import agents, experiments, explorers from pfrl import nn as pnn -from pfrl import utils -from pfrl.q_functions import DiscreteActionValueHead -from pfrl.q_functions import DuelingDQN -from pfrl import replay_buffers - -from pfrl.wrappers import atari_wrappers +from pfrl import replay_buffers, utils from pfrl.initializers import init_chainer_default +from pfrl.q_functions import DiscreteActionValueHead, DuelingDQN +from pfrl.wrappers import atari_wrappers class SingleSharedBias(nn.Module): diff --git a/examples/atari/train_drqn_ale.py b/examples/atari/train_drqn_ale.py index 9917bb2ea..0b1b3e219 100644 --- a/examples/atari/train_drqn_ale.py +++ b/examples/atari/train_drqn_ale.py @@ -18,12 +18,8 @@ from torch import nn import pfrl -from pfrl import experiments -from pfrl import explorers -from pfrl import utils -from pfrl import replay_buffers +from pfrl import experiments, explorers, replay_buffers, utils from pfrl.q_functions import DiscreteActionValueHead - from pfrl.wrappers import atari_wrappers diff --git a/examples/atari/train_ppo_ale.py b/examples/atari/train_ppo_ale.py index 68fb28539..aa6e10776 100644 --- a/examples/atari/train_ppo_ale.py +++ b/examples/atari/train_ppo_ale.py @@ -15,11 +15,10 @@ from torch import nn import pfrl +from pfrl import experiments, utils from pfrl.agents import PPO -from pfrl import experiments -from pfrl import utils -from pfrl.wrappers import atari_wrappers from pfrl.policies import SoftmaxCategoricalHead +from pfrl.wrappers import atari_wrappers def main(): diff --git a/examples/atlas/train_soft_actor_critic_atlas.py b/examples/atlas/train_soft_actor_critic_atlas.py index 6e1f4a559..dafda918b 100644 --- a/examples/atlas/train_soft_actor_critic_atlas.py +++ b/examples/atlas/train_soft_actor_critic_atlas.py @@ -8,14 +8,11 @@ import gym.wrappers import numpy as np import torch -from torch import nn -from torch import distributions +from torch import distributions, nn import pfrl -from pfrl import experiments +from pfrl import experiments, replay_buffers, utils from pfrl.nn.lmbda import Lambda -from pfrl import utils -from pfrl import replay_buffers def make_env(args, seed, test): diff --git a/examples/grasping/train_dqn_batch_grasping.py b/examples/grasping/train_dqn_batch_grasping.py index 8e78ce06f..2e458e090 100644 --- a/examples/grasping/train_dqn_batch_grasping.py +++ b/examples/grasping/train_dqn_batch_grasping.py @@ -9,11 +9,8 @@ from torch import nn import pfrl +from pfrl import experiments, explorers, replay_buffers, utils from pfrl.q_functions import DiscreteActionValueHead -from pfrl import experiments -from pfrl import explorers -from pfrl import utils -from pfrl import replay_buffers class CastAction(gym.ActionWrapper): @@ -243,9 +240,9 @@ def main(): max_episode_steps = 8 def make_env(idx, test): - from pybullet_envs.bullet.kuka_diverse_object_gym_env import ( + from pybullet_envs.bullet.kuka_diverse_object_gym_env import ( # NOQA KukaDiverseObjectEnv, - ) # NOQA + ) # Use different random seeds for train and test envs process_seed = int(process_seeds[idx]) diff --git a/examples/gym/train_categorical_dqn_gym.py b/examples/gym/train_categorical_dqn_gym.py index a4e4b7ea2..9971c5a21 100644 --- a/examples/gym/train_categorical_dqn_gym.py +++ b/examples/gym/train_categorical_dqn_gym.py @@ -10,15 +10,11 @@ import argparse import sys -import torch import gym +import torch import pfrl -from pfrl import experiments -from pfrl import explorers -from pfrl import utils -from pfrl import q_functions -from pfrl import replay_buffers +from pfrl import experiments, explorers, q_functions, replay_buffers, utils def main(): diff --git a/examples/gym/train_dqn_gym.py b/examples/gym/train_dqn_gym.py index 2714d9cdd..bb9a0bd49 100644 --- a/examples/gym/train_dqn_gym.py +++ b/examples/gym/train_dqn_gym.py @@ -12,22 +12,19 @@ """ import argparse -import sys import os +import sys -import torch.optim as optim import gym -from gym import spaces import numpy as np +import torch.optim as optim +from gym import spaces import pfrl -from pfrl.agents.dqn import DQN -from pfrl import experiments -from pfrl import explorers +from pfrl import experiments, explorers from pfrl import nn as pnn -from pfrl import utils -from pfrl import q_functions -from pfrl import replay_buffers +from pfrl import q_functions, replay_buffers, utils +from pfrl.agents.dqn import DQN def main(): diff --git a/examples/gym/train_reinforce_gym.py b/examples/gym/train_reinforce_gym.py index d334356f4..de187cb0d 100644 --- a/examples/gym/train_reinforce_gym.py +++ b/examples/gym/train_reinforce_gym.py @@ -17,10 +17,8 @@ from torch import nn import pfrl -from pfrl import experiments -from pfrl import utils -from pfrl.policies import SoftmaxCategoricalHead -from pfrl.policies import GaussianHeadWithFixedCovariance +from pfrl import experiments, utils +from pfrl.policies import GaussianHeadWithFixedCovariance, SoftmaxCategoricalHead def main(): diff --git a/examples/mujoco/reproduction/ddpg/train_ddpg.py b/examples/mujoco/reproduction/ddpg/train_ddpg.py index 8fe180128..d8dca96a6 100644 --- a/examples/mujoco/reproduction/ddpg/train_ddpg.py +++ b/examples/mujoco/reproduction/ddpg/train_ddpg.py @@ -15,13 +15,9 @@ from torch import nn import pfrl +from pfrl import experiments, explorers, replay_buffers, utils from pfrl.agents.ddpg import DDPG -from pfrl import experiments -from pfrl import explorers -from pfrl import utils -from pfrl import replay_buffers -from pfrl.nn import ConcatObsAndAction -from pfrl.nn import BoundByTanh +from pfrl.nn import BoundByTanh, ConcatObsAndAction from pfrl.policies import DeterministicHead diff --git a/examples/mujoco/reproduction/ppo/train_ppo.py b/examples/mujoco/reproduction/ppo/train_ppo.py index b2802f88e..4ef24e37c 100644 --- a/examples/mujoco/reproduction/ppo/train_ppo.py +++ b/examples/mujoco/reproduction/ppo/train_ppo.py @@ -13,9 +13,8 @@ from torch import nn import pfrl +from pfrl import experiments, utils from pfrl.agents import PPO -from pfrl import experiments -from pfrl import utils def main(): diff --git a/examples/mujoco/reproduction/soft_actor_critic/train_soft_actor_critic.py b/examples/mujoco/reproduction/soft_actor_critic/train_soft_actor_critic.py index 158de19f9..91be09af6 100644 --- a/examples/mujoco/reproduction/soft_actor_critic/train_soft_actor_critic.py +++ b/examples/mujoco/reproduction/soft_actor_critic/train_soft_actor_critic.py @@ -4,23 +4,20 @@ as possible. """ import argparse -from distutils.version import LooseVersion import functools import logging import sys +from distutils.version import LooseVersion -import torch -from torch import nn -from torch import distributions import gym import gym.wrappers import numpy as np +import torch +from torch import distributions, nn import pfrl -from pfrl import experiments +from pfrl import experiments, replay_buffers, utils from pfrl.nn.lmbda import Lambda -from pfrl import utils -from pfrl import replay_buffers def main(): diff --git a/examples/mujoco/reproduction/td3/train_td3.py b/examples/mujoco/reproduction/td3/train_td3.py index e0ffb6077..64d978c1d 100644 --- a/examples/mujoco/reproduction/td3/train_td3.py +++ b/examples/mujoco/reproduction/td3/train_td3.py @@ -15,10 +15,7 @@ from torch import nn import pfrl -from pfrl import experiments -from pfrl import explorers -from pfrl import utils -from pfrl import replay_buffers +from pfrl import experiments, explorers, replay_buffers, utils def main(): diff --git a/examples/slimevolley/train_rainbow.py b/examples/slimevolley/train_rainbow.py index 3960c1d80..1579a89ab 100644 --- a/examples/slimevolley/train_rainbow.py +++ b/examples/slimevolley/train_rainbow.py @@ -2,17 +2,14 @@ import gym import gym.spaces +import numpy as np import torch from torch import nn -import numpy as np import pfrl -from pfrl import agents -from pfrl import experiments -from pfrl import explorers +from pfrl import agents, experiments, explorers from pfrl import nn as pnn -from pfrl import utils -from pfrl import replay_buffers +from pfrl import replay_buffers, utils class MultiBinaryAsDiscreteAction(gym.ActionWrapper): diff --git a/pfrl/__init__.py b/pfrl/__init__.py index 061bc47f5..779b1b8b0 100644 --- a/pfrl/__init__.py +++ b/pfrl/__init__.py @@ -9,7 +9,6 @@ from pfrl import explorers # NOQA from pfrl import functions # NOQA from pfrl import nn # NOQA -from pfrl import utils # NOQA from pfrl import optimizers # NOQA from pfrl import policies # NOQA from pfrl import policy # NOQA @@ -17,4 +16,5 @@ from pfrl import q_functions # NOQA from pfrl import replay_buffer # NOQA from pfrl import replay_buffers # NOQA +from pfrl import utils # NOQA from pfrl import wrappers # NOQA diff --git a/pfrl/action_value.py b/pfrl/action_value.py index 9be0d65ba..f6f697042 100644 --- a/pfrl/action_value.py +++ b/pfrl/action_value.py @@ -1,7 +1,5 @@ -from abc import ABCMeta -from abc import abstractmethod -from abc import abstractproperty import warnings +from abc import ABCMeta, abstractmethod, abstractproperty import torch import torch.nn.functional as F diff --git a/pfrl/agent.py b/pfrl/agent.py index 128e10d90..6b612f013 100644 --- a/pfrl/agent.py +++ b/pfrl/agent.py @@ -1,13 +1,7 @@ -from abc import ABCMeta -from abc import abstractmethod -from abc import abstractproperty import contextlib import os -from typing import Any -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple +from abc import ABCMeta, abstractmethod, abstractproperty +from typing import Any, List, Optional, Sequence, Tuple import torch diff --git a/pfrl/agents/__init__.py b/pfrl/agents/__init__.py index b8cd64644..1f225d9ec 100644 --- a/pfrl/agents/__init__.py +++ b/pfrl/agents/__init__.py @@ -14,6 +14,6 @@ from pfrl.agents.ppo import PPO # NOQA from pfrl.agents.reinforce import REINFORCE # NOQA from pfrl.agents.soft_actor_critic import SoftActorCritic # NOQA +from pfrl.agents.state_q_function_actor import StateQFunctionActor # NOQA from pfrl.agents.td3 import TD3 # NOQA from pfrl.agents.trpo import TRPO # NOQA -from pfrl.agents.state_q_function_actor import StateQFunctionActor # NOQA diff --git a/pfrl/agents/a2c.py b/pfrl/agents/a2c.py index 0091fea69..615189a03 100644 --- a/pfrl/agents/a2c.py +++ b/pfrl/agents/a2c.py @@ -1,12 +1,12 @@ -from logging import getLogger import warnings +from logging import getLogger import torch from pfrl import agent +from pfrl.utils import clip_l2_grad_norm_ from pfrl.utils.batch_states import batch_states from pfrl.utils.mode_of_distribution import mode_of_distribution -from pfrl.utils import clip_l2_grad_norm_ logger = getLogger(__name__) diff --git a/pfrl/agents/a3c.py b/pfrl/agents/a3c.py index 135788c9f..13b0a0e72 100644 --- a/pfrl/agents/a3c.py +++ b/pfrl/agents/a3c.py @@ -6,12 +6,10 @@ import pfrl from pfrl import agent +from pfrl.utils import clip_l2_grad_norm_, copy_param from pfrl.utils.batch_states import batch_states -from pfrl.utils import copy_param from pfrl.utils.mode_of_distribution import mode_of_distribution -from pfrl.utils.recurrent import one_step_forward -from pfrl.utils.recurrent import pack_and_forward -from pfrl.utils import clip_l2_grad_norm_ +from pfrl.utils.recurrent import one_step_forward, pack_and_forward logger = getLogger(__name__) diff --git a/pfrl/agents/acer.py b/pfrl/agents/acer.py index 4114682dd..b1c4e8f8f 100644 --- a/pfrl/agents/acer.py +++ b/pfrl/agents/acer.py @@ -5,14 +5,12 @@ import torch from torch import nn -from pfrl.action_value import SingleActionValue from pfrl import agent +from pfrl.action_value import SingleActionValue +from pfrl.utils import clip_l2_grad_norm_, copy_param from pfrl.utils.batch_states import batch_states -from pfrl.utils import copy_param from pfrl.utils.mode_of_distribution import mode_of_distribution -from pfrl.utils.recurrent import one_step_forward -from pfrl.utils.recurrent import detach_recurrent_state -from pfrl.utils import clip_l2_grad_norm_ +from pfrl.utils.recurrent import detach_recurrent_state, one_step_forward def compute_importance(pi, mu, x): diff --git a/pfrl/agents/al.py b/pfrl/agents/al.py index 54227b9d9..1ff88deac 100644 --- a/pfrl/agents/al.py +++ b/pfrl/agents/al.py @@ -1,4 +1,5 @@ import torch + from pfrl.agents import dqn from pfrl.utils.recurrent import pack_and_forward diff --git a/pfrl/agents/ddpg.py b/pfrl/agents/ddpg.py index 2ad73ef4c..9d2d15589 100644 --- a/pfrl/agents/ddpg.py +++ b/pfrl/agents/ddpg.py @@ -2,18 +2,16 @@ import copy from logging import getLogger +import numpy as np import torch from torch import nn from torch.nn import functional as F -import numpy as np -from pfrl.agent import AttributeSavingMixin -from pfrl.agent import BatchAgent +from pfrl.agent import AttributeSavingMixin, BatchAgent +from pfrl.replay_buffer import ReplayUpdater, batch_experiences from pfrl.utils.batch_states import batch_states from pfrl.utils.contexts import evaluating from pfrl.utils.copy_param import synchronize_parameters -from pfrl.replay_buffer import batch_experiences -from pfrl.replay_buffer import ReplayUpdater def _mean_or_nan(xs): diff --git a/pfrl/agents/dpp.py b/pfrl/agents/dpp.py index 11044eefb..df6ceb74a 100644 --- a/pfrl/agents/dpp.py +++ b/pfrl/agents/dpp.py @@ -1,5 +1,4 @@ -from abc import ABCMeta -from abc import abstractmethod +from abc import ABCMeta, abstractmethod import torch diff --git a/pfrl/agents/dqn.py b/pfrl/agents/dqn.py index 3fa766577..68b579734 100644 --- a/pfrl/agents/dqn.py +++ b/pfrl/agents/dqn.py @@ -1,39 +1,37 @@ -import copy import collections -import time +import copy import ctypes import multiprocessing as mp import multiprocessing.synchronize -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from logging import Logger -from logging import getLogger +import time +from logging import Logger, getLogger +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple +import numpy as np import torch import torch.nn.functional as F -import numpy as np import pfrl from pfrl import agent from pfrl.action_value import ActionValue from pfrl.explorer import Explorer +from pfrl.replay_buffer import ( + AbstractEpisodicReplayBuffer, + ReplayUpdater, + batch_experiences, + batch_recurrent_experiences, +) +from pfrl.replay_buffers import PrioritizedReplayBuffer from pfrl.utils.batch_states import batch_states from pfrl.utils.contexts import evaluating from pfrl.utils.copy_param import synchronize_parameters -from pfrl.replay_buffer import AbstractEpisodicReplayBuffer, batch_experiences -from pfrl.replay_buffer import batch_recurrent_experiences -from pfrl.replay_buffer import ReplayUpdater -from pfrl.replay_buffers import PrioritizedReplayBuffer -from pfrl.utils.recurrent import get_recurrent_state_at -from pfrl.utils.recurrent import mask_recurrent_state_at -from pfrl.utils.recurrent import one_step_forward -from pfrl.utils.recurrent import pack_and_forward -from pfrl.utils.recurrent import recurrent_state_as_numpy +from pfrl.utils.recurrent import ( + get_recurrent_state_at, + mask_recurrent_state_at, + one_step_forward, + pack_and_forward, + recurrent_state_as_numpy, +) def _mean_or_nan(xs: Sequence[float]) -> float: diff --git a/pfrl/agents/iqn.py b/pfrl/agents/iqn.py index 91567b4af..d5e132c0d 100644 --- a/pfrl/agents/iqn.py +++ b/pfrl/agents/iqn.py @@ -5,8 +5,7 @@ from pfrl.action_value import QuantileDiscreteActionValue from pfrl.agents import dqn from pfrl.nn import Recurrent -from pfrl.utils.recurrent import one_step_forward -from pfrl.utils.recurrent import pack_and_forward +from pfrl.utils.recurrent import one_step_forward, pack_and_forward def cosine_basis_functions(x, n_basis_functions=64): diff --git a/pfrl/agents/ppo.py b/pfrl/agents/ppo.py index 0468edecf..44b9bf2fc 100644 --- a/pfrl/agents/ppo.py +++ b/pfrl/agents/ppo.py @@ -10,12 +10,14 @@ from pfrl import agent from pfrl.utils.batch_states import batch_states from pfrl.utils.mode_of_distribution import mode_of_distribution -from pfrl.utils.recurrent import get_recurrent_state_at -from pfrl.utils.recurrent import mask_recurrent_state_at -from pfrl.utils.recurrent import concatenate_recurrent_states -from pfrl.utils.recurrent import one_step_forward -from pfrl.utils.recurrent import flatten_sequences_time_first -from pfrl.utils.recurrent import pack_and_forward +from pfrl.utils.recurrent import ( + concatenate_recurrent_states, + flatten_sequences_time_first, + get_recurrent_state_at, + mask_recurrent_state_at, + one_step_forward, + pack_and_forward, +) def _mean_or_nan(xs): diff --git a/pfrl/agents/reinforce.py b/pfrl/agents/reinforce.py index cf7b7672c..155ddcf15 100644 --- a/pfrl/agents/reinforce.py +++ b/pfrl/agents/reinforce.py @@ -1,14 +1,14 @@ -from logging import getLogger import warnings +from logging import getLogger import numpy as np import torch import pfrl from pfrl import agent +from pfrl.utils import clip_l2_grad_norm_ from pfrl.utils.mode_of_distribution import mode_of_distribution from pfrl.utils.recurrent import one_step_forward -from pfrl.utils import clip_l2_grad_norm_ class REINFORCE(agent.AttributeSavingMixin, agent.Agent): diff --git a/pfrl/agents/soft_actor_critic.py b/pfrl/agents/soft_actor_critic.py index 2e48ffe01..a3048c93d 100644 --- a/pfrl/agents/soft_actor_critic.py +++ b/pfrl/agents/soft_actor_critic.py @@ -8,14 +8,12 @@ from torch.nn import functional as F import pfrl -from pfrl.agent import AttributeSavingMixin -from pfrl.agent import BatchAgent +from pfrl.agent import AttributeSavingMixin, BatchAgent +from pfrl.replay_buffer import ReplayUpdater, batch_experiences +from pfrl.utils import clip_l2_grad_norm_ from pfrl.utils.batch_states import batch_states from pfrl.utils.copy_param import synchronize_parameters from pfrl.utils.mode_of_distribution import mode_of_distribution -from pfrl.replay_buffer import batch_experiences -from pfrl.replay_buffer import ReplayUpdater -from pfrl.utils import clip_l2_grad_norm_ def _mean_or_nan(xs): diff --git a/pfrl/agents/state_q_function_actor.py b/pfrl/agents/state_q_function_actor.py index 7a4c32531..97a708876 100644 --- a/pfrl/agents/state_q_function_actor.py +++ b/pfrl/agents/state_q_function_actor.py @@ -3,11 +3,13 @@ import torch from pfrl import agent -from pfrl.utils.batch_states import batch_states from pfrl.utils import evaluating -from pfrl.utils.recurrent import one_step_forward -from pfrl.utils.recurrent import get_recurrent_state_at -from pfrl.utils.recurrent import recurrent_state_as_numpy +from pfrl.utils.batch_states import batch_states +from pfrl.utils.recurrent import ( + get_recurrent_state_at, + one_step_forward, + recurrent_state_as_numpy, +) class StateQFunctionActor(agent.AsyncAgent): diff --git a/pfrl/agents/td3.py b/pfrl/agents/td3.py index 1e851fb05..2596494e6 100644 --- a/pfrl/agents/td3.py +++ b/pfrl/agents/td3.py @@ -7,13 +7,11 @@ from torch.nn import functional as F import pfrl -from pfrl.agent import AttributeSavingMixin -from pfrl.agent import BatchAgent +from pfrl.agent import AttributeSavingMixin, BatchAgent +from pfrl.replay_buffer import ReplayUpdater, batch_experiences +from pfrl.utils import clip_l2_grad_norm_ from pfrl.utils.batch_states import batch_states from pfrl.utils.copy_param import synchronize_parameters -from pfrl.replay_buffer import batch_experiences -from pfrl.replay_buffer import ReplayUpdater -from pfrl.utils import clip_l2_grad_norm_ def _mean_or_nan(xs): diff --git a/pfrl/agents/trpo.py b/pfrl/agents/trpo.py index 67ab538af..fca632366 100644 --- a/pfrl/agents/trpo.py +++ b/pfrl/agents/trpo.py @@ -1,6 +1,6 @@ import collections -from logging import getLogger import random +from logging import getLogger import numpy as np import torch @@ -9,22 +9,24 @@ import pfrl from pfrl import agent -from pfrl.agents.ppo import _compute_explained_variance -from pfrl.agents.ppo import _make_dataset -from pfrl.agents.ppo import _make_dataset_recurrent -from pfrl.agents.ppo import _yield_minibatches -from pfrl.agents.ppo import ( +from pfrl.agents.ppo import ( # NOQA + _compute_explained_variance, + _make_dataset, + _make_dataset_recurrent, + _yield_minibatches, _yield_subset_of_sequences_with_fixed_number_of_items, -) # NOQA -from pfrl.utils.mode_of_distribution import mode_of_distribution -from pfrl.utils.batch_states import batch_states -from pfrl.utils.recurrent import flatten_sequences_time_first -from pfrl.utils.recurrent import one_step_forward -from pfrl.utils.recurrent import pack_and_forward -from pfrl.utils.recurrent import get_recurrent_state_at -from pfrl.utils.recurrent import mask_recurrent_state_at -from pfrl.utils.recurrent import concatenate_recurrent_states +) from pfrl.utils import clip_l2_grad_norm_ +from pfrl.utils.batch_states import batch_states +from pfrl.utils.mode_of_distribution import mode_of_distribution +from pfrl.utils.recurrent import ( + concatenate_recurrent_states, + flatten_sequences_time_first, + get_recurrent_state_at, + mask_recurrent_state_at, + one_step_forward, + pack_and_forward, +) def _flatten_and_concat_variables(vs): diff --git a/pfrl/collections/persistent_collections.py b/pfrl/collections/persistent_collections.py index b73404c3d..885097793 100644 --- a/pfrl/collections/persistent_collections.py +++ b/pfrl/collections/persistent_collections.py @@ -1,13 +1,12 @@ import binascii import collections -from datetime import datetime -from struct import pack, unpack, calcsize import os import pickle +from datetime import datetime +from struct import calcsize, pack, unpack from pfrl.collections.random_access_queue import RandomAccessQueue - # code for future extension. `_VanillaFS` is a dummy of chainerio's # FIleSystem class. _VanillaFS = collections.namedtuple("_VanillaFS", "exists open makedirs") diff --git a/pfrl/distributions/delta.py b/pfrl/distributions/delta.py index afe851230..d0324dfc3 100644 --- a/pfrl/distributions/delta.py +++ b/pfrl/distributions/delta.py @@ -1,7 +1,7 @@ from numbers import Number import torch -from torch.distributions import constraints, Distribution +from torch.distributions import Distribution, constraints class Delta(Distribution): diff --git a/pfrl/env.py b/pfrl/env.py index b751bdb43..ad28310a4 100644 --- a/pfrl/env.py +++ b/pfrl/env.py @@ -1,5 +1,4 @@ -from abc import ABCMeta -from abc import abstractmethod +from abc import ABCMeta, abstractmethod class Env(object, metaclass=ABCMeta): diff --git a/pfrl/envs/abc.py b/pfrl/envs/abc.py index c19999799..7018e6816 100644 --- a/pfrl/envs/abc.py +++ b/pfrl/envs/abc.py @@ -1,5 +1,5 @@ -from gym import spaces import numpy as np +from gym import spaces from pfrl import env diff --git a/pfrl/envs/multiprocess_vector_env.py b/pfrl/envs/multiprocess_vector_env.py index 47a94cffd..a993e1940 100644 --- a/pfrl/envs/multiprocess_vector_env.py +++ b/pfrl/envs/multiprocess_vector_env.py @@ -1,7 +1,6 @@ -from multiprocessing import Pipe -from multiprocessing import Process import signal import warnings +from multiprocessing import Pipe, Process import numpy as np from torch.distributions.utils import lazy_property diff --git a/pfrl/experiments/__init__.py b/pfrl/experiments/__init__.py index d62bc918f..cc7a2806a 100644 --- a/pfrl/experiments/__init__.py +++ b/pfrl/experiments/__init__.py @@ -1,12 +1,9 @@ from pfrl.experiments.evaluator import eval_performance # NOQA - from pfrl.experiments.hooks import LinearInterpolationHook # NOQA from pfrl.experiments.hooks import StepHook # NOQA - -from pfrl.experiments.prepare_output_dir import is_under_git_control # NOQA from pfrl.experiments.prepare_output_dir import generate_exp_id # NOQA +from pfrl.experiments.prepare_output_dir import is_under_git_control # NOQA from pfrl.experiments.prepare_output_dir import prepare_output_dir # NOQA - from pfrl.experiments.train_agent import train_agent # NOQA from pfrl.experiments.train_agent import train_agent_with_evaluation # NOQA from pfrl.experiments.train_agent_async import train_agent_async # NOQA diff --git a/pfrl/experiments/hooks.py b/pfrl/experiments/hooks.py index 311b494ef..47e8a162d 100644 --- a/pfrl/experiments/hooks.py +++ b/pfrl/experiments/hooks.py @@ -1,5 +1,4 @@ -from abc import ABCMeta -from abc import abstractmethod +from abc import ABCMeta, abstractmethod import numpy as np diff --git a/pfrl/experiments/prepare_output_dir.py b/pfrl/experiments/prepare_output_dir.py index 9574aa706..377bf5102 100644 --- a/pfrl/experiments/prepare_output_dir.py +++ b/pfrl/experiments/prepare_output_dir.py @@ -1,5 +1,4 @@ import argparse -from binascii import crc32 import datetime import json import os @@ -7,6 +6,7 @@ import shutil import subprocess import sys +from binascii import crc32 import pfrl diff --git a/pfrl/experiments/train_agent.py b/pfrl/experiments/train_agent.py index eff0310fe..9e716fc0c 100644 --- a/pfrl/experiments/train_agent.py +++ b/pfrl/experiments/train_agent.py @@ -1,8 +1,7 @@ import logging import os -from pfrl.experiments.evaluator import Evaluator -from pfrl.experiments.evaluator import save_agent +from pfrl.experiments.evaluator import Evaluator, save_agent from pfrl.utils.ask_yes_no import ask_yes_no diff --git a/pfrl/experiments/train_agent_async.py b/pfrl/experiments/train_agent_async.py index 4c3accdfc..2c26824e3 100644 --- a/pfrl/experiments/train_agent_async.py +++ b/pfrl/experiments/train_agent_async.py @@ -1,17 +1,16 @@ import logging -import torch.multiprocessing as mp import os -import torch -from torch import nn +import signal +import subprocess +import sys import numpy as np +import torch +import torch.multiprocessing as mp +from torch import nn from pfrl.experiments.evaluator import AsyncEvaluator -from pfrl.utils import async_ -from pfrl.utils import random_seed -import signal -import subprocess -import sys +from pfrl.utils import async_, random_seed def kill_all(): diff --git a/pfrl/experiments/train_agent_batch.py b/pfrl/experiments/train_agent_batch.py index 045b839a6..8ad3871fa 100644 --- a/pfrl/experiments/train_agent_batch.py +++ b/pfrl/experiments/train_agent_batch.py @@ -1,12 +1,10 @@ -from collections import deque import logging import os +from collections import deque import numpy as np - -from pfrl.experiments.evaluator import Evaluator -from pfrl.experiments.evaluator import save_agent +from pfrl.experiments.evaluator import Evaluator, save_agent def train_agent_batch( diff --git a/pfrl/explorer.py b/pfrl/explorer.py index 2997f465b..b262b2209 100644 --- a/pfrl/explorer.py +++ b/pfrl/explorer.py @@ -1,5 +1,4 @@ -from abc import ABCMeta -from abc import abstractmethod +from abc import ABCMeta, abstractmethod class Explorer(object, metaclass=ABCMeta): diff --git a/pfrl/explorers/__init__.py b/pfrl/explorers/__init__.py index 0c571d2ca..935fba103 100644 --- a/pfrl/explorers/__init__.py +++ b/pfrl/explorers/__init__.py @@ -2,6 +2,6 @@ from pfrl.explorers.additive_ou import AdditiveOU # NOQA from pfrl.explorers.boltzmann import Boltzmann # NOQA from pfrl.explorers.epsilon_greedy import ConstantEpsilonGreedy # NOQA -from pfrl.explorers.epsilon_greedy import LinearDecayEpsilonGreedy # NOQA from pfrl.explorers.epsilon_greedy import ExponentialDecayEpsilonGreedy # NOQA +from pfrl.explorers.epsilon_greedy import LinearDecayEpsilonGreedy # NOQA from pfrl.explorers.greedy import Greedy # NOQA diff --git a/pfrl/explorers/boltzmann.py b/pfrl/explorers/boltzmann.py index 4a87e001e..25d69ba97 100644 --- a/pfrl/explorers/boltzmann.py +++ b/pfrl/explorers/boltzmann.py @@ -1,6 +1,6 @@ +import numpy as np import torch import torch.nn.functional as F -import numpy as np import pfrl diff --git a/pfrl/functions/lower_triangular_matrix.py b/pfrl/functions/lower_triangular_matrix.py index d00f78292..76b898d71 100644 --- a/pfrl/functions/lower_triangular_matrix.py +++ b/pfrl/functions/lower_triangular_matrix.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch def set_batch_non_diagonal(array, non_diag_val): diff --git a/pfrl/initializers/__init__.py b/pfrl/initializers/__init__.py index bbaa3f22d..e6bfd5d5d 100644 --- a/pfrl/initializers/__init__.py +++ b/pfrl/initializers/__init__.py @@ -1,4 +1,3 @@ # Add lecun_normal weight initialization for networks in pytorch -from pfrl.initializers.lecun_normal import init_lecun_normal # NOQA - from pfrl.initializers.chainer_default import init_chainer_default # NOQA +from pfrl.initializers.lecun_normal import init_lecun_normal # NOQA diff --git a/pfrl/initializers/chainer_default.py b/pfrl/initializers/chainer_default.py index b5a638a18..1532a0db9 100644 --- a/pfrl/initializers/chainer_default.py +++ b/pfrl/initializers/chainer_default.py @@ -2,6 +2,7 @@ """ import torch import torch.nn as nn + from pfrl.initializers import init_lecun_normal diff --git a/pfrl/nn/__init__.py b/pfrl/nn/__init__.py index 80eda9333..dbab51d76 100644 --- a/pfrl/nn/__init__.py +++ b/pfrl/nn/__init__.py @@ -1,7 +1,10 @@ -from pfrl.nn.branched import Branched # NOQA from pfrl.nn.atari_cnn import LargeAtariCNN # NOQA from pfrl.nn.atari_cnn import SmallAtariCNN # NOQA +from pfrl.nn.bound_by_tanh import BoundByTanh # NOQA +from pfrl.nn.branched import Branched # NOQA +from pfrl.nn.concat_obs_and_action import ConcatObsAndAction # NOQA from pfrl.nn.empirical_normalization import EmpiricalNormalization # NOQA +from pfrl.nn.lmbda import Lambda # NOQA from pfrl.nn.mlp import MLP # NOQA from pfrl.nn.mlp_bn import MLPBN # NOQA from pfrl.nn.noisy_chain import to_factorized_noisy # NOQA @@ -9,6 +12,3 @@ from pfrl.nn.recurrent import Recurrent # NOQA from pfrl.nn.recurrent_branched import RecurrentBranched # NOQA from pfrl.nn.recurrent_sequential import RecurrentSequential # NOQA -from pfrl.nn.lmbda import Lambda # NOQA -from pfrl.nn.concat_obs_and_action import ConcatObsAndAction # NOQA -from pfrl.nn.bound_by_tanh import BoundByTanh # NOQA diff --git a/pfrl/nn/atari_cnn.py b/pfrl/nn/atari_cnn.py index b16326564..95dcdc33d 100644 --- a/pfrl/nn/atari_cnn.py +++ b/pfrl/nn/atari_cnn.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F + from pfrl.initializers import init_chainer_default diff --git a/pfrl/nn/empirical_normalization.py b/pfrl/nn/empirical_normalization.py index 7b0e638fd..a49560ec2 100644 --- a/pfrl/nn/empirical_normalization.py +++ b/pfrl/nn/empirical_normalization.py @@ -1,5 +1,4 @@ import numpy as np - import torch from torch import nn diff --git a/pfrl/nn/mlp.py b/pfrl/nn/mlp.py index efd4b2b39..55fdb89ca 100644 --- a/pfrl/nn/mlp.py +++ b/pfrl/nn/mlp.py @@ -1,7 +1,7 @@ import torch.nn as nn import torch.nn.functional as F -from pfrl.initializers import init_chainer_default -from pfrl.initializers import init_lecun_normal + +from pfrl.initializers import init_chainer_default, init_lecun_normal class MLP(nn.Module): diff --git a/pfrl/nn/mlp_bn.py b/pfrl/nn/mlp_bn.py index 434245623..f99016033 100644 --- a/pfrl/nn/mlp_bn.py +++ b/pfrl/nn/mlp_bn.py @@ -1,5 +1,6 @@ import torch.nn as nn import torch.nn.functional as F + from pfrl.initializers import init_lecun_normal diff --git a/pfrl/nn/noisy_linear.py b/pfrl/nn/noisy_linear.py index 4cfadc070..7fa9aa0d4 100644 --- a/pfrl/nn/noisy_linear.py +++ b/pfrl/nn/noisy_linear.py @@ -1,5 +1,4 @@ import numpy as np - import torch import torch.nn as nn import torch.nn.functional as F diff --git a/pfrl/nn/recurrent_sequential.py b/pfrl/nn/recurrent_sequential.py index a4ec07f19..8459cc6ee 100644 --- a/pfrl/nn/recurrent_sequential.py +++ b/pfrl/nn/recurrent_sequential.py @@ -1,10 +1,12 @@ from torch import nn from pfrl.nn.recurrent import Recurrent -from pfrl.utils.recurrent import is_recurrent -from pfrl.utils.recurrent import get_packed_sequence_info -from pfrl.utils.recurrent import wrap_packed_sequences_recursive -from pfrl.utils.recurrent import unwrap_packed_sequences_recursive +from pfrl.utils.recurrent import ( + get_packed_sequence_info, + is_recurrent, + unwrap_packed_sequences_recursive, + wrap_packed_sequences_recursive, +) class RecurrentSequential(Recurrent, nn.Sequential): diff --git a/pfrl/policy.py b/pfrl/policy.py index fa31af67f..e80315b25 100644 --- a/pfrl/policy.py +++ b/pfrl/policy.py @@ -1,7 +1,4 @@ -from abc import ABCMeta -from abc import abstractmethod - - +from abc import ABCMeta, abstractmethod from logging import getLogger logger = getLogger(__name__) diff --git a/pfrl/q_function.py b/pfrl/q_function.py index f418dbf53..ac9a5e4bb 100644 --- a/pfrl/q_function.py +++ b/pfrl/q_function.py @@ -1,5 +1,4 @@ -from abc import ABCMeta -from abc import abstractmethod +from abc import ABCMeta, abstractmethod class StateQFunction(object, metaclass=ABCMeta): diff --git a/pfrl/q_functions/dueling_dqn.py b/pfrl/q_functions/dueling_dqn.py index 483b0dc91..42a492503 100644 --- a/pfrl/q_functions/dueling_dqn.py +++ b/pfrl/q_functions/dueling_dqn.py @@ -3,9 +3,9 @@ import torch.nn.functional as F from pfrl import action_value +from pfrl.initializers import init_chainer_default from pfrl.nn.mlp import MLP from pfrl.q_function import StateQFunction -from pfrl.initializers import init_chainer_default def constant_bias_initializer(bias=0.0): diff --git a/pfrl/q_functions/state_action_q_functions.py b/pfrl/q_functions/state_action_q_functions.py index 8943603c3..ae88e16af 100644 --- a/pfrl/q_functions/state_action_q_functions.py +++ b/pfrl/q_functions/state_action_q_functions.py @@ -1,11 +1,11 @@ -from pfrl.nn.mlp import MLP -from pfrl.nn.mlp_bn import MLPBN -from pfrl.q_function import StateActionQFunction - import torch import torch.nn as nn import torch.nn.functional as F + from pfrl.initializers import init_lecun_normal +from pfrl.nn.mlp import MLP +from pfrl.nn.mlp_bn import MLPBN +from pfrl.q_function import StateActionQFunction class SingleModelStateActionQFunction(nn.Module, StateActionQFunction): diff --git a/pfrl/q_functions/state_q_functions.py b/pfrl/q_functions/state_q_functions.py index 5568f4595..1eedd573d 100644 --- a/pfrl/q_functions/state_q_functions.py +++ b/pfrl/q_functions/state_q_functions.py @@ -1,18 +1,19 @@ import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F -from pfrl.action_value import DiscreteActionValue -from pfrl.action_value import DistributionalDiscreteActionValue -from pfrl.action_value import QuadraticActionValue +from pfrl.action_value import ( + DiscreteActionValue, + DistributionalDiscreteActionValue, + QuadraticActionValue, +) from pfrl.functions.lower_triangular_matrix import lower_triangular_matrix +from pfrl.initializers import init_chainer_default from pfrl.nn import Lambda from pfrl.nn.mlp import MLP from pfrl.q_function import StateQFunction -import torch -import torch.nn as nn -import torch.nn.functional as F -from pfrl.initializers import init_chainer_default - def scale_by_tanh(x, low, high): scale = (high - low) / 2 diff --git a/pfrl/replay_buffer.py b/pfrl/replay_buffer.py index ec19e4d1c..7da0fd3f9 100644 --- a/pfrl/replay_buffer.py +++ b/pfrl/replay_buffer.py @@ -1,15 +1,15 @@ -from abc import ABCMeta -from abc import abstractmethod -from abc import abstractproperty +from abc import ABCMeta, abstractmethod, abstractproperty from typing import Optional import numpy as np import torch from pfrl.utils.batch_states import batch_states -from pfrl.utils.recurrent import concatenate_recurrent_states -from pfrl.utils.recurrent import flatten_sequences_time_first -from pfrl.utils.recurrent import recurrent_state_from_numpy +from pfrl.utils.recurrent import ( + concatenate_recurrent_states, + flatten_sequences_time_first, + recurrent_state_from_numpy, +) class AbstractReplayBuffer(object, metaclass=ABCMeta): diff --git a/pfrl/replay_buffers/episodic.py b/pfrl/replay_buffers/episodic.py index a4b6a471b..31e88b0e4 100644 --- a/pfrl/replay_buffers/episodic.py +++ b/pfrl/replay_buffers/episodic.py @@ -3,8 +3,7 @@ from typing import Optional from pfrl.collections.random_access_queue import RandomAccessQueue -from pfrl.replay_buffer import AbstractEpisodicReplayBuffer -from pfrl.replay_buffer import random_subseq +from pfrl.replay_buffer import AbstractEpisodicReplayBuffer, random_subseq class EpisodicReplayBuffer(AbstractEpisodicReplayBuffer): diff --git a/pfrl/replay_buffers/persistent.py b/pfrl/replay_buffers/persistent.py index df6f6b239..19342df2e 100644 --- a/pfrl/replay_buffers/persistent.py +++ b/pfrl/replay_buffers/persistent.py @@ -1,11 +1,11 @@ import os import warnings -from .replay_buffer import ReplayBuffer -from .episodic import EpisodicReplayBuffer - from pfrl.collections.persistent_collections import PersistentRandomAccessQueue +from .episodic import EpisodicReplayBuffer +from .replay_buffer import ReplayBuffer + class PersistentReplayBuffer(ReplayBuffer): """Experience replay buffer that are saved to disk storage diff --git a/pfrl/replay_buffers/prioritized_episodic.py b/pfrl/replay_buffers/prioritized_episodic.py index 53361c188..e31a74863 100644 --- a/pfrl/replay_buffers/prioritized_episodic.py +++ b/pfrl/replay_buffers/prioritized_episodic.py @@ -1,10 +1,9 @@ import collections -from pfrl.collections.random_access_queue import RandomAccessQueue from pfrl.collections.prioritized import PrioritizedBuffer +from pfrl.collections.random_access_queue import RandomAccessQueue from pfrl.replay_buffer import random_subseq -from pfrl.replay_buffers import EpisodicReplayBuffer -from pfrl.replay_buffers import PriorityWeightError +from pfrl.replay_buffers import EpisodicReplayBuffer, PriorityWeightError class PrioritizedEpisodicReplayBuffer(EpisodicReplayBuffer, PriorityWeightError): diff --git a/pfrl/replay_buffers/replay_buffer.py b/pfrl/replay_buffers/replay_buffer.py index 907f83094..0db496dd0 100644 --- a/pfrl/replay_buffers/replay_buffer.py +++ b/pfrl/replay_buffers/replay_buffer.py @@ -2,8 +2,8 @@ import pickle from typing import Optional -from pfrl.collections.random_access_queue import RandomAccessQueue from pfrl import replay_buffer +from pfrl.collections.random_access_queue import RandomAccessQueue class ReplayBuffer(replay_buffer.AbstractReplayBuffer): diff --git a/pfrl/utils/__init__.py b/pfrl/utils/__init__.py index d63288621..9ad17e560 100644 --- a/pfrl/utils/__init__.py +++ b/pfrl/utils/__init__.py @@ -1,9 +1,9 @@ +from pfrl.utils import env_modifiers # NOQA from pfrl.utils.batch_states import batch_states # NOQA +from pfrl.utils.clip_l2_grad_norm import clip_l2_grad_norm_ # NOQA from pfrl.utils.conjugate_gradient import conjugate_gradient # NOQA -from pfrl.utils import env_modifiers # NOQA +from pfrl.utils.contexts import evaluating # NOQA from pfrl.utils.is_return_code_zero import is_return_code_zero # NOQA -from pfrl.utils.random_seed import set_random_seed # NOQA from pfrl.utils.pretrained_models import download_model # NOQA -from pfrl.utils.contexts import evaluating # NOQA +from pfrl.utils.random_seed import set_random_seed # NOQA from pfrl.utils.stoppable_thread import StoppableThread # NOQA -from pfrl.utils.clip_l2_grad_norm import clip_l2_grad_norm_ # NOQA diff --git a/pfrl/utils/async_.py b/pfrl/utils/async_.py index 51788f573..a067d3b61 100644 --- a/pfrl/utils/async_.py +++ b/pfrl/utils/async_.py @@ -1,6 +1,7 @@ -import torch.multiprocessing as mp import warnings +import torch.multiprocessing as mp + class AbnormalExitWarning(Warning): """Warning category for abnormal subprocess exit.""" diff --git a/pfrl/utils/batch_states.py b/pfrl/utils/batch_states.py index fd7fdb55d..63a16371b 100644 --- a/pfrl/utils/batch_states.py +++ b/pfrl/utils/batch_states.py @@ -1,6 +1,4 @@ -from typing import Any -from typing import Callable -from typing import Sequence +from typing import Any, Callable, Sequence import torch from torch.utils.data._utils.collate import default_collate diff --git a/pfrl/utils/random_seed.py b/pfrl/utils/random_seed.py index 4fe8d523f..726c22fe2 100644 --- a/pfrl/utils/random_seed.py +++ b/pfrl/utils/random_seed.py @@ -1,6 +1,7 @@ import random -import torch + import numpy as np +import torch def set_random_seed(seed): diff --git a/pfrl/wrappers/__init__.py b/pfrl/wrappers/__init__.py index 7906b0646..0f3e99258 100644 --- a/pfrl/wrappers/__init__.py +++ b/pfrl/wrappers/__init__.py @@ -1,16 +1,9 @@ from pfrl.wrappers.cast_observation import CastObservation # NOQA from pfrl.wrappers.cast_observation import CastObservationToFloat32 # NOQA - from pfrl.wrappers.continuing_time_limit import ContinuingTimeLimit # NOQA - from pfrl.wrappers.monitor import Monitor # NOQA - from pfrl.wrappers.normalize_action_space import NormalizeActionSpace # NOQA - from pfrl.wrappers.randomize_action import RandomizeAction # NOQA - from pfrl.wrappers.render import Render # NOQA - from pfrl.wrappers.scale_reward import ScaleReward # NOQA - from pfrl.wrappers.vector_frame_stack import VectorFrameStack # NOQA diff --git a/pfrl/wrappers/atari_wrappers.py b/pfrl/wrappers/atari_wrappers.py index 40101ca9c..11db1f23c 100644 --- a/pfrl/wrappers/atari_wrappers.py +++ b/pfrl/wrappers/atari_wrappers.py @@ -6,12 +6,10 @@ import gym import numpy as np - from gym import spaces import pfrl - try: import cv2 diff --git a/pfrl/wrappers/monitor.py b/pfrl/wrappers/monitor.py index 0839b0d85..31153d9a1 100644 --- a/pfrl/wrappers/monitor.py +++ b/pfrl/wrappers/monitor.py @@ -1,5 +1,5 @@ -from logging import getLogger import time +from logging import getLogger from gym.wrappers import Monitor as _GymMonitor from gym.wrappers.monitoring.stats_recorder import StatsRecorder as _GymStatsRecorder diff --git a/pfrl/wrappers/vector_frame_stack.py b/pfrl/wrappers/vector_frame_stack.py index 1bb73f219..5596f5b87 100644 --- a/pfrl/wrappers/vector_frame_stack.py +++ b/pfrl/wrappers/vector_frame_stack.py @@ -1,7 +1,7 @@ from collections import deque -from gym import spaces import numpy as np +from gym import spaces from pfrl.env import VectorEnv from pfrl.wrappers.atari_wrappers import LazyFrames diff --git a/tests/agents_tests/basetest_ddpg.py b/tests/agents_tests/basetest_ddpg.py index a6298538a..40b40480b 100644 --- a/tests/agents_tests/basetest_ddpg.py +++ b/tests/agents_tests/basetest_ddpg.py @@ -1,17 +1,14 @@ import numpy as np import torch +from basetest_training import _TestTraining from torch import nn +from pfrl import replay_buffers from pfrl.envs.abc import ABC from pfrl.explorers.epsilon_greedy import LinearDecayEpsilonGreedy -from pfrl import replay_buffers -from pfrl.nn import RecurrentSequential -from pfrl.nn import ConcatObsAndAction -from pfrl.nn import BoundByTanh +from pfrl.nn import BoundByTanh, ConcatObsAndAction, RecurrentSequential from pfrl.policies import DeterministicHead -from basetest_training import _TestTraining - class _TestDDPGOnABC(_TestTraining): def make_agent(self, env, gpu): diff --git a/tests/agents_tests/basetest_dqn_like.py b/tests/agents_tests/basetest_dqn_like.py index 21645fa26..f959ea50a 100644 --- a/tests/agents_tests/basetest_dqn_like.py +++ b/tests/agents_tests/basetest_dqn_like.py @@ -1,14 +1,12 @@ import numpy as np import torch.nn as nn import torch.optim as optim +from basetest_training import _TestTraining +from pfrl import q_functions, replay_buffers from pfrl.envs.abc import ABC from pfrl.explorers.epsilon_greedy import LinearDecayEpsilonGreedy from pfrl.nn import RecurrentSequential -from pfrl import q_functions -from pfrl import replay_buffers - -from basetest_training import _TestTraining class _TestDQNLike(_TestTraining): diff --git a/tests/agents_tests/basetest_training.py b/tests/agents_tests/basetest_training.py index 067e0ac33..a7d97e501 100644 --- a/tests/agents_tests/basetest_training.py +++ b/tests/agents_tests/basetest_training.py @@ -1,16 +1,20 @@ import logging import os import tempfile -import pytest import numpy as np +import pytest import pfrl -from pfrl.experiments.evaluator import batch_run_evaluation_episodes -from pfrl.experiments.evaluator import run_evaluation_episodes -from pfrl.experiments import train_agent_async -from pfrl.experiments import train_agent_batch_with_evaluation -from pfrl.experiments import train_agent_with_evaluation +from pfrl.experiments import ( + train_agent_async, + train_agent_batch_with_evaluation, + train_agent_with_evaluation, +) +from pfrl.experiments.evaluator import ( + batch_run_evaluation_episodes, + run_evaluation_episodes, +) from pfrl.utils import random_seed diff --git a/tests/agents_tests/test_a2c.py b/tests/agents_tests/test_a2c.py index 62206f391..461142a8d 100644 --- a/tests/agents_tests/test_a2c.py +++ b/tests/agents_tests/test_a2c.py @@ -11,8 +11,10 @@ from pfrl.envs.abc import ABC from pfrl.experiments.evaluator import batch_run_evaluation_episodes from pfrl.nn import Branched -from pfrl.policies import SoftmaxCategoricalHead -from pfrl.policies import GaussianHeadWithStateIndependentCovariance +from pfrl.policies import ( + GaussianHeadWithStateIndependentCovariance, + SoftmaxCategoricalHead, +) @pytest.mark.parametrize("num_processes", [1, 3]) diff --git a/tests/agents_tests/test_a3c.py b/tests/agents_tests/test_a3c.py index f4b934428..e02cc56e8 100644 --- a/tests/agents_tests/test_a3c.py +++ b/tests/agents_tests/test_a3c.py @@ -11,12 +11,13 @@ import pfrl from pfrl.agents import a3c from pfrl.envs.abc import ABC -from pfrl.experiments.train_agent_async import train_agent_async from pfrl.experiments.evaluator import run_evaluation_episodes -from pfrl.nn import RecurrentBranched -from pfrl.nn import RecurrentSequential -from pfrl.policies import SoftmaxCategoricalHead -from pfrl.policies import GaussianHeadWithStateIndependentCovariance +from pfrl.experiments.train_agent_async import train_agent_async +from pfrl.nn import RecurrentBranched, RecurrentSequential +from pfrl.policies import ( + GaussianHeadWithStateIndependentCovariance, + SoftmaxCategoricalHead, +) class _TestA3C: diff --git a/tests/agents_tests/test_acer.py b/tests/agents_tests/test_acer.py index fc8bf5ba4..1b2223b41 100644 --- a/tests/agents_tests/test_acer.py +++ b/tests/agents_tests/test_acer.py @@ -12,13 +12,12 @@ import pfrl from pfrl.agents import acer from pfrl.envs.abc import ABC -from pfrl.experiments.train_agent_async import train_agent_async from pfrl.experiments.evaluator import run_evaluation_episodes -from pfrl.replay_buffers import EpisodicReplayBuffer -from pfrl.policies import GaussianHeadWithDiagonalCovariance -from pfrl.policies import SoftmaxCategoricalHead -from pfrl.q_functions import DiscreteActionValueHead +from pfrl.experiments.train_agent_async import train_agent_async from pfrl.nn import ConcatObsAndAction +from pfrl.policies import GaussianHeadWithDiagonalCovariance, SoftmaxCategoricalHead +from pfrl.q_functions import DiscreteActionValueHead +from pfrl.replay_buffers import EpisodicReplayBuffer def extract_gradients_as_single_vector(mod): diff --git a/tests/agents_tests/test_al.py b/tests/agents_tests/test_al.py index 89d75b9d7..2517cb9aa 100644 --- a/tests/agents_tests/test_al.py +++ b/tests/agents_tests/test_al.py @@ -1,5 +1,6 @@ import basetest_dqn_like as base from basetest_training import _TestBatchTrainingMixin + from pfrl.agents.al import AL diff --git a/tests/agents_tests/test_categorical_dqn.py b/tests/agents_tests/test_categorical_dqn.py index 4858debb3..43e8c85f4 100644 --- a/tests/agents_tests/test_categorical_dqn.py +++ b/tests/agents_tests/test_categorical_dqn.py @@ -1,19 +1,14 @@ import unittest - -import numpy as np - import basetest_dqn_like as base +import numpy as np +import pytest +import torch from basetest_training import _TestBatchTrainingMixin -import pfrl -from pfrl.agents import categorical_dqn -from pfrl.agents.categorical_dqn import compute_value_loss -from pfrl.agents.categorical_dqn import compute_weighted_value_loss -from pfrl.agents import CategoricalDQN - -import torch -import pytest +import pfrl +from pfrl.agents import CategoricalDQN, categorical_dqn +from pfrl.agents.categorical_dqn import compute_value_loss, compute_weighted_value_loss assertions = unittest.TestCase("__init__") diff --git a/tests/agents_tests/test_ddpg.py b/tests/agents_tests/test_ddpg.py index 200fa420c..4afec0361 100644 --- a/tests/agents_tests/test_ddpg.py +++ b/tests/agents_tests/test_ddpg.py @@ -1,8 +1,8 @@ +import basetest_ddpg as base import pytest +from basetest_training import _TestBatchTrainingMixin from pfrl.agents.ddpg import DDPG -import basetest_ddpg as base -from basetest_training import _TestBatchTrainingMixin @pytest.mark.skip # recurrent=True is not supported yet diff --git a/tests/agents_tests/test_double_categorical_dqn.py b/tests/agents_tests/test_double_categorical_dqn.py index b6b204f55..1cb7173eb 100644 --- a/tests/agents_tests/test_double_categorical_dqn.py +++ b/tests/agents_tests/test_double_categorical_dqn.py @@ -1,6 +1,5 @@ -import torch.nn as nn - import basetest_dqn_like as base +import torch.nn as nn from basetest_training import _TestBatchTrainingMixin import pfrl diff --git a/tests/agents_tests/test_double_dqn.py b/tests/agents_tests/test_double_dqn.py index 73f6e3ecf..5e5b5797a 100644 --- a/tests/agents_tests/test_double_dqn.py +++ b/tests/agents_tests/test_double_dqn.py @@ -1,5 +1,6 @@ import basetest_dqn_like from basetest_training import _TestBatchTrainingMixin + from pfrl.agents.double_dqn import DoubleDQN diff --git a/tests/agents_tests/test_double_pal.py b/tests/agents_tests/test_double_pal.py index 56496807d..0e4d7a40a 100644 --- a/tests/agents_tests/test_double_pal.py +++ b/tests/agents_tests/test_double_pal.py @@ -1,8 +1,8 @@ -from pfrl.agents.double_pal import DoublePAL - import basetest_dqn_like from basetest_training import _TestBatchTrainingMixin +from pfrl.agents.double_pal import DoublePAL + class TestDoublePALOnDiscreteABC( _TestBatchTrainingMixin, basetest_dqn_like._TestDQNOnDiscreteABC diff --git a/tests/agents_tests/test_dpp.py b/tests/agents_tests/test_dpp.py index a86c9b913..f9e8b872e 100644 --- a/tests/agents_tests/test_dpp.py +++ b/tests/agents_tests/test_dpp.py @@ -1,9 +1,8 @@ -import pytest import basetest_dqn_like as base +import pytest from basetest_training import _TestBatchTrainingMixin -from pfrl.agents.dpp import DPP -from pfrl.agents.dpp import DPPGreedy -from pfrl.agents.dpp import DPPL + +from pfrl.agents.dpp import DPP, DPPL, DPPGreedy def parse_dpp_agent(dpp_type): diff --git a/tests/agents_tests/test_dqn.py b/tests/agents_tests/test_dqn.py index d57c98532..bb4607a59 100644 --- a/tests/agents_tests/test_dqn.py +++ b/tests/agents_tests/test_dqn.py @@ -1,14 +1,12 @@ import unittest +import basetest_dqn_like as base import pytest import torch -import basetest_dqn_like as base +from basetest_training import _TestActorLearnerTrainingMixin, _TestBatchTrainingMixin + import pfrl -from pfrl.agents.dqn import compute_value_loss -from pfrl.agents.dqn import compute_weighted_value_loss -from pfrl.agents.dqn import DQN -from basetest_training import _TestActorLearnerTrainingMixin -from basetest_training import _TestBatchTrainingMixin +from pfrl.agents.dqn import DQN, compute_value_loss, compute_weighted_value_loss assertions = unittest.TestCase("__init__") diff --git a/tests/agents_tests/test_iqn.py b/tests/agents_tests/test_iqn.py index 7ecd806db..6f6569983 100644 --- a/tests/agents_tests/test_iqn.py +++ b/tests/agents_tests/test_iqn.py @@ -1,13 +1,13 @@ +import basetest_dqn_like as base import numpy as np -import torch -from torch import nn import pytest - -import basetest_dqn_like as base +import torch # IQN does not support the actor-learner interface for now # from basetest_training import _TestActorLearnerTrainingMixin from basetest_training import _TestBatchTrainingMixin +from torch import nn + import pfrl from pfrl.agents import iqn from pfrl.testing import torch_assert_allclose diff --git a/tests/agents_tests/test_pal.py b/tests/agents_tests/test_pal.py index 94a859837..ac851fab6 100644 --- a/tests/agents_tests/test_pal.py +++ b/tests/agents_tests/test_pal.py @@ -1,5 +1,6 @@ import basetest_dqn_like as base from basetest_training import _TestBatchTrainingMixin + from pfrl.agents.pal import PAL diff --git a/tests/agents_tests/test_ppo.py b/tests/agents_tests/test_ppo.py index e5a2a179b..5ccd980a6 100644 --- a/tests/agents_tests/test_ppo.py +++ b/tests/agents_tests/test_ppo.py @@ -13,17 +13,21 @@ from pfrl.agents import ppo from pfrl.agents.ppo import PPO from pfrl.envs.abc import ABC -from pfrl.experiments.evaluator import batch_run_evaluation_episodes -from pfrl.experiments.evaluator import run_evaluation_episodes -from pfrl.experiments import train_agent_batch_with_evaluation -from pfrl.experiments import train_agent_with_evaluation -from pfrl.utils.batch_states import batch_states - -from pfrl.nn import RecurrentBranched -from pfrl.nn import RecurrentSequential -from pfrl.policies import SoftmaxCategoricalHead -from pfrl.policies import GaussianHeadWithStateIndependentCovariance +from pfrl.experiments import ( + train_agent_batch_with_evaluation, + train_agent_with_evaluation, +) +from pfrl.experiments.evaluator import ( + batch_run_evaluation_episodes, + run_evaluation_episodes, +) +from pfrl.nn import RecurrentBranched, RecurrentSequential +from pfrl.policies import ( + GaussianHeadWithStateIndependentCovariance, + SoftmaxCategoricalHead, +) from pfrl.testing import torch_assert_allclose +from pfrl.utils.batch_states import batch_states def make_random_episodes(n_episodes=10, obs_size=2, n_actions=3): diff --git a/tests/agents_tests/test_reinforce.py b/tests/agents_tests/test_reinforce.py index d9559497e..d00206665 100644 --- a/tests/agents_tests/test_reinforce.py +++ b/tests/agents_tests/test_reinforce.py @@ -2,15 +2,17 @@ import tempfile import numpy as np +import pytest import torch from torch import nn -import pytest import pfrl -from pfrl.experiments.evaluator import run_evaluation_episodes from pfrl.envs.abc import ABC -from pfrl.policies import SoftmaxCategoricalHead -from pfrl.policies import GaussianHeadWithStateIndependentCovariance +from pfrl.experiments.evaluator import run_evaluation_episodes +from pfrl.policies import ( + GaussianHeadWithStateIndependentCovariance, + SoftmaxCategoricalHead, +) @pytest.mark.parametrize("discrete", [True, False]) diff --git a/tests/agents_tests/test_soft_actor_critic.py b/tests/agents_tests/test_soft_actor_critic.py index 1207bcafe..ef99707f8 100644 --- a/tests/agents_tests/test_soft_actor_critic.py +++ b/tests/agents_tests/test_soft_actor_critic.py @@ -4,16 +4,18 @@ import numpy as np import pytest import torch -from torch import distributions -from torch import nn - +from torch import distributions, nn import pfrl from pfrl.envs.abc import ABC -from pfrl.experiments.evaluator import batch_run_evaluation_episodes -from pfrl.experiments.evaluator import run_evaluation_episodes -from pfrl.experiments import train_agent_batch_with_evaluation -from pfrl.experiments import train_agent_with_evaluation +from pfrl.experiments import ( + train_agent_batch_with_evaluation, + train_agent_with_evaluation, +) +from pfrl.experiments.evaluator import ( + batch_run_evaluation_episodes, + run_evaluation_episodes, +) from pfrl.nn.lmbda import Lambda diff --git a/tests/agents_tests/test_td3.py b/tests/agents_tests/test_td3.py index f5d521b75..00bf6e17a 100644 --- a/tests/agents_tests/test_td3.py +++ b/tests/agents_tests/test_td3.py @@ -8,10 +8,14 @@ import pfrl from pfrl.envs.abc import ABC -from pfrl.experiments.evaluator import batch_run_evaluation_episodes -from pfrl.experiments.evaluator import run_evaluation_episodes -from pfrl.experiments import train_agent_batch_with_evaluation -from pfrl.experiments import train_agent_with_evaluation +from pfrl.experiments import ( + train_agent_batch_with_evaluation, + train_agent_with_evaluation, +) +from pfrl.experiments.evaluator import ( + batch_run_evaluation_episodes, + run_evaluation_episodes, +) @pytest.mark.parametrize("episodic", [False, True]) diff --git a/tests/agents_tests/test_trpo.py b/tests/agents_tests/test_trpo.py index f4ad841dd..1430127c9 100644 --- a/tests/agents_tests/test_trpo.py +++ b/tests/agents_tests/test_trpo.py @@ -10,13 +10,19 @@ import pfrl from pfrl.agents import trpo from pfrl.envs.abc import ABC -from pfrl.experiments.evaluator import batch_run_evaluation_episodes -from pfrl.experiments.evaluator import run_evaluation_episodes -from pfrl.experiments import train_agent_batch_with_evaluation -from pfrl.experiments import train_agent_with_evaluation +from pfrl.experiments import ( + train_agent_batch_with_evaluation, + train_agent_with_evaluation, +) +from pfrl.experiments.evaluator import ( + batch_run_evaluation_episodes, + run_evaluation_episodes, +) from pfrl.nn import RecurrentSequential -from pfrl.policies import SoftmaxCategoricalHead -from pfrl.policies import GaussianHeadWithStateIndependentCovariance +from pfrl.policies import ( + GaussianHeadWithStateIndependentCovariance, + SoftmaxCategoricalHead, +) from pfrl.testing import torch_assert_allclose diff --git a/tests/collections_tests/test_prioritized.py b/tests/collections_tests/test_prioritized.py index 7b222c992..702d45fe7 100644 --- a/tests/collections_tests/test_prioritized.py +++ b/tests/collections_tests/test_prioritized.py @@ -1,7 +1,7 @@ +import random import unittest import numpy as np -import random import pytest from pfrl.collections import prioritized diff --git a/tests/experiments_tests/test_prepare_output_dir.py b/tests/experiments_tests/test_prepare_output_dir.py index 3df7ba1fd..4e8ee5dc0 100644 --- a/tests/experiments_tests/test_prepare_output_dir.py +++ b/tests/experiments_tests/test_prepare_output_dir.py @@ -2,11 +2,12 @@ import itertools import json import os -import pytest import subprocess import sys import tempfile +import pytest + import pfrl diff --git a/tests/experiments_tests/test_train_agent_async.py b/tests/experiments_tests/test_train_agent_async.py index 375328073..8c3f861e2 100644 --- a/tests/experiments_tests/test_train_agent_async.py +++ b/tests/experiments_tests/test_train_agent_async.py @@ -1,10 +1,10 @@ -import torch.multiprocessing as mp import os import tempfile import unittest from unittest import mock import pytest +import torch.multiprocessing as mp import pfrl from pfrl.experiments.train_agent_async import train_loop diff --git a/tests/explorers_tests/test_boltzmann.py b/tests/explorers_tests/test_boltzmann.py index 09792a097..f03d9eaee 100644 --- a/tests/explorers_tests/test_boltzmann.py +++ b/tests/explorers_tests/test_boltzmann.py @@ -1,7 +1,7 @@ import unittest -import torch import numpy as np +import torch import pfrl diff --git a/tests/functions_tests/test_lower_triangular_matrix.py b/tests/functions_tests/test_lower_triangular_matrix.py index 49f2a999b..16e4e7167 100644 --- a/tests/functions_tests/test_lower_triangular_matrix.py +++ b/tests/functions_tests/test_lower_triangular_matrix.py @@ -1,7 +1,7 @@ -import torch - import numpy as np import pytest +import torch + from pfrl.functions.lower_triangular_matrix import lower_triangular_matrix diff --git a/tests/nn_tests/test_mlp_bn.py b/tests/nn_tests/test_mlp_bn.py index da3a4c12e..3c97aab60 100644 --- a/tests/nn_tests/test_mlp_bn.py +++ b/tests/nn_tests/test_mlp_bn.py @@ -1,11 +1,11 @@ -import torch +import unittest + import numpy as np +import pytest +import torch import pfrl -import pytest -import unittest - assertions = unittest.TestCase("__init__") diff --git a/tests/nn_tests/test_noisy_chain.py b/tests/nn_tests/test_noisy_chain.py index 28926d056..31b898129 100644 --- a/tests/nn_tests/test_noisy_chain.py +++ b/tests/nn_tests/test_noisy_chain.py @@ -1,7 +1,8 @@ -import torch -import numpy import unittest +import numpy +import torch + from pfrl.nn import to_factorized_noisy diff --git a/tests/nn_tests/test_noisy_linear.py b/tests/nn_tests/test_noisy_linear.py index 83245ed42..fc0612468 100644 --- a/tests/nn_tests/test_noisy_linear.py +++ b/tests/nn_tests/test_noisy_linear.py @@ -1,10 +1,9 @@ import numpy +import pytest +import torch from pfrl.nn import noisy_linear -import torch -import pytest - @pytest.mark.parametrize("bias", [False, True]) class TestFactorizedNoisyLinear: diff --git a/tests/nn_tests/test_recurrent_branched.py b/tests/nn_tests/test_recurrent_branched.py index 840208e40..846428a7a 100644 --- a/tests/nn_tests/test_recurrent_branched.py +++ b/tests/nn_tests/test_recurrent_branched.py @@ -4,13 +4,14 @@ import torch from torch import nn -from pfrl.nn import RecurrentBranched -from pfrl.nn import RecurrentSequential -from pfrl.utils.recurrent import mask_recurrent_state_at -from pfrl.utils.recurrent import get_recurrent_state_at -from pfrl.utils.recurrent import concatenate_recurrent_states -from pfrl.utils.recurrent import one_step_forward +from pfrl.nn import RecurrentBranched, RecurrentSequential from pfrl.testing import torch_assert_allclose +from pfrl.utils.recurrent import ( + concatenate_recurrent_states, + get_recurrent_state_at, + mask_recurrent_state_at, + one_step_forward, +) class TestRecurrentBranched(unittest.TestCase): diff --git a/tests/nn_tests/test_recurrent_sequential.py b/tests/nn_tests/test_recurrent_sequential.py index b3d4077e9..9ab41d44d 100644 --- a/tests/nn_tests/test_recurrent_sequential.py +++ b/tests/nn_tests/test_recurrent_sequential.py @@ -2,13 +2,11 @@ import pytest import torch -from torch import nn import torch.nn.functional as F +from torch import nn -from pfrl.nn import RecurrentSequential - +from pfrl.nn import Lambda, RecurrentSequential from pfrl.testing import torch_assert_allclose -from pfrl.nn import Lambda def _step_lstm(lstm, x, state): diff --git a/tests/q_functions_tests/basetest_state_action_q_function.py b/tests/q_functions_tests/basetest_state_action_q_function.py index 63fdd22b0..54019bc51 100644 --- a/tests/q_functions_tests/basetest_state_action_q_function.py +++ b/tests/q_functions_tests/basetest_state_action_q_function.py @@ -1,6 +1,6 @@ import unittest -import numpy as np +import numpy as np import torch assertions = unittest.TestCase("__init__") diff --git a/tests/q_functions_tests/test_state_action_q_function.py b/tests/q_functions_tests/test_state_action_q_function.py index 75f27e4d3..0cc5760db 100644 --- a/tests/q_functions_tests/test_state_action_q_function.py +++ b/tests/q_functions_tests/test_state_action_q_function.py @@ -1,10 +1,10 @@ -import torch.nn.functional as F +import unittest import basetest_state_action_q_function as base -import pfrl - import pytest -import unittest +import torch.nn.functional as F + +import pfrl assertions = unittest.TestCase("__init__") diff --git a/tests/replay_buffers_test/test_replay_buffer.py b/tests/replay_buffers_test/test_replay_buffer.py index 31904c587..bf2b2b037 100644 --- a/tests/replay_buffers_test/test_replay_buffer.py +++ b/tests/replay_buffers_test/test_replay_buffer.py @@ -6,11 +6,10 @@ import numpy as np import pytest - -from pfrl import replay_buffers -from pfrl import replay_buffer import torch +from pfrl import replay_buffer, replay_buffers + @pytest.mark.parametrize("capacity", [100, None]) @pytest.mark.parametrize("num_steps", [1, 3]) diff --git a/tests/utils_tests/test_batch_states.py b/tests/utils_tests/test_batch_states.py index 6c5368126..d84aa369b 100644 --- a/tests/utils_tests/test_batch_states.py +++ b/tests/utils_tests/test_batch_states.py @@ -1,10 +1,10 @@ import unittest import numpy as np +import pytest +import torch import pfrl -import torch -import pytest @pytest.mark.skip diff --git a/tests/utils_tests/test_clip_l2_grad_norm.py b/tests/utils_tests/test_clip_l2_grad_norm.py index 7ab8707ef..21f762af8 100644 --- a/tests/utils_tests/test_clip_l2_grad_norm.py +++ b/tests/utils_tests/test_clip_l2_grad_norm.py @@ -1,5 +1,5 @@ -from logging import getLogger import timeit +from logging import getLogger import numpy as np import pytest diff --git a/tests/utils_tests/test_copy_param.py b/tests/utils_tests/test_copy_param.py index 14e7057c9..1ffecb472 100644 --- a/tests/utils_tests/test_copy_param.py +++ b/tests/utils_tests/test_copy_param.py @@ -1,11 +1,11 @@ import unittest +import numpy as np import torch import torch.nn as nn -import numpy as np -from pfrl.utils import copy_param from pfrl.testing import torch_assert_allclose +from pfrl.utils import copy_param class TestCopyParam(unittest.TestCase): diff --git a/tests/utils_tests/test_mode_of_distribution.py b/tests/utils_tests/test_mode_of_distribution.py index dcd8b89f3..025d27bbe 100644 --- a/tests/utils_tests/test_mode_of_distribution.py +++ b/tests/utils_tests/test_mode_of_distribution.py @@ -2,8 +2,8 @@ import torch -from pfrl.utils.mode_of_distribution import mode_of_distribution from pfrl.testing import torch_assert_allclose +from pfrl.utils.mode_of_distribution import mode_of_distribution def test_transform(): diff --git a/tests/utils_tests/test_random_seed.py b/tests/utils_tests/test_random_seed.py index afff3ad73..812d8f158 100644 --- a/tests/utils_tests/test_random_seed.py +++ b/tests/utils_tests/test_random_seed.py @@ -1,9 +1,10 @@ import random import unittest -import pfrl -import torch import pytest +import torch + +import pfrl class TestSetRandomSeed(unittest.TestCase): diff --git a/tests/utils_tests/test_recurrent.py b/tests/utils_tests/test_recurrent.py index 1add419c6..1bb38fbc6 100644 --- a/tests/utils_tests/test_recurrent.py +++ b/tests/utils_tests/test_recurrent.py @@ -4,10 +4,12 @@ import torch from torch import nn -from pfrl.utils.recurrent import get_recurrent_state_at -from pfrl.utils.recurrent import mask_recurrent_state_at -from pfrl.utils.recurrent import concatenate_recurrent_states from pfrl.testing import torch_assert_allclose +from pfrl.utils.recurrent import ( + concatenate_recurrent_states, + get_recurrent_state_at, + mask_recurrent_state_at, +) class TestRecurrentStateFunctions(unittest.TestCase): diff --git a/tests/utils_tests/test_stoppable_thread.py b/tests/utils_tests/test_stoppable_thread.py index 686bbfa8f..f179e6bda 100644 --- a/tests/utils_tests/test_stoppable_thread.py +++ b/tests/utils_tests/test_stoppable_thread.py @@ -1,6 +1,6 @@ +import threading import unittest -import threading from pfrl.utils import StoppableThread diff --git a/tests/wrappers_tests/test_atari_wrappers.py b/tests/wrappers_tests/test_atari_wrappers.py index 62467c474..b8ed38eba 100644 --- a/tests/wrappers_tests/test_atari_wrappers.py +++ b/tests/wrappers_tests/test_atari_wrappers.py @@ -9,9 +9,7 @@ import numpy as np import pytest -from pfrl.wrappers.atari_wrappers import FrameStack -from pfrl.wrappers.atari_wrappers import LazyFrames -from pfrl.wrappers.atari_wrappers import ScaledFloatFrame +from pfrl.wrappers.atari_wrappers import FrameStack, LazyFrames, ScaledFloatFrame @pytest.mark.parametrize("dtype", [np.uint8, np.float32]) diff --git a/tests/wrappers_tests/test_monitor.py b/tests/wrappers_tests/test_monitor.py index e6ea7235e..ba65e9cc9 100644 --- a/tests/wrappers_tests/test_monitor.py +++ b/tests/wrappers_tests/test_monitor.py @@ -3,8 +3,8 @@ import tempfile import gym -from gym.wrappers import TimeLimit import pytest +from gym.wrappers import TimeLimit import pfrl diff --git a/tests/wrappers_tests/test_render.py b/tests/wrappers_tests/test_render.py index 9a3e2f89f..f93fcf8c4 100644 --- a/tests/wrappers_tests/test_render.py +++ b/tests/wrappers_tests/test_render.py @@ -1,4 +1,5 @@ from unittest import mock + import pytest import pfrl diff --git a/tests/wrappers_tests/test_vector_frame_stack.py b/tests/wrappers_tests/test_vector_frame_stack.py index 5efafd356..ef45e63c8 100644 --- a/tests/wrappers_tests/test_vector_frame_stack.py +++ b/tests/wrappers_tests/test_vector_frame_stack.py @@ -8,10 +8,8 @@ import pytest import pfrl -from pfrl.wrappers.atari_wrappers import FrameStack -from pfrl.wrappers.atari_wrappers import LazyFrames -from pfrl.wrappers.vector_frame_stack import VectorEnvWrapper -from pfrl.wrappers.vector_frame_stack import VectorFrameStack +from pfrl.wrappers.atari_wrappers import FrameStack, LazyFrames +from pfrl.wrappers.vector_frame_stack import VectorEnvWrapper, VectorFrameStack class TestVectorEnvWrapper(unittest.TestCase): From d6985b4e5991d32066fb4db458757406081ed927 Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 14 Oct 2020 11:52:57 +0900 Subject: [PATCH 3/5] Apply isort in CI --- .pfnci/lint.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.pfnci/lint.sh b/.pfnci/lint.sh index 2babf5fab..34f2c5bca 100644 --- a/.pfnci/lint.sh +++ b/.pfnci/lint.sh @@ -3,9 +3,10 @@ set -eux # Use latest black to apply https://github.com/psf/black/issues/1288 -pip3 install git+git://github.com/psf/black.git@88d12f88a97e5e4c8fd0d245df0a311e932fd1e1 flake8 mypy +pip3 install git+git://github.com/psf/black.git@88d12f88a97e5e4c8fd0d245df0a311e932fd1e1 flake8 mypy isort black --diff --check pfrl tests examples +isort --diff --check pfrl tests examples flake8 pfrl tests examples mypy pfrl # mypy does not search child directories unless there is __init__.py From 8b46df7b7e6c816eb7908fa2c571405eca65a35a Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 14 Oct 2020 12:19:54 +0900 Subject: [PATCH 4/5] Fix errors related to import order --- pfrl/initializers/chainer_default.py | 2 +- pfrl/nn/concat_obs_and_action.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pfrl/initializers/chainer_default.py b/pfrl/initializers/chainer_default.py index 1532a0db9..437674816 100644 --- a/pfrl/initializers/chainer_default.py +++ b/pfrl/initializers/chainer_default.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -from pfrl.initializers import init_lecun_normal +from pfrl.initializers.lecun_normal import init_lecun_normal @torch.no_grad() diff --git a/pfrl/nn/concat_obs_and_action.py b/pfrl/nn/concat_obs_and_action.py index 0c15ff276..ff2abef89 100644 --- a/pfrl/nn/concat_obs_and_action.py +++ b/pfrl/nn/concat_obs_and_action.py @@ -1,6 +1,6 @@ import torch -from pfrl.nn import Lambda +from pfrl.nn.lmbda import Lambda def concat_obs_and_action(obs_and_action): From 4c30c1d73f0941a2b649b62937eec346bb55a95e Mon Sep 17 00:00:00 2001 From: Mario Ynocente Castro Date: Thu, 29 Oct 2020 19:01:19 +0900 Subject: [PATCH 5/5] Set stats type (#83) --- pfrl/wrappers/monitor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pfrl/wrappers/monitor.py b/pfrl/wrappers/monitor.py index 0839b0d85..6fcecad43 100644 --- a/pfrl/wrappers/monitor.py +++ b/pfrl/wrappers/monitor.py @@ -47,6 +47,8 @@ def _start( autoreset=False, env_id=env_id, ) + if mode is not None: + self._set_mode(mode) return ret