Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.nn.ctc_loss calculated sequentially when using tf.distribute.MirroredStrategy() #52752

Open
viktor-haag opened this issue Oct 27, 2021 · 1 comment
Assignees
Labels
comp:dist-strat Distribution Strategy related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.5 Issues related to TF 2.5 type:performance Performance Issue

Comments

@viktor-haag
Copy link

Please make sure that this is an issue related to performance of TensorFlow.
As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:performance_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • TensorFlow installed from (source or binary): docker pull tensorflow/tensorflow:2.5.0-gpu, running in a TFJob
  • GPU model and memory: Tesla P100

Describe the current behavior

I am working on a ctc based model and wanted to accelerate the training by applying the MirroredStrategy to my custom training loop. I modified my code according to the tutorials, but did not observe any speed up. After digging into it with the profiler I noticed in the trace viewer that each gpu seemed to compute the loss one after another instead of concurrently (see image further below). Also, there seemed a lot of communication to be going on between the devices and the host. Moreover, the overview page stated that more than ~95% of the device time is spent on eager execution. Maybe that's related somehow?

Describe the expected behavior
Maybe there is some misunderstanding on my side how exactly ctc_loss and MirroredStrategy work, but I would expect that it is possible to calculate the loss concurrently on all gpus.

Standalone code to reproduce the issue
Here is some example code which reproduces the issue I am facing. I removed the model forward/backward pass as it is not important and did not seem to cause any problems:

import tensorflow as tf

# some dummy parameters, not really important
profile_steps = 3
per_replica_batch_size = 2
max_label_seq_length = frames = 20
num_labels = 10
strategy = tf.distribute.MirroredStrategy()
GLOBAL_BATCH_SIZE = per_replica_batch_size * strategy.num_replicas_in_sync

with strategy.scope():
    def compute_loss(labels, logits, label_length, logit_length):
        per_replica_loss = tf.nn.ctc_loss(labels, logits, label_length, logit_length)
        return tf.nn.compute_average_loss(
            per_replica_loss, global_batch_size=GLOBAL_BATCH_SIZE
        )

@tf.function
def train_step():
    labels = tf.ones((GLOBAL_BATCH_SIZE, max_label_seq_length), dtype=tf.int32)
    logits = tf.random.normal(shape=(frames, GLOBAL_BATCH_SIZE, num_labels))
    label_length = tf.ones(GLOBAL_BATCH_SIZE, dtype=tf.int32) * max_label_seq_length
    logit_length = tf.ones(GLOBAL_BATCH_SIZE, dtype=tf.int32) * frames
    loss = compute_loss(labels, logits, label_length, logit_length)

tf.profiler.experimental.start("logs/profiler")
for step in range(profile_steps):
    with tf.profiler.experimental.Trace("train", step_num=step, _r=1):
        loss = strategy.run(
            train_step,
        )
tf.profiler.experimental.stop()

Other info / logs

Here is a screenshot of the trace viewer of the above code when executed on 3 gpus.

tf_distributed

@viktor-haag viktor-haag added the type:performance Performance Issue label Oct 27, 2021
@mohantym mohantym added comp:dist-strat Distribution Strategy related issues TF 2.5 Issues related to TF 2.5 labels Oct 28, 2021
@mohantym
Copy link
Contributor

Hi @Saduf2019 ! Could you please have a look at this issue ?

@mohantym mohantym assigned Saduf2019 and unassigned mohantym Oct 28, 2021
@Saduf2019 Saduf2019 assigned jvishnuvardhan and unassigned Saduf2019 Nov 3, 2021
@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Nov 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:dist-strat Distribution Strategy related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.5 Issues related to TF 2.5 type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests

4 participants