Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRAIN TF-AGENTS WITH MULTIPLE GPUs #289

Closed
JCMiles opened this issue Jan 21, 2020 · 22 comments
Closed

TRAIN TF-AGENTS WITH MULTIPLE GPUs #289

JCMiles opened this issue Jan 21, 2020 · 22 comments

Comments

@JCMiles
Copy link

JCMiles commented Jan 21, 2020

Hi,
I finally got my vm up and running using:
2 Tesla P100
NVIDIA driver 440.33.01
cuda 10.2
tensorflow=2.1.0
tf_agents=0.3.0

I start training a custom model/env based on SAC agent v2 train loop
but only one GPU is used.
My question :
at the moment is tf-agents able to manage distribute training on multiple GPU?
or should I use only one?

@kuanghuei
Copy link
Contributor

It should work with multiple GPUs. You can specify that in your training script like you usually would do in other tensorflow training code.

@JCMiles
Copy link
Author

JCMiles commented Jan 22, 2020

@kuanghuei

if you try sac/examples/v2/train_eval_rnn.py as it is on a vm like the one listed above you'll see only one GPU working.

Could you please provide an example on how to run train_eval_rnn.py on a multigpu?
Wrapping the entire pipeline with MirroredStrategy() doesn't solve the problem.

    tf.debugging.set_log_device_placement(True)
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        for _ in range(num_iterations):
          start_time = time.time()
          start_env_steps = env_steps.result()
          time_step, policy_state = collect_driver.run(time_step,policy_state) 

        .....

@ebrevdo
Copy link
Contributor

ebrevdo commented Jan 22, 2020

@oars any advice for how to make the MirroredStrategy work here?

i wonder if you need to build the iterator inside the MirroredStrategy.

@JCMiles
Copy link
Author

JCMiles commented Jan 22, 2020

@ebrevdo and @oars

do you mean something like this?

    tf.debugging.set_log_device_placement(True)
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():

       dataset = replay_buffer.as_dataset(
       num_parallel_calls=3,
       sample_batch_size=batch_size,
       num_steps=train_sequence_length + 1).prefetch(3)
       iterator = iter(dataset)
    
       def train_step():
           experience, _ = next(iterator)
           return tf_agent.train(experience)

        train_step = common.function(train_step)

        for _ in range(num_iterations):
          start_time = time.time()
          start_env_steps = env_steps.result()
          time_step, policy_state = collect_driver.run(time_step,policy_state) 

    .....

I can give it a try tomorrow and let you know

@JCMiles
Copy link
Author

JCMiles commented Jan 23, 2020

with this approach, operations are always added on a single GPU

@wsun66
Copy link

wsun66 commented Jan 27, 2020

I am having the exactly same problem. Any solutions or work-around? Thanks,

@JCMiles
Copy link
Author

JCMiles commented Jan 27, 2020

I made some progress by combining some snippets of code from the official tensorflow docs
but I'm not able to complete the train step:

Here what I did so far based on: https://github.com/tensorflow/agents/blob/master/tf_agents/agents/sac/examples/v2/train_eval_rnn.py but with custom networks and environment.
parallel environments are temporarily disabled.

  1. Addat the very beginning of train_eval fn

     tf.config.set_soft_device_placement(True)
     gpus = tf.config.experimental.list_physical_devices('GPU') 
     if gpus:
         try:
             for gpu in gpus:
                 tf.config.experimental.set_memory_growth(gpu, True)
             logical_gpus = tf.config.experimental.list_logical_devices('GPU')
             print(f"{len(gpus)} physical GPUs - {len(logical_gpus)} logical GPUs")
         except RuntimeError as e:
             # Memory growth must be set before GPUs have been initialized
             logging.error(e)
    
  2. define mirrored strategy just before agent init

       batch_size = 256
       strategy = tf.distribute.MirroredStrategy()
       global_batch_size = batch_size*strategy.num_replicas_in_sync
    
       with strategy.scope():
       	tf_agent = sac_agent.SacAgent(batch_size=batch_size)
    
  3. Create distributed dataset

     tf_env = tf_py_environment.TFPyEnvironment(custom_py_env)
    
    # tf_env.batch_size = 1
    replay_buffer = TFUniformReplayBuffer(agent.collect_data_spec, batch_size=tf_env.batch_size
    
    dataset = replay_buffer.as_dataset(
               num_parallel_calls=num_of_cores,
               sample_batch_size=global_batch_size,
               num_steps= train_sequence_length+1).prefetch(num_of_cores)
    
    dist_dataset = strategy.experimental_distribute_dataset(dataset)
    
    with strategy.scope():
          iterator = iter(dist_dataset)
    
  4. wrap train_step() as follow

        def train_step(dist_inputs):
            def step_fn(experience):
                return agent.train(experience)
     
            result = strategy.experimental_run_v2(step_fn, args=(dist_inputs,))
            mean_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, result.loss, axis=0)
            return mean_loss
    
        train_step =  common.function(train_step)  
    
  5. The main train loop:

     for _ in range(num_iterations):
     	time_step, policy_state = collect_driver.run(time_step, policy_state)
     	episode_steps = self.env_steps.result() - start_env_steps
    
     	for _ in range(episode_steps):
     	     with strategy.scope():
     	            for __ in range(train_steps_per_iteration):
     		         experience, _ = next(iterator)
     		         loss = train_step(experience)
    

I'm probably messing around with batch_size, global_batch_size and tf_env.batch_size

@wsun66
Copy link

wsun66 commented Jan 27, 2020

Glad to hear you're making progresses. I'll be grateful if you could share your final solution. In the same time, I am trying to see if I can make it work. I'll share mine here as well if I can find a working one.

@tfboyd
Copy link
Member

tfboyd commented Feb 5, 2020

Assigning to me to identify the multi-gpu example or get one made. A common problem and we should publish a definitive example to give people a starting point.

@tfboyd tfboyd self-assigned this Feb 5, 2020
@JCMiles
Copy link
Author

JCMiles commented Feb 14, 2020

Hi @tfboyd , any update on this?

@egonina
Copy link
Contributor

egonina commented Feb 19, 2020

@JCMiles what errors are you getting with your implementation?

One thing you want to make sure is to create your network variables and the dataset iterator within the strategy scope. I am not sure from your code snippets where you create your networks. The only thing that needs to be outside of strategy scope is the replay buffer itself, so it doesn't get replicated on the GPUs.

Something like this:

  @common.function
  def train_step(iterator):
    def replicated_train_step(experience):
      return self._agent.train(experience).loss

    experience = eager_utils.get_next(iterator)
    per_replica_losses = strategy.experimental_run_v2(
        replicated_train_step, args=(experience,))
    return strategy.reduce(
        tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)

  strategy = tf.distribute.MirroredStrategy()

  with strategy.scope():
    q_net = q_network.QNetwork(...)

    tf_agent = dqn_agent.DqnAgent(...)
 
  # note the replay buffer is outside of strategy.scope():
  replay_buffer = TFUniformReplayBuffer(...)

  with strategy.scope():
    dataset = replay_buffer.as_dataset(...)
    dist_dataset = strategy.experimental_distribute_dataset(dataset)
    strategy_dataset_iterator = iter(dist_dataset)

    for _ in range(num_iterations):
      train_step(strategy_dataset_iterator)

@tfboyd @sguada @ebrevdo let's discuss internally an example we can share

@JCMiles
Copy link
Author

JCMiles commented Feb 20, 2020

@egonina
following your implementation applied to sac agent example v2 I get this error

Traceback (most recent call last):

File "/home/user/test_agent/train/module.py", line 321, in train_step
experience = eager_utils.get_next(iterator)
File "/home/user/.local/lib/python3.6/site-packages/tf_agents/utils/eager_utils.py", line 638, in get_next
return iterator.get_next()
File "/home/user/.local/lib/python3.6/site-packages/tensorflow_core/python/distribute/input_lib.py", line 306, in get_next
global_has_value, replicas = _get_next_as_optional(self, self._strategy)
File "/home/user/.local/lib/python3.6/site-packages/tensorflow_core/python/distribute/input_lib.py", line 223, in _get_next_as_optional
reduce_util.ReduceOp.SUM, worker_has_values, axis=None)
File "/home/user/.local/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 808, in reduce
return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access
File "/home/user/.local/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 1449, in _reduce
device_util.current() or "/device:CPU:0"))[0]
File "/home/user/.local/lib/python3.6/site-packages/tensorflow_core/python/distribute/mirrored_strategy.py", line 736, in _reduce_to
reduce_op, value, destinations=destinations)
TypeError: reduce() missing 1 required positional argument: 'per_replica_value'

@egonina
Copy link
Contributor

egonina commented Feb 20, 2020

@anj-s , can you PTAL at this use of DistributionStrategy? Thanks!

@anj-s
Copy link

anj-s commented Mar 2, 2020

From the error message it looks like we are not populating a field which indicates if there is data remaining in the dataset(that has not been processed yet). Given this is a single machine MirroredStrategy example I am not sure why we would not have populated this field. I need a reproducible example to dig into this.

@JCMiles @egonina Can you provide me with a reproducible example?

@JCMiles
Copy link
Author

JCMiles commented Mar 2, 2020

just add @egonina implementation to agents/sac/examples/v2/train_eval_rnn.py

@wsun66
Copy link

wsun66 commented Mar 3, 2020

I've made the train_eval.py work on multiple GPUs. Attached is the source code. Hope it help
train_eval_strategy.zip

@JCMiles
Copy link
Author

JCMiles commented Mar 4, 2020

Hi,
I tested @wsun66 on train_eval.py and worked fine. The key is to place the replay buffer and the dataset outside the mirrored strategy scope under the cpu.
But when I try to implement the same on train_eval_rnn.py I get this error.
Like the previous one a parameter is missing.

[03-04|23:44:08] [14931] INFO 219 coordinator | Error reported to Coordinator: batch_reduce() missing 1 required positional argument: 'value_destination_pairs'
Traceback (most recent call last):
File "/home/user/anaconda3/envs/dev/lib/python3.6/site-packages/tensorflow_core/python/training/coordinator.py", line 297, in stop_on_exception
yield
File "/home/user/anaconda3/envs/dev/lib/python3.6/site-packages/tensorflow_core/python/distribute/mirrored_strategy.py", line 195, in _call_for_each_replica
**merge_kwargs)
File "/home/user/anaconda3/envs/dev/lib/python3.6/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py", line 449, in _distributed_apply
ds_reduce_util.ReduceOp.SUM, grads_and_vars)
File "/home/user/anaconda3/envs/dev/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 1494, in batch_reduce_to
return self._batch_reduce_to(reduce_op, value_destination_pairs)
File "/home/user/anaconda3/envs/dev/lib/python3.6/site-packages/tensorflow_core/python/distribute/mirrored_strategy.py", line 740, in _batch_reduce_to
value_destination_pairs)
TypeError: batch_reduce() missing 1 required positional argument: 'value_destination_pairs'

@JCMiles
Copy link
Author

JCMiles commented Mar 5, 2020

Hello everyone, I found the problem and fixed also in train_eval_rnn.
when I implemented @wsun66 solution in my previous version I declared:

   strategy = tf.distribute.MirroredStrategy(cross_device_ops=tf.distribute.NcclAllReduce)

I passed as cross_device_ops param the callable. That caused the problem.
remember to pass tf.distribute.NcclAllReduce() or whatever ops you decide to implement, after fixing this I was able to run the train pipeline entirely:

I have only one small question now before to close this issue
I noticed from the logs this one:

[03-05|14:06:21] [26463] INFO 760 cross_device_ops| batch_all_reduce: 112 all-reduces with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10
[03-05|14:06:21] [26463] WARNING 791 cross_device_ops| Efficient allreduce is not supported for 4 IndexedSlices

what are those 4 IndexedSlices? and why efficient allreduce is not supported ?

@egonina
Copy link
Contributor

egonina commented Mar 5, 2020

ah, glad you were able to figure it out! and thanks @wsun66 for providing your example!

how many GPUs are you running on? @anj-s can you provide some insight into what IndexSlices are?

@JCMiles
Copy link
Author

JCMiles commented Mar 5, 2020

I'm using 2 Tesla P100, so I suppose IndexSlices are not related to the number of gpus.

Meanwhile I'm running some performance test and I'm facing some unexpected downtime.
Sometimes the training loop get stuck in the first iteration.

Sill investigating ...
Keep you updated if any progress

@JCMiles
Copy link
Author

JCMiles commented Mar 10, 2020

Quick update:

To run a train cycle on multiple GPUs and with "n" parallel environments I found that is necessary
to place parallel environments initialization under the CPU. Otherwise CUDA errors can occur.

    with tf.device('/CPU:0'):
        train_env = ParallelPyEnvironment(...)

Hope this helps.

@tfboyd
Copy link
Member

tfboyd commented Oct 26, 2020

Assigning to me to test with the new updated distributed scripts. Using the new distributed collect is the best (maybe only) way to really drive multiple GPUs. ParallelPyEnvironment might work but we do not use that approach very often as we prefer to use a lot of CPU only machines to drive data to the GPU server(s).

https://github.com/tensorflow/agents/tree/master/tf_agents/experimental/distributed/examples/sac

I will close this issue after testing the linked cover on multiple GPUs and verifying it is using both GPUs. The example is unlikely to drive a high usage rate due to the network being very small.

@JCMiles JCMiles closed this as completed Sep 1, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

7 participants