Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Update Average tf_metrics to a TF implementation. Speeds up train_eva…
…l by ~10%.

PiperOrigin-RevId: 260054790
Change-Id: Id1228e6248037c2872484c468a8cd683e58af725
  • Loading branch information
Oscar Ramirez authored and Copybara-Service committed Jul 26, 2019
1 parent 0a954b1 commit b08a142
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 62 deletions.
6 changes: 4 additions & 2 deletions tf_agents/agents/ppo/examples/v1/train_eval.py
Expand Up @@ -195,8 +195,10 @@ def train_eval(
environment_steps_metric,
]
train_metrics = step_metrics + [
tf_metrics.AverageReturnMetric(),
tf_metrics.AverageEpisodeLengthMetric(),
tf_metrics.AverageReturnMetric(
batch_size=num_parallel_environments),
tf_metrics.AverageEpisodeLengthMetric(
batch_size=num_parallel_environments),
]

# Add to replay buffer and other agent specific observers.
Expand Down
6 changes: 4 additions & 2 deletions tf_agents/agents/ppo/examples/v2/train_eval.py
Expand Up @@ -172,8 +172,10 @@ def train_eval(
]

train_metrics = step_metrics + [
tf_metrics.AverageReturnMetric(),
tf_metrics.AverageEpisodeLengthMetric(),
tf_metrics.AverageReturnMetric(
batch_size=num_parallel_environments),
tf_metrics.AverageEpisodeLengthMetric(
batch_size=num_parallel_environments),
]

eval_policy = tf_agent.policy
Expand Down
12 changes: 5 additions & 7 deletions tf_agents/agents/sac/examples/v2/train_eval.py
Expand Up @@ -26,7 +26,6 @@
```
"""


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand All @@ -47,9 +46,7 @@
from tf_agents.environments import suite_mujoco
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import py_metrics
from tf_agents.metrics import tf_metrics
from tf_agents.metrics import tf_py_metric
from tf_agents.networks import actor_distribution_network
from tf_agents.networks import normal_projection_network
from tf_agents.policies import greedy_policy
Expand All @@ -60,8 +57,7 @@

flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_multi_string('gin_file', None,
'Path to the trainer config files.')
flags.DEFINE_multi_string('gin_file', None, 'Path to the trainer config files.')
flags.DEFINE_multi_string('gin_param', None, 'Gin binding to pass through.')

FLAGS = flags.FLAGS
Expand Down Expand Up @@ -193,8 +189,10 @@ def train_eval(
train_metrics = [
tf_metrics.NumberOfEpisodes(),
tf_metrics.EnvironmentSteps(),
tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
tf_metrics.AverageReturnMetric(
buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
tf_metrics.AverageEpisodeLengthMetric(
buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
]

eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
Expand Down
3 changes: 2 additions & 1 deletion tf_agents/metrics/tf_metric.py
Expand Up @@ -117,7 +117,8 @@ def tf_summaries(self, train_step=None, step_metrics=()):
if self.name == step_metric.name:
continue
step_tag = '{}vs_{}/{}'.format(prefix, step_metric.name, self.name)
step = step_metric.result()
# Summaries expect the step value to be an int64.
step = tf.cast(step_metric.result(), tf.int64)
summaries.append(tf.compat.v2.summary.scalar(
name=step_tag,
data=result,
Expand Down
89 changes: 73 additions & 16 deletions tf_agents/metrics/tf_metrics.py
Expand Up @@ -23,9 +23,7 @@

import tensorflow as tf

from tf_agents.metrics import py_metrics
from tf_agents.metrics import tf_metric
from tf_agents.metrics import tf_py_metric
from tf_agents.utils import common


Expand Down Expand Up @@ -101,8 +99,7 @@ def call(self, trajectory):
return trajectory

def result(self):
return tf.identity(
self.environment_steps, name=self.name)
return tf.identity(self.environment_steps, name=self.name)

@common.function
def reset(self):
Expand Down Expand Up @@ -136,37 +133,97 @@ def call(self, trajectory):
return trajectory

def result(self):
return tf.identity(
self.number_episodes, name=self.name)
return tf.identity(self.number_episodes, name=self.name)

@common.function
def reset(self):
self.number_episodes.assign(0)


class AverageReturnMetric(tf_py_metric.TFPyMetric):
class AverageReturnMetric(tf_metric.TFStepMetric):
"""Metric to compute the average return."""

def __init__(self, name='AverageReturn', dtype=tf.float32, buffer_size=10):
py_metric = py_metrics.AverageReturnMetric(buffer_size=buffer_size)
def __init__(self,
name='AverageReturn',
dtype=tf.float32,
batch_size=1,
buffer_size=10):
super(AverageReturnMetric, self).__init__(name=name)
self._buffer = TFDeque(buffer_size, dtype)
self._dtype = dtype
self._return_accumulator = common.create_variable(
initial_value=0, dtype=dtype, shape=(batch_size,), name='Accumulator')

@common.function(autograph=True)
def call(self, trajectory):
# Zero out batch indices where a new episode is starting.
self._return_accumulator.assign(
tf.where(trajectory.is_first(), tf.zeros_like(self._return_accumulator),
self._return_accumulator))

# Update accumulator with received rewards.
self._return_accumulator.assign_add(trajectory.reward)

# Add final returns to buffer.
last_episode_indices = tf.squeeze(tf.where(trajectory.is_last()), axis=-1)
for indx in last_episode_indices:
self._buffer.add(self._return_accumulator[indx])

return trajectory

def result(self):
return self._buffer.mean()

super(AverageReturnMetric, self).__init__(
py_metric=py_metric, name=name, dtype=dtype)
@common.function
def reset(self):
self._buffer.clear()
self._return_accumulator.assign(tf.zeros_like(self._return_accumulator))


class AverageEpisodeLengthMetric(tf_py_metric.TFPyMetric):
class AverageEpisodeLengthMetric(tf_metric.TFStepMetric):
"""Metric to compute the average episode length."""

def __init__(self,
name='AverageEpisodeLength',
dtype=tf.float32,
batch_size=1,
buffer_size=10):
super(AverageEpisodeLengthMetric, self).__init__(name=name)
self._buffer = TFDeque(buffer_size, dtype)
self._dtype = dtype
self._length_accumulator = common.create_variable(
initial_value=0, dtype=dtype, shape=(batch_size,), name='Accumulator')

@common.function(autograph=True)
def call(self, trajectory):
# Each non-boundary trajectory (first, mid or last) represents a step.
non_boundary_indices = tf.squeeze(
tf.where(tf.logical_not(trajectory.is_boundary())), axis=-1)
self._length_accumulator.scatter_add(
tf.IndexedSlices(
tf.ones_like(
non_boundary_indices, dtype=self._length_accumulator.dtype),
non_boundary_indices))

# Add lengths to buffer when we hit end of episode
last_indices = tf.squeeze(tf.where(trajectory.is_last()), axis=-1)
for indx in last_indices:
self._buffer.add(self._length_accumulator[indx])

# Clear length accumulator at the end of episodes.
self._length_accumulator.scatter_update(
tf.IndexedSlices(
tf.zeros_like(last_indices, dtype=self._dtype), last_indices))

return trajectory

py_metric = py_metrics.AverageEpisodeLengthMetric(
buffer_size=buffer_size)
def result(self):
return self._buffer.mean()

super(AverageEpisodeLengthMetric, self).__init__(
py_metric=py_metric, name=name, dtype=dtype)
@common.function
def reset(self):
self._buffer.clear()
self._length_accumulator.assign(tf.zeros_like(self._length_accumulator))


def log_metrics(metrics, prefix=''):
Expand Down
70 changes: 36 additions & 34 deletions tf_agents/metrics/tf_metrics_test.py
Expand Up @@ -22,7 +22,6 @@
import tensorflow as tf
from tf_agents.metrics import tf_metrics
from tf_agents.trajectories import trajectory
from tf_agents.utils import nest_utils

from tensorflow.python.eager import context # TF internal

Expand Down Expand Up @@ -99,32 +98,35 @@ def test_extend(self):
class TFMetricsTest(parameterized.TestCase, tf.test.TestCase):

def _create_trajectories(self):

def _concat_nested_tensors(nest1, nest2):
return tf.nest.map_structure(lambda t1, t2: tf.concat([t1, t2], axis=0),
nest1, nest2)

# Order of args for trajectory methods:
# observation, action, policy_info, reward, discount
ts0 = nest_utils.stack_nested_tensors([
trajectory.boundary((), (), (), 0., 1.),
trajectory.boundary((), (), (), 0., 1.)
])
ts1 = nest_utils.stack_nested_tensors([
trajectory.first((), (), (), 1., 1.),
trajectory.first((), (), (), 2., 1.)
])
ts2 = nest_utils.stack_nested_tensors([
trajectory.last((), (), (), 3., 1.),
trajectory.last((), (), (), 4., 1.)
])
ts3 = nest_utils.stack_nested_tensors([
trajectory.boundary((), (), (), 0., 1.),
trajectory.boundary((), (), (), 0., 1.)
])
ts4 = nest_utils.stack_nested_tensors([
trajectory.first((), (), (), 5., 1.),
trajectory.first((), (), (), 6., 1.)
])
ts5 = nest_utils.stack_nested_tensors([
trajectory.last((), (), (), 7., 1.),
trajectory.last((), (), (), 8., 1.)
])
ts0 = _concat_nested_tensors(
trajectory.boundary((), (), (), tf.constant([0.], dtype=tf.float32),
[1.]),
trajectory.boundary((), (), (), tf.constant([0.], dtype=tf.float32),
[1.]))
ts1 = _concat_nested_tensors(
trajectory.first((), (), (), tf.constant([1.], dtype=tf.float32), [1.]),
trajectory.first((), (), (), tf.constant([2.], dtype=tf.float32), [1.]))
ts2 = _concat_nested_tensors(
trajectory.last((), (), (), tf.constant([3.], dtype=tf.float32), [1.]),
trajectory.last((), (), (), tf.constant([4.], dtype=tf.float32), [1.]))
ts3 = _concat_nested_tensors(
trajectory.boundary((), (), (), tf.constant([0.], dtype=tf.float32),
[1.]),
trajectory.boundary((), (), (), tf.constant([0.], dtype=tf.float32),
[1.]))
ts4 = _concat_nested_tensors(
trajectory.first((), (), (), tf.constant([5.], dtype=tf.float32), [1.]),
trajectory.first((), (), (), tf.constant([6.], dtype=tf.float32), [1.]))
ts5 = _concat_nested_tensors(
trajectory.last((), (), (), tf.constant([7.], dtype=tf.float32), [1.]),
trajectory.last((), (), (), tf.constant([8.], dtype=tf.float32), [1.]))

return [ts0, ts1, ts2, ts3, ts4, ts5]

Expand All @@ -150,17 +152,17 @@ def testMetric(self, run_mode, metric_class, num_trajectories,
expected_result):
with run_mode():
trajectories = self._create_trajectories()
metric = metric_class()
deps = []
if metric_class in [tf_metrics.AverageReturnMetric,
tf_metrics.AverageEpisodeLengthMetric]:
metric = metric_class(batch_size=2)
else:
metric = metric_class()
self.evaluate(tf.compat.v1.global_variables_initializer())
self.evaluate(metric.init_variables())
for i in range(num_trajectories):
with tf.control_dependencies(deps):
traj = metric(trajectories[i])
deps = tf.nest.flatten(traj)
with tf.control_dependencies(deps):
result = metric.result()
result_ = self.evaluate(result)
self.assertEqual(result_, expected_result)
self.evaluate(metric(trajectories[i]))

self.assertEqual(expected_result, self.evaluate(metric.result()))
self.evaluate(metric.reset())
self.assertEqual(0.0, self.evaluate(metric.result()))

Expand Down

0 comments on commit b08a142

Please sign in to comment.