Skip to content

Commit

Permalink
v0.2 update
Browse files Browse the repository at this point in the history
got her-sac example working

got dqn example working

finish td3 and set up examples for it

remove railrl reference
  • Loading branch information
vitchyr committed Apr 4, 2019
1 parent 86db9c2 commit c1e9ec6
Show file tree
Hide file tree
Showing 43 changed files with 1,464 additions and 1,784 deletions.
44 changes: 30 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,55 @@ Reinforcement learning framework and algorithms implemented in PyTorch.
Some implemented algorithms:
- Reinforcement Learning with Imagined Goals (RIG)
- [example script](examples/rig/pusher/rig.py)
- [RIG paper](https://arxiv.org/abs/1807.04742)
- [paper](https://arxiv.org/abs/1807.04742)
- [Documentation](docs/RIG.md)
- Temporal Difference Models (TDMs)
- [example script](examples/tdm/cheetah.py)
- [TDM paper](https://arxiv.org/abs/1802.09081)
- [paper](https://arxiv.org/abs/1802.09081)
- [Documentation](docs/TDMs.md)
- Hindsight Experience Replay (HER)
- [example script](examples/her/her_td3_gym_fetch_reach.py)
- [HER paper](https://arxiv.org/abs/1707.01495)
- [example script](examples/her/her_sac_gym_fetch_reach.py)
- [paper](https://arxiv.org/abs/1707.01495)
- [Documentation](docs/HER.md)
- Deep Deterministic Policy Gradient (DDPG)
- [example script](examples/ddpg.py)
- [DDPG paper](https://arxiv.org/pdf/1509.02971.pdf)
- [paper](https://arxiv.org/pdf/1509.02971.pdf)
- (Double) Deep Q-Network (DQN)
- [example script](examples/dqn_and_double_dqn.py)
- [DQN paper](https://arxiv.org/pdf/1509.06461.pdf)
- [paper](https://arxiv.org/pdf/1509.06461.pdf)
- [Double Q-learning paper](https://arxiv.org/pdf/1509.06461.pdf)
- (Twin) Soft Actor Critic (SAC)
- Soft Actor Critic (SAC)
- [example script](examples/tsac.py)
- [SAC paper](https://arxiv.org/abs/1801.01290)
- [original paper](https://arxiv.org/abs/1801.01290) and [updated
version](https://arxiv.org/abs/1812.05905)
- [TensorFlow implementation from author](https://github.com/rail-berkeley/softlearning)
- Includes the "min of Q" method and the entropy-constrained implementation
- Includes the "min of Q" method, the entropy-constrained implementation,
reparameterization trick, and numerical tanh-Normal Jacbian calcuation.
- Twin Delayed Deep Determinstic Policy Gradient (TD3)
- [example script](examples/td3.py)
- [TD3 paper](https://arxiv.org/abs/1802.09477)
- (Non-Twin/Old) Soft Actor Critic
- [example script](examples/sac.py)
- SAC without the "min of Q" method.
- The canonical SAC implementation is the twin version, listed earlier.
- [paper](https://arxiv.org/abs/1802.09477)

To get started, checkout the example scripts, linked above.

## What's New
### Version 0.2
The initial release for 0.2 has the following major changes:
- Remove `Serializable` class and use default pickle scheme.
- Remove `PyTorchModule` class and use native `torch.nn.Module` directly.
- Switch to batch-style training rather than online training.
- Makes code more amenable to parallelization.
- Implementing the online-version is straightforward.
- Refactor training code to be its own object, rather than being integrated
inside of `RLAlgorithm`.
- Refactor sampling code to be its own object, rather than being integrated
inside of `RLAlgorithm`.
- Update Soft Actor Critic to more closely match TensorFlow implementation:
- Rename `TwinSAC` to just `SAC`.
- Only have Q networks.
- Remove unnecessary policy regualization terms.
- Use numerically stable Jacobian computation.

### Version 0.1
12/04/2018
- Add RIG implementation

Expand Down
90 changes: 67 additions & 23 deletions examples/dqn_and_double_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,97 @@
"""

import gym
import numpy as np
from torch import nn as nn

from rlkit.exploration_strategies.base import \
PolicyWrappedWithExplorationStrategy
from rlkit.exploration_strategies.epsilon_greedy import EpsilonGreedy
from rlkit.policies.argmax import ArgmaxDiscretePolicy
from rlkit.torch.dqn.dqn import DQNTrainer
from rlkit.torch.networks import Mlp
import rlkit.torch.pytorch_util as ptu
from rlkit.data_management.env_replay_buffer import EnvReplayBuffer
from rlkit.launchers.launcher_util import setup_logger
from rlkit.torch.dqn.dqn import DQN
from rlkit.torch.networks import Mlp
from rlkit.samplers.data_collector import MdpPathCollector
from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm


def experiment(variant):
env = gym.make('CartPole-v0')
training_env = gym.make('CartPole-v0')
expl_env = gym.make('CartPole-v0')
eval_env = gym.make('CartPole-v0')
obs_dim = expl_env.observation_space.low.size
action_dim = eval_env.action_space.n

qf = Mlp(
hidden_sizes=[32, 32],
input_size=int(np.prod(env.observation_space.shape)),
output_size=env.action_space.n,
input_size=obs_dim,
output_size=action_dim,
)
target_qf = Mlp(
hidden_sizes=[32, 32],
input_size=obs_dim,
output_size=action_dim,
)
qf_criterion = nn.MSELoss()
# Use this to switch to DoubleDQN
# algorithm = DoubleDQN(
algorithm = DQN(
env,
training_env=training_env,
eval_policy = ArgmaxDiscretePolicy(qf)
expl_policy = PolicyWrappedWithExplorationStrategy(
EpsilonGreedy(expl_env.action_space),
eval_policy,
)
eval_path_collector = MdpPathCollector(
eval_env,
eval_policy,
)
expl_path_collector = MdpPathCollector(
expl_env,
expl_policy,
)
trainer = DQNTrainer(
qf=qf,
target_qf=target_qf,
qf_criterion=qf_criterion,
**variant['algo_params']
**variant['trainer_kwargs']
)
replay_buffer = EnvReplayBuffer(
variant['replay_buffer_size'],
expl_env,
)
algorithm = TorchBatchRLAlgorithm(
trainer=trainer,
exploration_env=expl_env,
evaluation_env=eval_env,
exploration_data_collector=expl_path_collector,
evaluation_data_collector=eval_path_collector,
replay_buffer=replay_buffer,
**variant['algorithm_kwargs']
)
algorithm.to(ptu.device)
algorithm.train()




if __name__ == "__main__":
# noinspection PyTypeChecker
variant = dict(
algo_params=dict(
num_epochs=500,
num_steps_per_epoch=1000,
num_steps_per_eval=1000,
batch_size=128,
max_path_length=200,
algorithm="SAC",
version="normal",
layer_size=256,
replay_buffer_size=int(1E6),
algorithm_kwargs=dict(
num_epochs=3000,
num_eval_steps_per_epoch=5000,
num_trains_per_train_loop=1000,
num_expl_steps_per_train_loop=1000,
min_num_steps_before_training=1000,
max_path_length=1000,
batch_size=256,
),
trainer_kwargs=dict(
discount=0.99,
epsilon=0.2,
tau=0.001,
hard_update_period=1000,
save_environment=False, # Can't serialize CartPole for some reason
learning_rate=3E-4,
),
)
setup_logger('name-of-experiment', variant=variant)
# ptu.set_gpu_mode(True) # optionally set the GPU (default=False)
experiment(variant)
129 changes: 129 additions & 0 deletions examples/her/her_sac_gym_fetch_reach.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import gym

import rlkit.torch.pytorch_util as ptu
from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer
from rlkit.launchers.launcher_util import setup_logger
from rlkit.samplers.data_collector import GoalConditionedPathCollector
from rlkit.torch.her.her import HERTrainer
from rlkit.torch.networks import FlattenMlp
from rlkit.torch.sac.policies import MakeDeterministic, TanhGaussianPolicy
from rlkit.torch.sac.sac import SACTrainer
from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm


def experiment(variant):
eval_env = gym.make('FetchReach-v1')
expl_env = gym.make('FetchReach-v1')

observation_key = 'observation'
desired_goal_key = 'desired_goal'

achieved_goal_key = desired_goal_key.replace("desired", "achieved")
replay_buffer = ObsDictRelabelingBuffer(
env=eval_env,
observation_key=observation_key,
desired_goal_key=desired_goal_key,
achieved_goal_key=achieved_goal_key,
**variant['replay_buffer_kwargs']
)
obs_dim = eval_env.observation_space.spaces['observation'].low.size
action_dim = eval_env.action_space.low.size
goal_dim = eval_env.observation_space.spaces['desired_goal'].low.size
qf1 = FlattenMlp(
input_size=obs_dim + action_dim + goal_dim,
output_size=1,
**variant['qf_kwargs']
)
qf2 = FlattenMlp(
input_size=obs_dim + action_dim + goal_dim,
output_size=1,
**variant['qf_kwargs']
)
target_qf1 = FlattenMlp(
input_size=obs_dim + action_dim + goal_dim,
output_size=1,
**variant['qf_kwargs']
)
target_qf2 = FlattenMlp(
input_size=obs_dim + action_dim + goal_dim,
output_size=1,
**variant['qf_kwargs']
)
policy = TanhGaussianPolicy(
obs_dim=obs_dim + goal_dim,
action_dim=action_dim,
**variant['policy_kwargs']
)
eval_policy = MakeDeterministic(policy)
trainer = SACTrainer(
env=eval_env,
policy=policy,
qf1=qf1,
qf2=qf2,
target_qf1=target_qf1,
target_qf2=target_qf2,
**variant['sac_trainer_kwargs']
)
trainer = HERTrainer(trainer)
eval_path_collector = GoalConditionedPathCollector(
eval_env,
eval_policy,
observation_key=observation_key,
desired_goal_key=desired_goal_key,
)
expl_path_collector = GoalConditionedPathCollector(
expl_env,
policy,
observation_key=observation_key,
desired_goal_key=desired_goal_key,
)
algorithm = TorchBatchRLAlgorithm(
trainer=trainer,
exploration_env=expl_env,
evaluation_env=eval_env,
exploration_data_collector=expl_path_collector,
evaluation_data_collector=eval_path_collector,
replay_buffer=replay_buffer,
**variant['algo_kwargs']
)
algorithm.to(ptu.device)
algorithm.train()



if __name__ == "__main__":
variant = dict(
algorithm='HER-SAC',
version='normal',
algo_kwargs=dict(
batch_size=128,
num_epochs=100,
num_eval_steps_per_epoch=5000,
num_expl_steps_per_train_loop=1000,
num_trains_per_train_loop=1000,
min_num_steps_before_training=1000,
max_path_length=50,
),
sac_trainer_kwargs=dict(
discount=0.99,
soft_target_tau=5e-3,
target_update_period=1,
policy_lr=3E-4,
qf_lr=3E-4,
reward_scale=1,
use_automatic_entropy_tuning=True,
),
replay_buffer_kwargs=dict(
max_size=int(1E6),
fraction_goals_rollout_goals=0.2, # equal to k = 4 in HER paper
fraction_goals_env_goals=0,
),
qf_kwargs=dict(
hidden_sizes=[400, 300],
),
policy_kwargs=dict(
hidden_sizes=[400, 300],
),
)
setup_logger('her-sac-fetch-experiment', variant=variant)
experiment(variant)
Loading

0 comments on commit c1e9ec6

Please sign in to comment.