Skip to content

Commit

Permalink
Merge branch 'master' into her
Browse files Browse the repository at this point in the history
  • Loading branch information
prabhatnagarajan committed Nov 5, 2020
2 parents e481b85 + 9be4726 commit 573c7a2
Show file tree
Hide file tree
Showing 129 changed files with 380 additions and 459 deletions.
3 changes: 2 additions & 1 deletion .pfnci/lint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
set -eux

# Use latest black to apply https://github.com/psf/black/issues/1288
pip3 install git+git://github.com/psf/black.git@88d12f88a97e5e4c8fd0d245df0a311e932fd1e1 flake8 mypy
pip3 install git+git://github.com/psf/black.git@88d12f88a97e5e4c8fd0d245df0a311e932fd1e1 flake8 mypy isort

black --diff --check pfrl tests examples
isort --diff --check pfrl tests examples
flake8 pfrl tests examples
mypy pfrl
# mypy does not search child directories unless there is __init__.py
Expand Down
6 changes: 2 additions & 4 deletions examples/atari/reproduction/a3c/train_a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
from torch import nn

import pfrl
from pfrl import experiments, utils
from pfrl.agents import a3c
from pfrl import experiments
from pfrl import utils
from pfrl.policies import SoftmaxCategoricalHead
from pfrl.optimizers import SharedRMSpropEpsInsideSqrt

from pfrl.policies import SoftmaxCategoricalHead
from pfrl.wrappers import atari_wrappers


Expand Down
16 changes: 6 additions & 10 deletions examples/atari/reproduction/dqn/train_dqn.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import argparse
import json
import os

import torch.nn as nn
import numpy as np
import torch.nn as nn

import pfrl
from pfrl.q_functions import DiscreteActionValueHead
from pfrl import agents
from pfrl import experiments
from pfrl import explorers
from pfrl import agents, experiments, explorers
from pfrl import nn as pnn
from pfrl import utils
from pfrl import replay_buffers

from pfrl.wrappers import atari_wrappers
from pfrl import replay_buffers, utils
from pfrl.initializers import init_chainer_default
import json
from pfrl.q_functions import DiscreteActionValueHead
from pfrl.wrappers import atari_wrappers


def main():
Expand Down
5 changes: 1 addition & 4 deletions examples/atari/reproduction/iqn/train_iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
from torch import nn

import pfrl
from pfrl import experiments
from pfrl import explorers
from pfrl import utils
from pfrl import replay_buffers
from pfrl import experiments, explorers, replay_buffers, utils
from pfrl.wrappers import atari_wrappers


Expand Down
9 changes: 3 additions & 6 deletions examples/atari/reproduction/rainbow/train_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@
import json
import os

import torch
import numpy as np
import torch

import pfrl
from pfrl import agents
from pfrl import experiments
from pfrl import explorers
from pfrl import agents, experiments, explorers
from pfrl import nn as pnn
from pfrl import utils
from pfrl import replay_buffers, utils
from pfrl.q_functions import DistributionalDuelingDQN
from pfrl import replay_buffers
from pfrl.wrappers import atari_wrappers


Expand Down
3 changes: 1 addition & 2 deletions examples/atari/train_a2c_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from torch import nn

import pfrl
from pfrl import experiments, utils
from pfrl.agents import a2c
from pfrl import experiments
from pfrl import utils
from pfrl.policies import SoftmaxCategoricalHead
from pfrl.wrappers import atari_wrappers

Expand Down
6 changes: 2 additions & 4 deletions examples/atari/train_acer_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@
from torch import nn

import pfrl
from pfrl import experiments, utils
from pfrl.agents import acer
from pfrl import experiments
from pfrl import utils
from pfrl.replay_buffers import EpisodicReplayBuffer
from pfrl.policies import SoftmaxCategoricalHead
from pfrl.q_functions import DiscreteActionValueHead

from pfrl.replay_buffers import EpisodicReplayBuffer
from pfrl.wrappers import atari_wrappers


Expand Down
8 changes: 2 additions & 6 deletions examples/atari/train_categorical_dqn_ale.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import argparse

import torch
import numpy as np
import torch

import pfrl
from pfrl import experiments
from pfrl import explorers
from pfrl import utils
from pfrl import replay_buffers

from pfrl import experiments, explorers, replay_buffers, utils
from pfrl.wrappers import atari_wrappers


Expand Down
15 changes: 5 additions & 10 deletions examples/atari/train_dqn_ale.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
import argparse

import numpy as np
import torch
import torch.nn as nn
import numpy as np

import pfrl
from pfrl.q_functions import DiscreteActionValueHead
from pfrl import agents
from pfrl import experiments
from pfrl import explorers
from pfrl import agents, experiments, explorers
from pfrl import nn as pnn
from pfrl import utils
from pfrl.q_functions import DuelingDQN
from pfrl import replay_buffers

from pfrl.wrappers import atari_wrappers
from pfrl import replay_buffers, utils
from pfrl.initializers import init_chainer_default
from pfrl.q_functions import DiscreteActionValueHead, DuelingDQN
from pfrl.wrappers import atari_wrappers


class SingleSharedBias(nn.Module):
Expand Down
15 changes: 5 additions & 10 deletions examples/atari/train_dqn_batch_ale.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
import argparse
import functools

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

import pfrl
from pfrl import agents
from pfrl import experiments
from pfrl import explorers
from pfrl import agents, experiments, explorers
from pfrl import nn as pnn
from pfrl import utils
from pfrl.q_functions import DiscreteActionValueHead
from pfrl.q_functions import DuelingDQN
from pfrl import replay_buffers

from pfrl.wrappers import atari_wrappers
from pfrl import replay_buffers, utils
from pfrl.initializers import init_chainer_default
from pfrl.q_functions import DiscreteActionValueHead, DuelingDQN
from pfrl.wrappers import atari_wrappers


class SingleSharedBias(nn.Module):
Expand Down
6 changes: 1 addition & 5 deletions examples/atari/train_drqn_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@
from torch import nn

import pfrl
from pfrl import experiments
from pfrl import explorers
from pfrl import utils
from pfrl import replay_buffers
from pfrl import experiments, explorers, replay_buffers, utils
from pfrl.q_functions import DiscreteActionValueHead

from pfrl.wrappers import atari_wrappers


Expand Down
5 changes: 2 additions & 3 deletions examples/atari/train_ppo_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
from torch import nn

import pfrl
from pfrl import experiments, utils
from pfrl.agents import PPO
from pfrl import experiments
from pfrl import utils
from pfrl.wrappers import atari_wrappers
from pfrl.policies import SoftmaxCategoricalHead
from pfrl.wrappers import atari_wrappers


def main():
Expand Down
7 changes: 2 additions & 5 deletions examples/atlas/train_soft_actor_critic_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,11 @@
import gym.wrappers
import numpy as np
import torch
from torch import nn
from torch import distributions
from torch import distributions, nn

import pfrl
from pfrl import experiments
from pfrl import experiments, replay_buffers, utils
from pfrl.nn.lmbda import Lambda
from pfrl import utils
from pfrl import replay_buffers


def make_env(args, seed, test):
Expand Down
9 changes: 3 additions & 6 deletions examples/grasping/train_dqn_batch_grasping.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@
from torch import nn

import pfrl
from pfrl import experiments, explorers, replay_buffers, utils
from pfrl.q_functions import DiscreteActionValueHead
from pfrl import experiments
from pfrl import explorers
from pfrl import utils
from pfrl import replay_buffers


class CastAction(gym.ActionWrapper):
Expand Down Expand Up @@ -243,9 +240,9 @@ def main():
max_episode_steps = 8

def make_env(idx, test):
from pybullet_envs.bullet.kuka_diverse_object_gym_env import (
from pybullet_envs.bullet.kuka_diverse_object_gym_env import ( # NOQA
KukaDiverseObjectEnv,
) # NOQA
)

# Use different random seeds for train and test envs
process_seed = int(process_seeds[idx])
Expand Down
8 changes: 2 additions & 6 deletions examples/gym/train_categorical_dqn_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,11 @@
import argparse
import sys

import torch
import gym
import torch

import pfrl
from pfrl import experiments
from pfrl import explorers
from pfrl import utils
from pfrl import q_functions
from pfrl import replay_buffers
from pfrl import experiments, explorers, q_functions, replay_buffers, utils


def main():
Expand Down
15 changes: 6 additions & 9 deletions examples/gym/train_dqn_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,19 @@
"""

import argparse
import sys
import os
import sys

import torch.optim as optim
import gym
from gym import spaces
import numpy as np
import torch.optim as optim
from gym import spaces

import pfrl
from pfrl.agents.dqn import DQN
from pfrl import experiments
from pfrl import explorers
from pfrl import experiments, explorers
from pfrl import nn as pnn
from pfrl import utils
from pfrl import q_functions
from pfrl import replay_buffers
from pfrl import q_functions, replay_buffers, utils
from pfrl.agents.dqn import DQN


def main():
Expand Down
6 changes: 2 additions & 4 deletions examples/gym/train_reinforce_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
from torch import nn

import pfrl
from pfrl import experiments
from pfrl import utils
from pfrl.policies import SoftmaxCategoricalHead
from pfrl.policies import GaussianHeadWithFixedCovariance
from pfrl import experiments, utils
from pfrl.policies import GaussianHeadWithFixedCovariance, SoftmaxCategoricalHead


def main():
Expand Down
8 changes: 2 additions & 6 deletions examples/mujoco/reproduction/ddpg/train_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@
from torch import nn

import pfrl
from pfrl import experiments, explorers, replay_buffers, utils
from pfrl.agents.ddpg import DDPG
from pfrl import experiments
from pfrl import explorers
from pfrl import utils
from pfrl import replay_buffers
from pfrl.nn import ConcatObsAndAction
from pfrl.nn import BoundByTanh
from pfrl.nn import BoundByTanh, ConcatObsAndAction
from pfrl.policies import DeterministicHead


Expand Down
3 changes: 1 addition & 2 deletions examples/mujoco/reproduction/ppo/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
from torch import nn

import pfrl
from pfrl import experiments, utils
from pfrl.agents import PPO
from pfrl import experiments
from pfrl import utils


def main():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,20 @@
as possible.
"""
import argparse
from distutils.version import LooseVersion
import functools
import logging
import sys
from distutils.version import LooseVersion

import torch
from torch import nn
from torch import distributions
import gym
import gym.wrappers
import numpy as np
import torch
from torch import distributions, nn

import pfrl
from pfrl import experiments
from pfrl import experiments, replay_buffers, utils
from pfrl.nn.lmbda import Lambda
from pfrl import utils
from pfrl import replay_buffers


def main():
Expand Down
5 changes: 1 addition & 4 deletions examples/mujoco/reproduction/td3/train_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
from torch import nn

import pfrl
from pfrl import experiments
from pfrl import explorers
from pfrl import utils
from pfrl import replay_buffers
from pfrl import experiments, explorers, replay_buffers, utils


def main():
Expand Down
Loading

0 comments on commit 573c7a2

Please sign in to comment.