Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using tf.cond to change behaviour for training/inference with dropout causes crash when using XLA #50551

Open
callumm-graphcore opened this issue Jun 30, 2021 · 2 comments
Assignees
Labels
comp:xla XLA stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.5 Issues related to TF 2.5 type:bug Bug

Comments

@callumm-graphcore
Copy link

System information
Some irrelevant fields have been deleted

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 18.04.4 LTS (Bionic Beaver)
  • TensorFlow installed from (source or binary): Binary - installed using pip
  • TensorFlow version (use command below): Tested with both 2.5.0 (v2.5.0-rc3-213-ga4dfb8d1a71 2.5.0) and 1.15.5 (v1.15.4-39-g3db52be 1.15.5)
  • Python version: Python 3.6.9
  • GPU model and memory: N/A - Problem occurs just on CPU (originally identified using Graphcore IPU's)

Describe the current behavior

The code fails with the following error:

Traceback (most recent call last):
  File "/localdata/callumm/Z2876/upstream_tf_venv/lib/python3.6/site-packages/tensorflow_core/python/client/session.py", line 1365, in _do_call
    return fn(*args)
  File "/localdata/callumm/Z2876/upstream_tf_venv/lib/python3.6/site-packages/tensorflow_core/python/client/session.py", line 1348, in _run_fn
    self._extend_graph()
  File "/localdata/callumm/Z2876/upstream_tf_venv/lib/python3.6/site-packages/tensorflow_core/python/client/session.py", line 1388, in _extend_graph
    tf_session.ExtendSession(self._session)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Merge nodes {cond_1/gradients/cond/Identity/Switch_grad/cond_grad,cond_1/gradients/cond/dropout/mul/Switch_grad/cond_grad} directly dominated by switch nodes with different predicates (cond_1/gradients/cond/Merge_grad/cond_grad/Switch:1 vs is_training_0_arg:0).

Interestingly, the code fails at the call to sess.run(tf.global_variables_initializer()) - if you remove all code after that line it should still fail.

Describe the expected behavior

The model - if you can even call it that - should run without issue. It runs fine on the CPU without XLA. To see this, uncomment the line cpu_result = my_net(placeholder, is_training), comment out the next two lines and change xla_result to cpu_result in the sess.run() calls.

Do you want to contribute a PR? (yes/no): No, I do not yet understand the issue enough

Standalone code to reproduce the issue

The code here isn't for any particular neural network, it's just a minimal reproducer. The original model is much bigger, but gives rise to the same issue.

import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow.python.ops import math_ops
from tensorflow.python.training import gradient_descent

tf.disable_v2_behavior()

# Config
BATCH_SIZE = 1
MAX_LENGTH = 16
VOCAB_SIZE = 1000
DROPOUT_RATE = 0.9

# Generate random data
train_data = np.random.randint(0, VOCAB_SIZE, size=(BATCH_SIZE, MAX_LENGTH), dtype=np.int32)

# Placeholders
with tf.device("cpu"):
    placeholder = tf.placeholder(dtype=tf.int32, shape=[BATCH_SIZE, MAX_LENGTH], name='placeholder')
    is_training = tf.placeholder_with_default(input=tf.constant(True), shape=(), name="is_training")

# Define network
def my_net(x, training):
    embedding_table = tf.get_variable('embedding_table',
                                      [VOCAB_SIZE, 2],
                                      dtype=tf.float32)
    input_embedding = tf.nn.embedding_lookup(embedding_table, x)

    def dropped_inputs():
        #return input_embedding * 3
        return tf.nn.dropout(input_embedding, rate=DROPOUT_RATE)
    def identity():
        return tf.identity(input_embedding)

    input_embedding = tf.cond(training, dropped_inputs, identity)
    loss = math_ops.reduce_sum(math_ops.square(input_embedding))
    optimizer = gradient_descent.GradientDescentOptimizer(0.0005)

    def minimize():
        return optimizer.minimize(loss)
    def no_op():
        return tf.no_op()
    train = tf.cond(training, minimize, no_op)

    return loss, train

# Compile and run

# Uncomment to see the code working
#cpu_result = my_net(placeholder, is_training)

# Comment out these two lines to see code working
with tf.device('/device:XLA_CPU:0'):
    xla_result = tf.xla.experimental.compile(my_net, inputs=[placeholder, is_training])


with tf.Session() as sess:

    # Uncomment to write graph Protobuf to file
    # tf.io.write_graph(sess.graph, 'xla_tf_cond_issue', 'xla_tf_cond_issue.pb', as_text=False)    

    sess.run(tf.global_variables_initializer())

    print('Training...')
    for _ in range(5):
        # Change `xla_result` to `cpu_result` to see code working
        result = sess.run(xla_result, feed_dict={placeholder: train_data, is_training: True})
        print(result)

    print('Testing...')
    for _ in range(5):
        # Change `xla_result` to `cpu_result` to see code working
        result = sess.run(xla_result, feed_dict={placeholder: train_data, is_training: False})
        print(result)

Other info / logs

I have traced the issue to tensorflow/compiler/tf2xla/functionalize_cond.cc. Looking at the TensorFlow graph, the complaint raised in the error message is legitimate: the Merge nodes named are indeed directly downstream of two Switch nodes - cond_1/gradients/Switch with predicate cond_1/gradients/cond/Merge_grad/cond_grad/Switch:1 and cond_1/gradients/cond/mul_grad/Mul/Switch with predicate cond_1/pred_id.

I'm guessing that, assuming that such a situation should even be allowed to arise at all, it's a result of an optimisation which takes advantage of the fact that the two predicates are essentially the same. Either way, the code in tensorflow/compiler/tf2xla/functionalize_cond.cc seems to assume that this will never happen. I would like to use XLA because using it is necessary to run TensorFlow programs on the Graphcore IPU.

It appears that

I've included a line in the reproducer so that you can dump and inspect the graph for yourself using a tool such as Netron - simply uncomment the line # tf.io.write_graph(sess.graph, 'xla_tf_cond_issue', 'xla_tf_cond_issue.pb', as_text=False).

Please let me know if there is any more useful information I can provide. I will continue to investigate this issue for myself in the meantime. Thank you!

@tilakrayal tilakrayal added comp:xla XLA TF 2.5 Issues related to TF 2.5 labels Jul 1, 2021
@tilakrayal
Copy link
Contributor

tilakrayal commented Jul 1, 2021

@jvishnuvardhan ,
I was able to reproduce the issue in tf v2.5,v2.4,v1.15.Please find the gist of it here.

@callumm-graphcore
Copy link
Author

callumm-graphcore commented Jul 2, 2021

Hello,

I was looking at this a bit more today and learned about tf.enable_control_flow_v2(), and learned that adding this works around the issue. For example, instead of this, which results in the error described above being thrown:

# Many lines skipped for brevity 
from tensorflow.python.training import gradient_descent

tf.disable_v2_behavior()

# Config
BATCH_SIZE = 1
# Many lines skipped for brevity 

We have this, which results in the program running successfully:

# Many lines skipped for brevity 
from tensorflow.python.training import gradient_descent

tf.disable_v2_behavior()
tf.enable_control_flow_v2()

# Config
BATCH_SIZE = 1
# Many lines skipped for brevity 

I assume this is because instead of putting the control flow in the graph using Switch and Merge nodes which must later be converted to constructs using If nodes (which is where the error is thrown), If nodes are used in the first place so no conversion is necessary.

While this gets around the issue, it might still be worth figuring out why a graph where a Merge node is dominated by two Switch nodes with different predicates.

Thank you for taking the time to look at this.

With thanks,
Callum

@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jul 13, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:xla XLA stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.5 Issues related to TF 2.5 type:bug Bug
Projects
None yet
Development

No branches or pull requests

4 participants