Skip to content

Commit

Permalink
Implement Generative Adversarial Imitation Learning (GAIL) (#550)
Browse files Browse the repository at this point in the history
Implement GAIL based on PPO and provide example script and sample (i.e., most likely not the best) results with Mujoco tasks. (#531, #173)
  • Loading branch information
nuance1979 committed Mar 6, 2022
1 parent d976a5a commit 2377f2f
Show file tree
Hide file tree
Showing 12 changed files with 683 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)
- [Generative Adversarial Imitation Learning (GAIL)](https://arxiv.org/pdf/1606.03476.pdf)
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
Expand Down
5 changes: 5 additions & 0 deletions docs/api/tianshou.policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ Imitation
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.GAILPolicy
:members:
:undoc-members:
:show-inheritance:

Model-based
-----------

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
* :class:`~tianshou.policy.GAILPolicy` `Generative Adversarial Imitation Learning <https://arxiv.org/pdf/1606.03476.pdf>`_
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
* :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module <https://arxiv.org/pdf/1705.05363.pdf>`_
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
Expand Down
3 changes: 3 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,6 @@ Strens
Ornstein
Uhlenbeck
mse
gail
airl
ppo
27 changes: 27 additions & 0 deletions examples/inverse/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Inverse Reinforcement Learning

In inverse reinforcement learning setting, the agent learns a policy from interaction with an environment without reward and a fixed dataset which is collected with an expert policy.

## Continuous control

Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.

We provide implementation of GAIL algorithm for continuous control.

### Train

You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `expert_buffer` of `GAILPolicy`. `irl_gail.py` is an example of inverse RL using the d4rl dataset.

To train an agent with BCQ algorithm:

```bash
python irl_gail.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2
```

## GAIL (single run)

| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| HalfCheetah-v2 | 5177.07 | ![](results/gail/HalfCheetah-v2_rew.png) | `python3 irl_gail.py --task "HalfCheetah-v2" --expert-data-task "halfcheetah-expert-v2"` |
| Hopper-v2 | 1761.44 | ![](results/gail/Hopper-v2_rew.png) | `python3 irl_gail.py --task "Hopper-v2" --expert-data-task "hopper-expert-v2"` |
| Walker2d-v2 | 2020.77 | ![](results/gail/Walker2d-v2_rew.png) | `python3 irl_gail.py --task "Walker2d-v2" --expert-data-task "walker2d-expert-v2"` |
277 changes: 277 additions & 0 deletions examples/inverse/irl_gail.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
#!/usr/bin/env python3

import argparse
import datetime
import os
import pprint

import d4rl
import gym
import numpy as np
import torch
from torch import nn
from torch.distributions import Independent, Normal
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.policy import GAILPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic


class NoRewardEnv(gym.RewardWrapper):
"""sets the reward to 0.
:param gym.Env env: the environment to wrap.
"""

def __init__(self, env):
super().__init__(env)

def reward(self, reward):
"""Set reward to 0."""
return np.zeros_like(reward)


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='HalfCheetah-v2')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument(
'--expert-data-task', type=str, default='halfcheetah-expert-v2'
)
parser.add_argument('--buffer-size', type=int, default=4096)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--disc-lr', type=float, default=2.5e-5)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=30000)
parser.add_argument('--step-per-collect', type=int, default=2048)
parser.add_argument('--repeat-per-collect', type=int, default=10)
parser.add_argument('--disc-update-num', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--training-num', type=int, default=64)
parser.add_argument('--test-num', type=int, default=10)
# ppo special
parser.add_argument('--rew-norm', type=int, default=True)
# In theory, `vf-coef` will not make any difference if using Adam optimizer.
parser.add_argument('--vf-coef', type=float, default=0.25)
parser.add_argument('--ent-coef', type=float, default=0.001)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--bound-action-method', type=str, default="clip")
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=int, default=0)
parser.add_argument('--norm-adv', type=int, default=0)
parser.add_argument('--recompute-adv', type=int, default=1)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
)
return parser.parse_args()


def test_gail(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: NoRewardEnv(gym.make(args.task)) for _ in range(args.training_num)],
norm_obs=True
)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
norm_obs=True,
obs_rms=train_envs.obs_rms,
update_obs_rms=False
)

# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net_a = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device
)
actor = ActorProb(
net_a,
args.action_shape,
max_action=args.max_action,
unbounded=True,
device=args.device
).to(args.device)
net_c = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device
)
critic = Critic(net_c, device=args.device).to(args.device)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)

optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
# discriminator
net_d = Net(
args.state_shape,
action_shape=args.action_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device,
concat=True
)
disc_net = Critic(net_d, device=args.device).to(args.device)
for m in disc_net.modules():
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
disc_optim = torch.optim.Adam(disc_net.parameters(), lr=args.disc_lr)

lr_scheduler = None
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(
args.step_per_epoch / args.step_per_collect
) * args.epoch

lr_scheduler = LambdaLR(
optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
)

def dist(*logits):
return Independent(Normal(*logits), 1)

# expert replay buffer
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
dataset_size = dataset['rewards'].size

print("dataset_size", dataset_size)
expert_buffer = ReplayBuffer(dataset_size)

for i in range(dataset_size):
expert_buffer.add(
Batch(
obs=dataset['observations'][i],
act=dataset['actions'][i],
rew=dataset['rewards'][i],
done=dataset['terminals'][i],
obs_next=dataset['next_observations'][i],
)
)
print("dataset loaded")

policy = GAILPolicy(
actor,
critic,
optim,
dist,
expert_buffer,
disc_net,
disc_optim,
disc_update_num=args.disc_update_num,
discount_factor=args.gamma,
gae_lambda=args.gae_lambda,
max_grad_norm=args.max_grad_norm,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
reward_normalization=args.rew_norm,
action_scaling=True,
action_bound_method=args.bound_action_method,
lr_scheduler=lr_scheduler,
action_space=env.action_space,
eps_clip=args.eps_clip,
value_clip=args.value_clip,
dual_clip=args.dual_clip,
advantage_normalization=args.norm_adv,
recompute_advantage=args.recompute_adv
)

# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)

# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_gail'
log_path = os.path.join(args.logdir, args.task, 'gail', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

if not args.watch:
# trainer
result = onpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.repeat_per_collect,
args.test_num,
args.batch_size,
step_per_collect=args.step_per_collect,
save_fn=save_fn,
logger=logger,
test_in_train=False
)
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')


if __name__ == '__main__':
test_gail()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/inverse/results/gail/Hopper-v2_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 2377f2f

Please sign in to comment.