Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bijectors are orders of magnitude slower in tf2.1 autograph distributed mirrored single-gpu mode #35415

Closed
olegmyrk opened this issue Dec 26, 2019 · 6 comments
Assignees
Labels
comp:dist-strat Distribution Strategy related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.1 for tracking issues in 2.1 release type:performance Performance Issue

Comments

@olegmyrk
Copy link

olegmyrk commented Dec 26, 2019

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04.3 LTS
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): v2.1.0-rc1-58-g9837ece 2.1.0-rc2 (python3 -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)")
  • Python version: Python 3.6.8
  • CUDA/cuDNN version: Driver Version: 440.33.01, CUDA Version: 10.2, cuDNN 7.6.2
    *GPU model and memory: Tesla V100-SXM2-16GB

Describe the current behavior
I'm using Bijectors as a flexible prior for a VAE.

This code has negligible overhead in tf1.x (for input batch size 18x256x256x3). In tf2.1 autograph distributed mirrored mode
https://github.com/olegmyrk/SPADE-Tensorflow/blob/85b5fd7943296561dc3d54557fec5346c2adea58/SPADE.py#L1152
with single GPU it increases training step duration 1 second (tf1.x) -> 1.9 seconds (tf2.1):

bijectors = []
for i in range(16):
    bijectors.append(tfb.MaskedAutoregressiveFlow(
      shift_and_log_scale_fn=tfb.masked_autoregressive_default_template(
          code, hidden_layers=[1024, 1024], name=scope + "/maf_" + str(i))))

    bijectors.append(tfb.BatchNormalization(
        batchnorm_layer=tf.layers.BatchNormalization(
                            name=scope + '/batch_norm_' + str(i)),
        name=scope + '/batch_norm_bijector' + str(i)))

    permutation=tf.get_variable('permutation_'+str(i), initializer=np.random.permutation(out_channel).astype("int32"), trainable=False)
    bijectors.append(tfb.Permute(permutation))
    
flow_bijector = tfb.Chain(list(reversed(bijectors[:-1])))

https://github.com/olegmyrk/SPADE-Tensorflow/blob/85b5fd7943296561dc3d54557fec5346c2adea58/SPADE.py#L190

I'm using custom masked autoregressive template
https://github.com/olegmyrk/SPADE-Tensorflow/blob/85b5fd7943296561dc3d54557fec5346c2adea58/masked_autoregressive.py
but it is as slow with the default one:
https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors/masked_autoregressive_default_template

Possible suspects:

tfb.masked_dense

https://github.com/olegmyrk/SPADE-Tensorflow/blob/85b5fd7943296561dc3d54557fec5346c2adea58/masked_autoregressive.py#L44

tf1.make_template()

https://github.com/olegmyrk/SPADE-Tensorflow/blob/85b5fd7943296561dc3d54557fec5346c2adea58/masked_autoregressive.py#L115

Describe the expected behavior
Performance in tf2.1 and tf1.x should be comparable.

Code to reproduce the issue
TF2.x code:
https://github.com/olegmyrk/SPADE-Tensorflow/blob/85b5fd7943296561dc3d54557fec5346c2adea58/SPADE.py#L190

TF1.x code:
https://github.com/olegmyrk/SPADE-Tensorflow/blob/develop/SPADE.py#L190

Other info / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

@gadagashwini-zz gadagashwini-zz self-assigned this Dec 26, 2019
@gadagashwini-zz gadagashwini-zz added TF 2.1 for tracking issues in 2.1 release type:performance Performance Issue comp:dist-strat Distribution Strategy related issues labels Dec 26, 2019
@anj-s
Copy link
Contributor

anj-s commented Dec 26, 2019

olegmyrk@ Thank you for posting a detailed summarization! To debug performance issues such as the one above we need timeline traces. Can you post the following traces:
trace1: TF 1.x without distribution strategy
trace2: TF 2.x without distribution strategy
trace3: TF2.x with 1 GPU Mirrored strategy
Does the performance of trace1 and trace2 match? If they do we can then look at trace2 and trace3.

@olegmyrk
Copy link
Author

Here is trace1 with TF 1.x
timeline.gz

Here is trace3 with TF2.x with 1 GPU Mirrored strategy
trace.gz

Please note that adding tracing makes training significantly slower on its own (especially in TF1.x).

@olegmyrk
Copy link
Author

olegmyrk commented Jan 1, 2020

I have managed to create a minimalistic TF1 and TF2 scripts to demonstrate the issue. As you can see from the logs that building graph in TF1 is 2x faster than in TF2 single-gpu mirrored mode and 10x faster than in TF2 multi-gpu mirrored mode.

Note that in TF2 code there is a commented-out line of code that also makes the training step 2x slower in multi-gpu mirrored mode.

CUDA_VISIBLE_DEVICES=0 time python3 test_maf_tf1.py

...
step: 0
2020-01-01 18:28:19.660927: I tensorflow/stream_executor/dso_loader.cc:152] successfully opened CUDA library libcublas.so.10.0 locally
result -6955.917
duration: 12.701881170272827
step: 1
result -6996.775
duration: 0.745488166809082
step: 2
result -7037.4297
...

CUDA_VISIBLE_DEVICES=0 time python3 test_maf_tf2.py

...
step: 0
result tf.Tensor(-6955.917, shape=(), dtype=float32)
duration: 28.593788623809814
step: 1
result tf.Tensor(-6956.736, shape=(), dtype=float32)
duration: 0.6307415962219238
step: 2
result tf.Tensor(-6957.5557, shape=(), dtype=float32)
duration: 0.5865404605865479
...

CUDA_VISIBLE_DEVICES=0,1,2,3 time python3 test_maf_tf2.py

...
step: 0
result tf.Tensor(-6955.917, shape=(), dtype=float32)
duration: 189.2380485534668
step: 1
result tf.Tensor(-6956.736, shape=(), dtype=float32)
duration: 36.806933879852295
step: 2
result tf.Tensor(-6957.5557, shape=(), dtype=float32)
duration: 0.7145941257476807
...

TF1 script:

import time
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

def dist(loc, diag):
    channel = int(loc.get_shape()[-1])
    bijectors = []
    for i in range(16):
      bijectors.append(tfb.MaskedAutoregressiveFlow(
        shift_and_log_scale_fn=tfb.masked_autoregressive_default_template(
            hidden_layers=[1024, 1024], name="maf_" + str(i))))

      bijectors.append(tfb.BatchNormalization(
          batchnorm_layer=tf.layers.BatchNormalization(
                              name='batch_norm_' + str(i)),
          name='batch_norm_bijector' + str(i)))

      permutation=tf.get_variable('permutation_'+str(i), dtype=tf.int32, initializer=np.random.permutation(channel).astype("int32"), trainable=False)
      bijectors.append(tfb.Permute(permutation))

    flow_bijector = tfb.Chain(list(reversed(bijectors[:-1])))

    mvn_dist = tfd.MultivariateNormalDiag(loc, diag, name='MultivariateNormalDiag')

    dist = tfd.TransformedDistribution(
                    distribution=mvn_dist,
                    bijector=flow_bijector
                )
    return dist	
    
def compute():
    x = tf.placeholder(tf.float32, [None, 128])
    x_dist = dist(tf.zeros_like(x), tf.ones_like(x))
    y = x_dist.sample()
    #d = -tf.reduce_mean(x_dist.log_prob(y))
    d = -tf.reduce_mean(x_dist.log_prob(x))
    return x, d, y

with tf.Session() as sess:
  x, d, y = compute()
  optim = tf.train.AdamOptimizer(0.01).minimize(d)
  tf.global_variables_initializer().run()
  for i in range(0,100):
      start = time.time()
      print("step:", i)
      v = np.zeros((64,128),np.float32)
      result_d, result_y, _ = sess.run([d, y, optim], feed_dict={ x : v })
      print("result", result_d)
      print("duration:", time.time()-start)  

TF2 script:

import time
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

distribute_strategy = tf.distribute.MirroredStrategy()

dataset = tf.data.Dataset.from_tensor_slices(([np.zeros((128,),np.float32)]))
dataset = dataset.repeat(None).batch(64)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
dataset = distribute_strategy.experimental_distribute_dataset(dataset)
dataset = iter(dataset)

with distribute_strategy.scope():
  optim = tf.keras.optimizers.Adam(0.0002)

  def dist(loc, diag):
    channel = loc.get_shape()[-1]
    bijectors = []
    for i in range(16):
      bijectors.append(tfb.MaskedAutoregressiveFlow(
        shift_and_log_scale_fn=tfb.masked_autoregressive_default_template(
            hidden_layers=[1024, 1024], name="maf_" + str(i))))

      bijectors.append(tfb.BatchNormalization(
          batchnorm_layer=tf.compat.v1.layers.BatchNormalization(
                              name='batch_norm_' + str(i)),
          name='batch_norm_bijector' + str(i)))

      permutation=tf.compat.v1.get_variable('permutation_'+str(i), dtype=tf.int32, initializer=np.random.permutation(channel).astype("int32"), trainable=False)
      bijectors.append(tfb.Permute(permutation))
    
    flow_bijector = tfb.Chain(list(reversed(bijectors[:-1])))

    mvn_dist = tfd.MultivariateNormalDiag(loc, diag, name='MultivariateNormalDiag')

    dist = tfd.TransformedDistribution(
                    distribution=mvn_dist,
                    bijector=flow_bijector
                )
    return dist

  def compute(x):
    x_dist = dist(tf.zeros_like(x), tf.ones_like(x))
    y = x_dist.sample()
    #I don't need this, but it makes the training step 2x slower in multi-gpu mode
    #d = -tf.reduce_mean(x_dist.log_prob(y))
    d = -tf.reduce_mean(x_dist.log_prob(x))
    return d, y
    
  @tf.function 
  def train(x):
    def train_fn(x):
      with tf.GradientTape(persistent=True) as tape:
        d,y  = compute(x)
      vars = tf.compat.v1.trainable_variables()
      g = tape.gradient(d,vars)
      optim.apply_gradients(zip(g, vars))
      return d, y
    rd,ry = distribute_strategy.experimental_run_v2(train_fn, args=(x,))
    result = tf.reduce_mean(distribute_strategy.experimental_local_results(rd))
    return result

  for i in range(0,100):
      start = time.time()
      print("step:", i)
      x = next(dataset)
      result = train(x)
      print("result", result)
      print("duration:", time.time()-start)

@sachinprasadhs
Copy link
Contributor

Could you please test the performance in latest Tensorflow version, since many of the experimental modules are moved to stable and there could be improvement in performance. Refer this document for details. Thank you!

@sachinprasadhs sachinprasadhs added the stat:awaiting response Status - Awaiting response from author label Nov 17, 2021
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Nov 24, 2021
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:dist-strat Distribution Strategy related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.1 for tracking issues in 2.1 release type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests

5 participants