Skip to content

Commit

Permalink
[RLlib] MARWIL loss function test case and cleanup. (#13455)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 committed Jan 19, 2021
1 parent 2506a6c commit a65ee92
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 71 deletions.
120 changes: 57 additions & 63 deletions rllib/agents/marwil/marwil_tf_policy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import ray
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Expand All @@ -8,6 +10,8 @@

tf1, tf, tfv = try_import_tf()

logger = logging.getLogger(__name__)


class ValueNetworkMixin:
def __init__(self, obs_space, action_space, config):
Expand Down Expand Up @@ -43,47 +47,6 @@ def value(ob, prev_action, prev_reward, *state):
self._value = value


class ValueLoss:
def __init__(self, state_values, cumulative_rewards):
self.loss = 0.5 * tf.reduce_mean(
tf.math.square(state_values - cumulative_rewards))


class ReweightedImitationLoss:
def __init__(self, policy, state_values, cumulative_rewards, actions,
action_dist, beta):
if beta != 0.0:
# Advantage Estimation.
adv = cumulative_rewards - state_values

# Update averaged advantage norm.
# Eager.
if policy.config["framework"] in ["tf2", "tfe"]:
policy._ma_adv_norm.assign_add(1e-6 * (
tf.reduce_mean(tf.math.square(adv)) - policy._ma_adv_norm))
# Exponentially weighted advantages.
exp_advs = tf.math.exp(beta * tf.math.divide(
adv, 1e-8 + tf.math.sqrt(policy._ma_adv_norm)))
# Static graph.
else:
update_adv_norm = tf1.assign_add(
ref=policy._ma_adv_norm,
value=1e-6 * (tf.reduce_mean(tf.math.square(adv)) -
policy._ma_adv_norm))

# Exponentially weighted advantages.
with tf1.control_dependencies([update_adv_norm]):
exp_advs = tf.math.exp(beta * tf.math.divide(
adv, 1e-8 + tf.math.sqrt(policy._ma_adv_norm)))
exp_advs = tf.stop_gradient(exp_advs)
else:
exp_advs = 1.0

# L = - A * log\pi_\theta(a|s)
logprobs = action_dist.logp(actions)
self.loss = -1.0 * tf.reduce_mean(exp_advs * logprobs)


def postprocess_advantages(policy,
sample_batch,
other_agent_batches=None,
Expand Down Expand Up @@ -135,43 +98,74 @@ def postprocess_advantages(policy,
sample_batch[SampleBatch.REWARDS][-1],
*next_state)

# Adds the policy logits, VF preds, and advantages to the batch,
# using GAE ("generalized advantage estimation") or not.
# Adds the "advantages" (which in the case of MARWIL are simply the
# discounted cummulative rewards) to the SampleBatch.
return compute_advantages(
sample_batch,
last_r,
policy.config["gamma"],
# We just want the discounted cummulative rewards, so we won't need
# GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
use_gae=False,
use_critic=False)


class MARWILLoss:
def __init__(self, policy, state_values, action_dist, actions, advantages,
vf_loss_coeff, beta):
def __init__(self, policy, value_estimates, action_dist, actions,
cumulative_rewards, vf_loss_coeff, beta):

# Advantage Estimation.
adv = cumulative_rewards - value_estimates
adv_squared = tf.reduce_mean(tf.math.square(adv))

# Value function's loss term (MSE).
self.v_loss = 0.5 * adv_squared

if beta != 0.0:
# Perform moving averaging of advantage^2.

self.v_loss = self._build_value_loss(state_values, advantages)
self.p_loss = self._build_policy_loss(policy, state_values, advantages,
actions, action_dist, beta)
# Update averaged advantage norm.
# Eager.
if policy.config["framework"] in ["tf2", "tfe"]:
update_term = adv_squared - policy._moving_average_sqd_adv_norm
policy._moving_average_sqd_adv_norm.assign_add(
1e-8 * update_term)

self.total_loss = self.p_loss.loss + vf_loss_coeff * self.v_loss.loss
explained_var = explained_variance(advantages, state_values)
self.explained_variance = tf.reduce_mean(explained_var)
# Exponentially weighted advantages.
c = tf.math.sqrt(policy._moving_average_sqd_adv_norm)
exp_advs = tf.math.exp(beta * (adv / c))
# Static graph.
else:
update_adv_norm = tf1.assign_add(
ref=policy._moving_average_sqd_adv_norm,
value=1e-6 *
(adv_squared - policy._moving_average_sqd_adv_norm))

# Exponentially weighted advantages.
with tf1.control_dependencies([update_adv_norm]):
exp_advs = tf.math.exp(beta * tf.math.divide(
adv, 1e-8 + tf.math.sqrt(
policy._moving_average_sqd_adv_norm)))
exp_advs = tf.stop_gradient(exp_advs)
else:
exp_advs = 1.0

# L = - A * log\pi_\theta(a|s)
logprobs = action_dist.logp(actions)
self.p_loss = -1.0 * tf.reduce_mean(exp_advs * logprobs)

def _build_value_loss(self, state_values, cum_rwds):
return ValueLoss(state_values, cum_rwds)
self.total_loss = self.p_loss + vf_loss_coeff * self.v_loss

def _build_policy_loss(self, policy, state_values, cum_rwds, actions,
action_dist, beta):
return ReweightedImitationLoss(policy, state_values, cum_rwds, actions,
action_dist, beta)
self.explained_variance = tf.reduce_mean(
explained_variance(cumulative_rewards, value_estimates))


def marwil_loss(policy, model, dist_class, train_batch):
model_out, _ = model.from_batch(train_batch)
action_dist = dist_class(model_out, model)
state_values = model.value_function()
value_estimates = model.value_function()

policy.loss = MARWILLoss(policy, state_values, action_dist,
policy.loss = MARWILLoss(policy, value_estimates, action_dist,
train_batch[SampleBatch.ACTIONS],
train_batch[Postprocessing.ADVANTAGES],
policy.config["vf_coeff"], policy.config["beta"])
Expand All @@ -181,8 +175,8 @@ def marwil_loss(policy, model, dist_class, train_batch):

def stats(policy, train_batch):
return {
"policy_loss": policy.loss.p_loss.loss,
"vf_loss": policy.loss.v_loss.loss,
"policy_loss": policy.loss.p_loss,
"vf_loss": policy.loss.v_loss,
"total_loss": policy.loss.total_loss,
"vf_explained_var": policy.loss.explained_variance,
}
Expand All @@ -191,8 +185,8 @@ def stats(policy, train_batch):
def setup_mixins(policy, obs_space, action_space, config):
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
# Set up a tf-var for the moving avg (do this here to make it work with
# eager mode).
policy._ma_adv_norm = get_variable(
# eager mode); "c^2" in the paper.
policy._moving_average_sqd_adv_norm = get_variable(
100.0,
framework="tf",
tf_name="moving_average_of_advantage_norm",
Expand Down
13 changes: 7 additions & 6 deletions rllib/agents/marwil/marwil_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,17 @@ def marwil_loss(policy, model, dist_class, train_batch):
advantages = train_batch[Postprocessing.ADVANTAGES]
actions = train_batch[SampleBatch.ACTIONS]

# Advantage estimation.
adv = advantages - state_values
adv_squared = torch.mean(torch.pow(adv, 2.0))

# Value loss.
policy.v_loss = 0.5 * torch.mean(torch.pow(state_values - advantages, 2.0))
policy.v_loss = 0.5 * adv_squared

# Policy loss.
# Advantage estimation.
adv = advantages - state_values
# Update averaged advantage norm.
policy.ma_adv_norm.add_(
1e-6 * (torch.mean(torch.pow(adv, 2.0)) - policy.ma_adv_norm))
# #xponentially weighted advantages.
policy.ma_adv_norm.add_(1e-6 * (adv_squared - policy.ma_adv_norm))
# Exponentially weighted advantages.
exp_advs = torch.exp(policy.config["beta"] *
(adv / (1e-8 + torch.pow(policy.ma_adv_norm, 0.5))))
# log\pi_\theta(a|s)
Expand Down
76 changes: 74 additions & 2 deletions rllib/agents/marwil/tests/test_marwil.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import numpy as np
import os
from pathlib import Path
import unittest

import ray
import ray.rllib.agents.marwil as marwil
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check_compute_single_action, \
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.offline import JsonReader
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
framework_iterator

tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()


class TestMARWIL(unittest.TestCase):
Expand Down Expand Up @@ -70,6 +74,74 @@ def test_marwil_compilation_and_learning_from_offline_file(self):

trainer.stop()

def test_marwil_loss_function(self):
"""
To generate the historic data used in this test case, first run:
$ ./train.py --run=PPO --env=CartPole-v0 \
--stop='{"timesteps_total": 50000}' \
--config='{"output": "/tmp/out", "batch_mode": "complete_episodes"}'
"""
rllib_dir = Path(__file__).parent.parent.parent.parent
print("rllib dir={}".format(rllib_dir))
data_file = os.path.join(rllib_dir, "tests/data/cartpole/small.json")
print("data_file={} exists={}".format(data_file,
os.path.isfile(data_file)))
config = marwil.DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
# Learn from offline data.
config["input"] = [data_file]

for fw in framework_iterator(config, frameworks=["torch", "tf2"]):
reader = JsonReader(inputs=[data_file])
batch = reader.next()

trainer = marwil.MARWILTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
model = policy.model

# Calculate our own expected values (to then compare against the
# agent's loss output).
cummulative_rewards = compute_advantages(
batch, 0.0, config["gamma"], 1.0, False, False)["advantages"]
if fw == "torch":
cummulative_rewards = torch.tensor(cummulative_rewards)
tensor_batch = policy._lazy_tensor_dict(batch)
model_out, _ = model.from_batch(tensor_batch)
vf_estimates = model.value_function()
adv = cummulative_rewards - vf_estimates
if fw == "torch":
adv = adv.detach().cpu().numpy()
adv_squared = np.mean(np.square(adv))
c_2 = 100.0 + 1e-8 * (adv_squared - 100.0)
c = np.sqrt(c_2)
exp_advs = np.exp(config["beta"] * (adv / c))
logp = policy.dist_class(model_out,
model).logp(tensor_batch["actions"])
if fw == "torch":
logp = logp.detach().cpu().numpy()
# Calculate all expected loss components.
expected_vf_loss = 0.5 * adv_squared
expected_pol_loss = -1.0 * np.mean(exp_advs * logp)
expected_loss = \
expected_pol_loss + config["vf_coeff"] * expected_vf_loss

# Calculate the algorithm's loss (to check against our own
# calculation above).
postprocessed_batch = policy.postprocess_trajectory(batch)
loss_func = marwil.marwil_tf_policy.marwil_loss if fw != "torch" \
else marwil.marwil_torch_policy.marwil_loss
loss_out = loss_func(policy, model, policy.dist_class,
policy._lazy_tensor_dict(postprocessed_batch))

# Check all components.
if fw == "torch":
check(policy.v_loss, expected_vf_loss, decimals=4)
check(policy.p_loss, expected_pol_loss, decimals=4)
else:
check(policy.loss.v_loss, expected_vf_loss, decimals=4)
check(policy.loss.p_loss, expected_pol_loss, decimals=4)
check(loss_out, expected_loss, decimals=3)


if __name__ == "__main__":
import pytest
Expand Down

0 comments on commit a65ee92

Please sign in to comment.