Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf-agents SAC 10x slower than stable-baselines on same hardware #275

Closed
pirobot opened this issue Dec 22, 2019 · 18 comments
Closed

tf-agents SAC 10x slower than stable-baselines on same hardware #275

pirobot opened this issue Dec 22, 2019 · 18 comments
Assignees
Labels
level:p1 type:performance Issues related to performance. Throughput, time to train, or accuracy/reward.

Comments

@pirobot
Copy link

pirobot commented Dec 22, 2019

I am running a simple test of SAC using the LunarLanderContinuous-v2 environment. Training is for 500,000 steps with a replay buffer of size 50,000 (see code below). tf-agents takes over 10 hours to complete training whereas the stable-baselines implementation of SAC using the same hyperparameters only takes 39 minutes. I've checked and double-check my version of CUDA, tensorflow-gpu, tf-agent, etc and cannot speed things up.

Here are the details to reproduce:

Ubuntu 16.04, tf-agents==0.3.0, tensorflow-gpu==1.15.0, gym==0.15.4, CUDA==10.0, cudnn==7.6.5, stable-baselines==2.9.0a0, GPU==Quadro M4000 8Gb, CPU==i7 64 Gb

My tf-agents test script is simply the v2 train_eval.py script from the sac/examples after substituting the LunarLanderContinuous-v2 environment for Half Cheetah and changing the hyperparameters as you can see below:

# coding=utf-8
# Copyright 2018 The TF-Agents Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Train and Eval SAC.

To run:

#bash
#tensorboard --logdir $HOME/tmp/sac/gym/HalfCheetah-v2/ --port 2223 &
#
#python tf_agents/agents/sac/examples/v2/train_eval.py \
#  --root_dir=$HOME/tmp/sac/gym/HalfCheetah-v2/ \
#  --alsologtostderr
#```
#"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time

from absl import app
from absl import flags
from absl import logging

import gin
import tensorflow as tf

from tf_agents.agents.ddpg import critic_network
from tf_agents.agents.sac import sac_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import parallel_py_environment
from tf_agents.environments import suite_mujoco
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import actor_distribution_network
from tf_agents.networks import normal_projection_network
from tf_agents.policies import greedy_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.utils import common

flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
                    'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_multi_string('gin_file', None, 'Path to the trainer config files.')
flags.DEFINE_multi_string('gin_param', None, 'Gin binding to pass through.')

FLAGS = flags.FLAGS


@gin.configurable
def normal_projection_net(action_spec,
                          init_action_stddev=0.35,
                          init_means_output_factor=0.1):
  del init_action_stddev
  return normal_projection_network.NormalProjectionNetwork(
      action_spec,
      mean_transform=None,
      state_dependent_std=True,
      init_means_output_factor=init_means_output_factor,
      std_transform=sac_agent.std_clip_transform,
      scale_distribution=True)


_DEFAULT_REWARD_SCALE = 0


@gin.configurable
def train_eval(
    root_dir,
    env_name='LunarLanderContinuous-v2',
    eval_env_name=None,
    env_load_fn=suite_mujoco.load,
    num_iterations=500000,
    actor_fc_layers=(64, 64),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(64, 64),
    num_parallel_environments=1,
    # Params for collect
    initial_collect_steps=100,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=50000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=64,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=_DEFAULT_REWARD_SCALE,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=100,
    eval_interval=1000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for SAC on Mujoco.

  All hyperparameters come from the original SAC paper
  (https://arxiv.org/pdf/1801.01290.pdf).
  """

  if reward_scale_factor == _DEFAULT_REWARD_SCALE:
    # Use value recommended by https://arxiv.org/abs/1801.01290
    if env_name.startswith('Humanoid'):
      reward_scale_factor = 20.0
    else:
      reward_scale_factor = 5.0

  root_dir = os.path.expanduser(root_dir)

  summary_writer = tf.compat.v2.summary.create_file_writer(
      root_dir, flush_millis=summaries_flush_secs * 1000)
  summary_writer.set_as_default()

  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    # create training environment
    if num_parallel_environments == 1:
      py_env = env_load_fn(env_name)
    else:
      py_env = parallel_py_environment.ParallelPyEnvironment(
          [lambda: env_load_fn(env_name)] * num_parallel_environments)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)
    # create evaluation environment
    eval_env_name = eval_env_name or env_name
    eval_py_env = env_load_fn(eval_env_name)
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=normal_projection_net)
    critic_net = critic_network.CriticNetwork(
        (observation_spec, action_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers)

    tf_agent = sac_agent.SacAgent(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=num_parallel_environments,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    env_steps = tf_metrics.EnvironmentSteps(prefix='Train')
    average_return = tf_metrics.AverageReturnMetric(
        prefix='Train',
        buffer_size=num_eval_episodes,
        batch_size=tf_env.batch_size)
    train_metrics = [
        tf_metrics.NumberOfEpisodes(prefix='Train'),
        env_steps,
        average_return,
        tf_metrics.AverageEpisodeLengthMetric(
            prefix='Train',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size),
    ]

    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    collect_policy = tf_agent.collect_policy

    train_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(root_dir, 'train'),
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(root_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(root_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=initial_collect_steps)

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)

    if use_tf_functions:
      initial_collect_driver.run = common.function(initial_collect_driver.run)
      collect_driver.run = common.function(collect_driver.run)
      tf_agent.train = common.function(tf_agent.train)

    # Collect initial replay data.
    if env_steps.result() == 0 or replay_buffer.num_frames() == 0:
      logging.info(
          'Initializing replay buffer by collecting experience for %d steps'
          'with a random policy.', initial_collect_steps)
      initial_collect_driver.run()

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=env_steps.result(),
        summary_writer=summary_writer,
        summary_prefix='Eval',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(results, env_steps.result())
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    time_acc = 0
    env_steps_before = env_steps.result().numpy()

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3, sample_batch_size=batch_size,
        num_steps=2).prefetch(3)
    iterator = iter(dataset)

    def train_step():
      experience, _ = next(iterator)
      return tf_agent.train(experience)

    if use_tf_functions:
      train_step = common.function(train_step)

    for _ in range(num_iterations):
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      for _ in range(train_steps_per_iteration):
        train_step()
      time_acc += time.time() - start_time

      if global_step.numpy() % log_interval == 0:
        logging.info('env steps = %d, average return = %f', env_steps.result(),
                     average_return.result())
        env_steps_per_sec = (env_steps.result().numpy() -
                             env_steps_before) / time_acc
        logging.info('%.3f env steps/sec', env_steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='env_steps_per_sec',
            data=env_steps_per_sec,
            step=env_steps.result())
        time_acc = 0
        env_steps_before = env_steps.result().numpy()

      for train_metric in train_metrics:
        train_metric.tf_summaries(train_step=env_steps.result())

      if global_step.numpy() % eval_interval == 0:
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=env_steps.result(),
            summary_writer=summary_writer,
            summary_prefix='Eval',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, env_steps.result())
        metric_utils.log_metrics(eval_metrics)

      global_step_val = global_step.numpy()
      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=global_step_val)


def main(_):
  tf.compat.v1.enable_v2_behavior()
  logging.set_verbosity(logging.INFO)
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
  train_eval(FLAGS.root_dir)


if __name__ == '__main__':
  flags.mark_flag_as_required('root_dir')
  app.run(main)

My stable-baselines script looks like this:

import gym
import numpy as np

from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.common import make_vec_env
from stable_baselines.sac.policies import MlpPolicy
from stable_baselines import SAC

env = make_vec_env('LunarLanderContinuous-v2', n_envs=1)

model_name = "sac_lunar_lander"

model = SAC(MlpPolicy, env, verbose=1, tensorboard_log="./tensorboard_logs/stable_baselines_test")

model.learn(total_timesteps=500000, log_interval=10)
model.save(model_name)

Finally, here is the output when I run the tf-agents script to show that the GPU is being detected and used:

2019-12-22 11:26:35.054589: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcuda.so.1
2019-12-22 11:26:35.068596: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1618] Found device 0 with properties: 
name: Quadro M4000 major: 5 minor: 2 memoryClockRate(GHz): 0.7725
pciBusID: 0000:01:00.0
2019-12-22 11:26:35.068767: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.0
2019-12-22 11:26:35.069770: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10.0
2019-12-22 11:26:35.070479: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcufft.so.10.0
2019-12-22 11:26:35.070640: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcurand.so.10.0
2019-12-22 11:26:35.071572: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcusolver.so.10.0
2019-12-22 11:26:35.072306: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcusparse.so.10.0
2019-12-22 11:26:35.074604: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudnn.so.7
2019-12-22 11:26:35.075808: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1746] Adding visible gpu devices: 0
2019-12-22 11:26:35.076022: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-12-22 11:26:35.080915: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 3407920000 Hz
2019-12-22 11:26:35.081214: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x555945a77880 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2019-12-22 11:26:35.081228: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2019-12-22 11:26:35.144953: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x555945a9b180 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2019-12-22 11:26:35.144974: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Quadro M4000, Compute Capability 5.2
2019-12-22 11:26:35.145550: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1618] Found device 0 with properties: 
name: Quadro M4000 major: 5 minor: 2 memoryClockRate(GHz): 0.7725
pciBusID: 0000:01:00.0
2019-12-22 11:26:35.145578: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.0
2019-12-22 11:26:35.145588: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10.0
2019-12-22 11:26:35.145597: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcufft.so.10.0
2019-12-22 11:26:35.145605: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcurand.so.10.0
2019-12-22 11:26:35.145629: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcusolver.so.10.0
2019-12-22 11:26:35.145650: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcusparse.so.10.0
2019-12-22 11:26:35.145674: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudnn.so.7
2019-12-22 11:26:35.146551: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1746] Adding visible gpu devices: 0
2019-12-22 11:26:35.146575: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.0
2019-12-22 11:26:35.147375: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1159] Device interconnect StreamExecutor with strength 1 edge matrix:
2019-12-22 11:26:35.147384: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1165]      0 
2019-12-22 11:26:35.147388: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1178] 0:   N 
2019-12-22 11:26:35.148348: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1304] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 6876 MB memory) -> physical GPU (device: 0, name: Quadro M4000, pci bus id: 0000:01:00.0, compute capability: 5.2)
/home/patrick/src/gym/gym/logger.py:30: UserWarning: WARN: Box bound precision lowered by casting to float32
  warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))
WARNING:tensorflow:From /home/patrick/src/tf_agents/tf_agents/agents/ddpg/critic_network.py:141: The name tf.keras.initializers.RandomUniform is deprecated. Please use tf.compat.v1.keras.initializers.RandomUniform instead.

W1222 11:26:35.589284 140187933329152 module_wrapper.py:139] From /home/patrick/src/tf_agents/tf_agents/agents/ddpg/critic_network.py:141: The name tf.keras.initializers.RandomUniform is deprecated. Please use tf.compat.v1.keras.initializers.RandomUniform instead.

2019-12-22 11:26:35.600509: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10.0
WARNING:tensorflow:From /home/patrick/src/tf_agents/tf_agents/distributions/utils.py:92: AffineScalar.__init__ (from tensorflow_probability.python.bijectors.affine_scalar) is deprecated and will be removed after 2020-01-01.
Instructions for updating:
`AffineScalar` bijector is deprecated; please use `tfb.Shift(loc)(tfb.Scale(...))` instead.
W1222 11:26:35.787435 140187933329152 deprecation.py:323] From /home/patrick/src/tf_agents/tf_agents/distributions/utils.py:92: AffineScalar.__init__ (from tensorflow_probability.python.bijectors.affine_scalar) is deprecated and will be removed after 2020-01-01.
Instructions for updating:
`AffineScalar` bijector is deprecated; please use `tfb.Shift(loc)(tfb.Scale(...))` instead.
I1222 11:26:35.814536 140187933329152 common.py:920] Checkpoint available: tensorboard_logs/tf_agents_v2/train/ckpt-30000
I1222 11:26:35.902629 140187933329152 common.py:920] Checkpoint available: tensorboard_logs/tf_agents_v2/policy/ckpt-35000
I1222 11:26:35.908307 140187933329152 common.py:923] No checkpoint available at tensorboard_logs/tf_agents_v2/replay_buffer
I1222 11:26:35.910735 140187933329152 tf_agents_v2_lunar_lander.py:267] Initializing replay buffer by collecting experience for 100 stepswith a random policy.
WARNING:tensorflow:From /home/patrick/src/tf_agents/tf_agents/metrics/tf_metrics.py:161: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
W1222 11:26:36.424730 140187933329152 deprecation.py:323] From /home/patrick/src/tf_agents/tf_agents/metrics/tf_metrics.py:161: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
I1222 11:28:23.095548 140187933329152 metric_utils.py:47]  
		 AverageReturn = 1.452040195465088
		 AverageEpisodeLength = 501.0
I1222 11:28:34.015443 140187933329152 tf_agents_v2_lunar_lander.py:314] env steps = 31200, average return = -80.228371
I1222 11:28:34.015817 140187933329152 tf_agents_v2_lunar_lander.py:317] 131.060 env steps/sec
etc.

And the output from nvidia-smi while running the script:

Sun Dec 22 11:29:16 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 410.129      Driver Version: 410.129      CUDA Version: 10.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Quadro M4000        Off  | 00000000:01:00.0  On |                  N/A |
| 51%   56C    P0    43W / 120W |   7865MiB /  8104MiB |     10%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0      1370      G   /usr/lib/xorg/Xorg                           435MiB |
|    0      2062      G   compiz                                       146MiB |
|    0      3479      G   ...uest-channel-token=17571043003057555071   211MiB |
|    0     17466      C   python                                      7057MiB |
+-----------------------------------------------------------------------------+
@tfboyd tfboyd self-assigned this Dec 23, 2019
@tfboyd
Copy link
Member

tfboyd commented Dec 23, 2019

Hi @pirobot

I am going to run a few experiments. It is the holiday time so my responses make take longer than usual. This item is important to us and me. That you for the detailed report and code snippet. If I need more help reproducing I will ping you. Thank you again.

@tfboyd tfboyd added the type:performance Issues related to performance. Throughput, time to train, or accuracy/reward. label Dec 23, 2019
@pirobot
Copy link
Author

pirobot commented Dec 24, 2019

Many thanks @tfboyd and no hurry as I realize we can all use a break during the holidays. :)

@pirobot
Copy link
Author

pirobot commented Jan 8, 2020

Just pinging to see if there has been any progress on this issue? Thanks!

@pirobot
Copy link
Author

pirobot commented Jan 21, 2020

@tfboyd Do you think this issue will get addressed eventually? It would be great if someone could at least reproduce the performance issue I detailed above. Unfortunately our lab will have to abandon tf-agents if this issue cannot be resolved since it will take a month to do experiments that can otherwise be done in a day with stable-baselines. I'm still hoping I have just missed a critical parameter that you will discover after running the above code. We have a definite preference to stick with tf-agents but this performance issue is preventing us from using it in any kind of serious research. Thanks!

@tfboyd
Copy link
Member

tfboyd commented Jan 21, 2020

Sorry about the delay. I will confirm with the team tomorrow. I wanted to dig into this but I think we will be focusing on other priorities in Q1 with our existing resources. I will ping back tomorrow and put it on the agenda for our weekly team meeting.

I understand your frustration and need to find a platform that works for your situation.

@pirobot
Copy link
Author

pirobot commented Jan 22, 2020

@tfboyd No worries and thanks for the response. We'll use stable-baselines for now and wait to hear back eventually if someone can confirm the performance differences.

@ebrevdo
Copy link
Contributor

ebrevdo commented Jan 22, 2020

@tfboyd quick ping on this one.

@tfboyd
Copy link
Member

tfboyd commented Jan 22, 2020

@kuanghuei may take a look as the train time seem really long. Thank you for being reasonable and kind as we work to make improvements.

@tfboyd tfboyd added the level:p1 label Feb 4, 2020
@a-z-e-r-i-l-a
Copy link

I noticed the same problem, any update on it?
Is there also any comparison made other than the speed with Stable_baselines?

@kuanghuei
Copy link
Contributor

kuanghuei commented Feb 26, 2020

I tried to run some experiments on gpu. It seems like the slow down only happens with LunarLanderContinuous-v2 but not on HalfCheetah. Even with a slow down, I was able to run ~80 steps/sec on Titan X, which is probably like 1.38 hours for 500,000 steps.

Did you observe the same slow down on other env or only on lunar lander?

By the way, GPU has higher overhead and is likely to be slower than CPU for small networks. On halfCheetah, my CPU experiments get ~210-230 steps/sec, but GPU experiments only get ~125 steps/sec on my desktop.

@Strateus
Copy link

Any updates on this? I am not using mentioned environments, but given there might be internal issues in the framework sounds concerning

@singagan
Copy link

I tried with a custom environment and it seems GPU performs slower than CPU with the latest tf-agent version

@tfboyd
Copy link
Member

tfboyd commented Oct 15, 2020

I have not done the exact testing all of you would want but I do have some data that I think is useful. I want to stress this is not remotely apples-to-apples. I have been doing perf testing with TF since before 1.x and I dislike sharing data that does not answer a question exactly; but I have found some data can be better than no data. I also have learned direct comparisons are really hard and using common envs (like mujoco half-cheetah in the case of many RL scenarios) as the starting point is the best approach.

For the new SAC example I we run nightly internal tests and I have published the results with full event logs. The runs were on CPU. As I said, not the GPU numbers you want.

I did a test with stable-baselines, but keep in mind I am not an expert on stable-baselines. On CPU I saw a 10-20% (1.1x to 1.2x) performance difference with stable-baselines being slightly faster. It was a sloppy test and aligning how often everyone evals and making sure it is inline leaves room for errors.

I also noticed baselines seems to be getting ~14K for half-cheetah which is confirmed in one of their git hub issues and my own test run. But i copied the pybullet env info for half-cheetah and I was using MuJoCo. I have no idea how much that impacts the results. I am not throwing shade as we (anyone doing RL) are all in this together. My results below also show 14K, but I found an error in that we were not using the GreedyPolicy for eval and the results now match the paper at 15K I just have not had time to publish all the numbers again. It is possible stable-baselines has a similar mistake, it is not a big deal and I did not have time to look into it.

Again not the CPU performance you want, but in this example at 500K it took ~2 Hours and we eval inline every 10K steps and do 30 episodes (yes 30 is not correct by the paper). My rough napkin estimate is 50 evals at 1m each for 50 minutes (without eval as a datapoint). So let's say the run took 1 hour to 500K steps of half-cheetah on a 16x 2.2Ghz vCPUs (8 physical and 8 logical) without eval and 2 hours doing eval every 10K steps and doing 30 episodes averaged (which is not correct 1 episode is correct). A quick check of the other results: Half-Cheetah, Hopper, and Walker2d are similar with Ant taking much longer at 3 hours minus eval time. None of these are LunarLanderContinuous-v2 and given the length of Ant the env matters.

Informally, I tested GPU for half-cheetah on tf-agents on a GTX-1080. I got a modest performance increase on my workstation using the same batch-size as CPU. I suspect larger batch-sizes are needed to go faster.

I hope to make time to do some GPU testing of stable-baselines as if there is a big difference we need to look into tf-agents as we are both using TensorFlow (v1 vs. V2 I think). I am sorry I have not been able to do this exact test. When I test next time I will try to toss in LunarLanderContinuous-v2, but it is not one of the main envs we test against.

@yasser-h-khalil
Copy link

hello, any solutions or walkaround on speed issue with tf.agents?

Help is much appreciated.

@tfboyd
Copy link
Member

tfboyd commented Dec 10, 2020

On CPU I did not see a huge performance difference testing on the same hardware. I have not been able to go back and test GPU apples to apples. We are forming a plan to test the top agents (likely SAC and PPO) against some other leading libraries in Q1 and address any gaps. If you have a specific use case, that is the most useful to start with. Doing perf is always difficult and I prefer to stick with standard agents so there is a decent chance of apples-to-apples. I was the perf owner for TensorFlow and the hardest part was apples-to-apples to then narrow down the difference. Even something as seemingly obvious as ResNet took ages before people are testing the same exact network. SAC + Half-Cheetah or another env from the paper is the best scenario. I/We could do others but it is more difficult and less universally valuable across the board.

This matters to us and I have not found a straight forward use case and plan to dig deeper in Q1.

This is something we are about. Not to distract you but we just released full testing of the PPO Agent on Mujoco. I am a few weeks again from numbers for the distributed setup (focused on distributed collect)

I know this puts a lot of work on you. But if you have a clear example of X vs. Y. With TF-Agents much slower that is actionable. @yasser-h-khalil

@ebrevdo
Copy link
Contributor

ebrevdo commented Dec 10, 2020

Have we resolved @pirobot 's original concern? If so, we should close this ticket and ask folks to create new ones with their specific code.

@tfboyd
Copy link
Member

tfboyd commented Dec 10, 2020

I am going to close this as it is misleading. If there is an exact use case let's get on it. We currently do not have proof of a gap in performance but we are going to use our time to look for a gap to be sure in early 2021. As stated above, the process is really hard as apples-to-apples is not as obvious. If anyone seeing this as a specific use case post it and @tfboyd . Perf is hard and it needs to be reproducible for both TF-Agents and the tool that is stated as faster. Having to go dig into another tool to get an exact reproduction is a big time sink. I wish it was easier.

@tfboyd tfboyd closed this as completed Dec 10, 2020
@ebrevdo
Copy link
Contributor

ebrevdo commented Dec 10, 2020

IMO the benchmarks link toby pointed to show that there is no longer a 10x performance drop; but again, please submit any new regressions as a new bug with repro instructions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
level:p1 type:performance Issues related to performance. Throughput, time to train, or accuracy/reward.
Projects
None yet
Development

No branches or pull requests

8 participants