Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] CQL BC loss fixes; PPO/PG/A2|3C action normalization fixes #16531

Merged
merged 74 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
47c91c1
wip
sven1977 Jun 10, 2021
649103c
wip
sven1977 Jun 11, 2021
fbfbd5b
Merge branch 'master' of https://github.com/ray-project/ray into poli…
sven1977 Jun 11, 2021
aa472ca
fix and LINT.
sven1977 Jun 11, 2021
281654d
wip.
sven1977 Jun 13, 2021
8c049fe
wip.
sven1977 Jun 13, 2021
a052fcd
Merge branch 'master' of https://github.com/ray-project/ray into poli…
sven1977 Jun 14, 2021
9d675af
wip.
sven1977 Jun 14, 2021
8dfbec9
fix
sven1977 Jun 14, 2021
eaa6afb
fix
sven1977 Jun 14, 2021
dc6a774
fix
sven1977 Jun 15, 2021
6406687
Merge branch 'master' of https://github.com/ray-project/ray into poli…
sven1977 Jun 15, 2021
e0b6311
wip.
sven1977 Jun 15, 2021
46e84fc
wip.
sven1977 Jun 15, 2021
2835b56
wip.
sven1977 Jun 16, 2021
4351570
wip.
sven1977 Jun 16, 2021
265454a
wip.
sven1977 Jun 16, 2021
ab79eac
wip.
sven1977 Jun 16, 2021
3fb411d
Merge branch 'master' of https://github.com/ray-project/ray into poli…
sven1977 Jun 16, 2021
9443460
wip.
sven1977 Jun 16, 2021
230adee
Merge branch 'master' of https://github.com/ray-project/ray into poli…
sven1977 Jun 16, 2021
f2b4c20
wip.
sven1977 Jun 17, 2021
6e4037c
wip.
sven1977 Jun 17, 2021
45fb626
wip.
sven1977 Jun 17, 2021
2fd6ff7
wip.
sven1977 Jun 17, 2021
e2d0378
wip.
sven1977 Jun 18, 2021
2ad07aa
Merge branch 'master' of https://github.com/ray-project/ray into poli…
sven1977 Jun 18, 2021
97ca8dc
wip.
sven1977 Jun 18, 2021
b5e9542
wip.
sven1977 Jun 18, 2021
967fd1e
wip
sven1977 Jun 18, 2021
f1d0bde
Merge branch 'master' of https://github.com/ray-project/ray into cql_…
sven1977 Jun 18, 2021
58fb139
wip
sven1977 Jun 18, 2021
09ca30f
wip
sven1977 Jun 18, 2021
e12f086
wip
sven1977 Jun 18, 2021
88a49df
wip
sven1977 Jun 18, 2021
3a28efc
wip
sven1977 Jun 18, 2021
9b86f84
fix
sven1977 Jun 18, 2021
6a07e44
Merge branch 'master' of https://github.com/ray-project/ray into cql_…
sven1977 Jun 19, 2021
8700e97
Merge branch 'master' of https://github.com/ray-project/ray into cql_…
sven1977 Jun 19, 2021
cfcd54e
wip
sven1977 Jun 19, 2021
43056d8
Merge branch 'master' of https://github.com/ray-project/ray into cql_…
sven1977 Jun 20, 2021
463901a
fix
sven1977 Jun 20, 2021
b50a138
wip
sven1977 Jun 20, 2021
a3cc35b
Merge branch 'master' into policy_support_add_and_delete
sven1977 Jun 20, 2021
18c41ca
wip
sven1977 Jun 20, 2021
ca44258
wip
sven1977 Jun 20, 2021
2b859e0
wip
sven1977 Jun 21, 2021
6baa539
wip
sven1977 Jun 21, 2021
f33b8b1
Merge branch 'master' of https://github.com/ray-project/ray into cql_…
sven1977 Jun 21, 2021
4598be1
Merge branch 'policy_support_add_and_delete' into cql_fix_bc_loss_term
sven1977 Jun 21, 2021
a7bf42e
LINT
sven1977 Jun 21, 2021
9cb1d60
wip.
sven1977 Jun 21, 2021
503d538
LINT.
sven1977 Jun 21, 2021
9024118
fixes.
sven1977 Jun 21, 2021
13fa9aa
fix and lint
sven1977 Jun 21, 2021
c28b096
fix and lint
sven1977 Jun 22, 2021
6b62aab
fix.
sven1977 Jun 22, 2021
7f40479
fix and lint
sven1977 Jun 23, 2021
9307aca
wip.
sven1977 Jun 23, 2021
42d8c5d
Merge branch 'master' of https://github.com/ray-project/ray into cql_…
sven1977 Jun 25, 2021
3b08816
wip
sven1977 Jun 25, 2021
c079185
wip
sven1977 Jun 28, 2021
15492ab
Merge branch 'master' of https://github.com/ray-project/ray into cql_…
sven1977 Jun 28, 2021
4e5de74
Merge branch 'master' of https://github.com/ray-project/ray into cql_…
sven1977 Jun 29, 2021
a8ab846
wip
sven1977 Jun 29, 2021
28e949f
wip
sven1977 Jun 29, 2021
e43ae8f
wip
sven1977 Jun 29, 2021
3a0f859
wip
sven1977 Jun 29, 2021
ca9f092
wip
sven1977 Jun 29, 2021
6013492
wip
sven1977 Jun 30, 2021
7507cc9
Merge branch 'master' of https://github.com/ray-project/ray into cql_…
sven1977 Jun 30, 2021
5fb5f80
wip
sven1977 Jun 30, 2021
7b3e1ba
wip
sven1977 Jun 30, 2021
b443510
fix
sven1977 Jun 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,35 @@ py_test(
args = ["--yaml-dir=tuned_examples/ars", "--framework=torch"]
)

# CQL
py_test(
name = "run_regression_tests_pendulum_cql_tf",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_tf", "learning_tests_pendulum"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
# Include the zipped json data file as well.
data = [
"tuned_examples/cql/pendulum-cql.yaml",
"tests/data/pendulum/huge.zip",
],
args = ["--yaml-dir=tuned_examples/cql"]
)

py_test(
name = "run_regression_tests_pendulum_cql_torch",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_tf", "learning_tests_pendulum"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
# Include the zipped json data file as well.
data = [
"tuned_examples/cql/pendulum-cql.yaml",
"tests/data/pendulum/huge.zip",
],
args = ["--yaml-dir=tuned_examples/cql", "--framework=torch"]
)

# DDPG
py_test(
name = "run_regression_tests_pendulum_ddpg_tf",
Expand Down Expand Up @@ -465,6 +494,26 @@ py_test(
args = ["--yaml-dir=tuned_examples/sac", "--framework=torch"]
)

py_test(
name = "run_regression_tests_transformed_actions_pendulum_sac_tf",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_tf", "learning_tests_pendulum"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/sac/pendulum-transformed-actions-sac.yaml"],
args = ["--yaml-dir=tuned_examples/sac"]
)

py_test(
name = "run_regression_tests_transformed_actions_pendulum_sac_torch",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_torch", "learning_tests_pendulum"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/sac/pendulum-transformed-actions-sac.yaml"],
args = ["--yaml-dir=tuned_examples/sac", "--framework=torch"]
)


# TD3
py_test(
Expand Down
15 changes: 8 additions & 7 deletions rllib/agents/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,22 @@
SAC_CONFIG, {
# You should override this to point to an offline dataset.
"input": "sampler",
# Offline RL does not need IS estimators.
# Switch off off-policy evaluation.
"input_evaluation": [],
# Number of iterations with Behavior Cloning Pretraining.
"bc_iters": 20000,
# CQL Loss Temperature.
# CQL loss temperature.
"temperature": 1.0,
# Num Actions to sample for CQL Loss.
# Number of actions to sample for CQL loss.
"num_actions": 10,
# Whether to use the Lagrangian for Alpha Prime (in CQL Loss).
# Whether to use the Lagrangian for Alpha Prime (in CQL loss).
"lagrangian": False,
# Lagrangian Threshold.
# Lagrangian threshold.
"lagrangian_thresh": 5.0,
# Min Q Weight multiplier.
# Min Q weight multiplier.
"min_q_weight": 5.0,
# Replay Buffer should be size of offline dataset.
# Replay buffer should be larger or equal the size of the offline
# dataset.
"buffer_size": int(1e6),
})
# __sphinx_doc_end__
Expand Down
21 changes: 2 additions & 19 deletions rllib/agents/cql/cql_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \
MAX_LOG_NN_OUTPUT
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.exploration.random import Random
Expand Down Expand Up @@ -126,25 +124,10 @@ def cql_loss(policy: Policy, model: ModelV2,
actor_loss = tf.reduce_mean(
tf.stop_gradient(alpha) * log_pis_t - min_q)
else:

def bc_log(model, obs, actions):
z = tf.math.atanh(actions)
logits = model.get_policy_output(obs)
mean, log_std = tf.split(logits, 2, axis=-1)
# Mean Clamping for Stability
mean = tf.clip_by_value(mean, MEAN_MIN, MEAN_MAX)
log_std = tf.clip_by_value(log_std, MIN_LOG_NN_OUTPUT,
MAX_LOG_NN_OUTPUT)
std = tf.math.exp(log_std)
normal_dist = tfp.distributions.Normal(mean, std)
return tf.reduce_sum(
normal_dist.log_prob(z) -
tf.math.log(1 - actions * actions + SMALL_NUMBER),
axis=-1)

bc_logp = bc_log(model, model_out_t, actions)
bc_logp = action_dist_t.logp(actions)
actor_loss = tf.reduce_mean(
tf.stop_gradient(alpha) * log_pis_t - bc_logp)
# actor_loss = -tf.reduce_mean(bc_logp)

# Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss)
# SAC Loss:
Expand Down
29 changes: 3 additions & 26 deletions rllib/agents/cql/cql_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
from ray.rllib.policy.policy import LEARNER_STATS_KEY
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \
MAX_LOG_NN_OUTPUT
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import LocalOptimizer, TensorType, \
TrainerConfigDict
from ray.rllib.utils.torch_ops import apply_grad_clipping, atanh, \
from ray.rllib.utils.torch_ops import apply_grad_clipping, \
convert_to_torch_tensor

torch, nn = try_import_torch()
Expand Down Expand Up @@ -130,30 +128,9 @@ def cql_loss(policy: Policy, model: ModelV2,
actor_loss = (alpha.detach() * log_pis_t - min_q).mean()
else:

def bc_log(model, obs, actions):
# Stabilize input to atanh.
normed_actions = \
(actions - action_dist_t.low) / \
(action_dist_t.high - action_dist_t.low) * 2.0 - 1.0
save_normed_actions = torch.clamp(
normed_actions, -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER)
z = atanh(save_normed_actions)

logits = model.get_policy_output(obs)
mean, log_std = torch.chunk(logits, 2, dim=-1)
# Mean Clamping for Stability
mean = torch.clamp(mean, MEAN_MIN, MEAN_MAX)
log_std = torch.clamp(log_std, MIN_LOG_NN_OUTPUT,
MAX_LOG_NN_OUTPUT)
std = torch.exp(log_std)
normal_dist = torch.distributions.Normal(mean, std)
return torch.sum(
normal_dist.log_prob(z) -
torch.log(1 - actions * actions + SMALL_NUMBER),
dim=-1)

bc_logp = bc_log(model, model_out_t, actions)
bc_logp = action_dist_t.logp(actions)
actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean()
# actor_loss = -bc_logp.mean()

if obs.shape[0] == policy.config["train_batch_size"]:
policy.actor_optim.zero_grad()
Expand Down
13 changes: 10 additions & 3 deletions rllib/agents/cql/tests/test_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,18 @@ def test_cql_compilation(self):
config["env"] = "Pendulum-v0"
config["input"] = [data_file]

# In the files, we use here for testing, actions have already
# been normalized.
# This is usually the case when the file was generated by another
# RLlib algorithm (e.g. PPO or SAC).
config["actions_in_input_normalized"] = False
config["clip_actions"] = True
config["train_batch_size"] = 2000

config["num_workers"] = 0 # Run locally.
config["twin_q"] = True
config["clip_actions"] = True
config["normalize_actions"] = True
config["learning_starts"] = 0
config["bc_iters"] = 2 # 2 BC iters, 2 CQL iters.
config["rollout_fragment_length"] = 1

# Switch on off-policy evaluation.
Expand All @@ -56,7 +63,7 @@ def test_cql_compilation(self):
config["evaluation_parallel_to_training"] = True
config["evaluation_num_workers"] = 2

num_iterations = 3
num_iterations = 4

# Test for tf/torch frameworks.
for fw in framework_iterator(config):
Expand Down
10 changes: 3 additions & 7 deletions rllib/agents/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,18 @@
"custom_model": None, # Use this to define a custom policy model.
"custom_model_config": {},
},
# Unsquash actions to the upper and lower bounds of env's action space.
# Ignored for discrete action spaces.
"normalize_actions": True,
# Actions are already normalized, no need to clip them further.
"clip_actions": False,

# === Learning ===
# Disable setting done=True at end of episode. This should be set to True
# for infinite-horizon MDPs (e.g., many continuous control problems).
"no_done_at_end": False,
# Update the target by \tau * policy + (1-\tau) * target_policy.
"tau": 5e-3,
# Initial value to use for the entropy weight alpha.
"initial_alpha": 1.0,
# Target entropy lower bound. If "auto", will be set to -|A| (e.g. -2.0 for
# Discrete(2), -3.0 for Box(shape=(3,))).
# This is the inverse of reward scale, and will be optimized automatically.
"target_entropy": None,
"target_entropy": "auto",
# N-step target updates. If >1, sars' tuples in trajectories will be
# postprocessed to become sa[discounted sum of R][s t+n] tuples.
"n_step": 1,
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/sac/sac_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def _get_dist_class(config: TrainerConfigDict, action_space: gym.spaces.Space
elif isinstance(action_space, Simplex):
return Dirichlet
else:
assert isinstance(action_space, Box)
if config["normalize_actions"]:
return SquashedGaussian if \
not config["_use_beta_distribution"] else Beta
Expand Down
3 changes: 2 additions & 1 deletion rllib/agents/sac/sac_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import gym
from gym.spaces import Discrete
from gym.spaces import Box, Discrete
import logging
from typing import Dict, List, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -48,6 +48,7 @@ def _get_dist_class(config: TrainerConfigDict, action_space: gym.spaces.Space
elif isinstance(action_space, Simplex):
return TorchDirichlet
else:
assert isinstance(action_space, Box)
if config["normalize_actions"]:
return TorchSquashedGaussian if \
not config["_use_beta_distribution"] else TorchBeta
Expand Down
Loading