Using tf.cond
to change behaviour for training/inference with dropout causes crash when using XLA
#50551
Labels
comp:xla
XLA
stat:awaiting tensorflower
Status - Awaiting response from tensorflower
TF 2.5
Issues related to TF 2.5
type:bug
Bug
System information
Some irrelevant fields have been deleted
pip
v2.5.0-rc3-213-ga4dfb8d1a71 2.5.0
) and 1.15.5 (v1.15.4-39-g3db52be 1.15.5
)Describe the current behavior
The code fails with the following error:
Interestingly, the code fails at the call to
sess.run(tf.global_variables_initializer())
- if you remove all code after that line it should still fail.Describe the expected behavior
The model - if you can even call it that - should run without issue. It runs fine on the CPU without XLA. To see this, uncomment the line
cpu_result = my_net(placeholder, is_training)
, comment out the next two lines and changexla_result
tocpu_result
in thesess.run()
calls.Do you want to contribute a PR? (yes/no): No, I do not yet understand the issue enough
Standalone code to reproduce the issue
The code here isn't for any particular neural network, it's just a minimal reproducer. The original model is much bigger, but gives rise to the same issue.
Other info / logs
I have traced the issue to
tensorflow/compiler/tf2xla/functionalize_cond.cc
. Looking at the TensorFlow graph, the complaint raised in the error message is legitimate: the Merge nodes named are indeed directly downstream of two Switch nodes -cond_1/gradients/Switch
with predicatecond_1/gradients/cond/Merge_grad/cond_grad/Switch:1
andcond_1/gradients/cond/mul_grad/Mul/Switch
with predicatecond_1/pred_id
.I'm guessing that, assuming that such a situation should even be allowed to arise at all, it's a result of an optimisation which takes advantage of the fact that the two predicates are essentially the same. Either way, the code in
tensorflow/compiler/tf2xla/functionalize_cond.cc
seems to assume that this will never happen. I would like to use XLA because using it is necessary to run TensorFlow programs on the Graphcore IPU.It appears that
I've included a line in the reproducer so that you can dump and inspect the graph for yourself using a tool such as Netron - simply uncomment the line
# tf.io.write_graph(sess.graph, 'xla_tf_cond_issue', 'xla_tf_cond_issue.pb', as_text=False)
.Please let me know if there is any more useful information I can provide. I will continue to investigate this issue for myself in the meantime. Thank you!
The text was updated successfully, but these errors were encountered: