Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make tf.transpose emit simpler graph when possible #21945

Merged
merged 1 commit into from
Oct 4, 2018

Conversation

efagerho
Copy link
Contributor

If not given an explicit 'perm' parameter, tf.transpose currently
emits a graph that dynamically calculates it from the rank of the
input tensor. This is completely unnecessary when the rank of the
input can be statically determined at graph construction time.

Modify tf.transpose to emit 'perm' as a single Const node whenever
possible.

@aaroey aaroey requested a review from alextp August 31, 2018 05:26
@aaroey aaroey self-assigned this Aug 31, 2018
@alextp alextp added awaiting testing (then merge) kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Sep 4, 2018
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Sep 4, 2018
tensorflow/python/ops/array_ops.py Outdated Show resolved Hide resolved
If not given an explicit 'perm' parameter, tf.transpose currently
emits a graph that dynamically calculates it from the rank of the
input tensor. This is completely unnecessary when the rank of the
input can be statically determined at graph construction time.

Modify tf.transpose to emit 'perm' as a single Const node whenever
possible.
@alextp alextp added the kokoro:force-run Tests on submitted change label Sep 13, 2018
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Sep 13, 2018
@efagerho
Copy link
Contributor Author

efagerho commented Sep 15, 2018

EDIT: It looks like //tensorflow/contrib/learn:dnn_test fails with the patch in addition to the other pre-existing failures under contrib. Need to debug this further, since I can't figure out the root cause.

@efagerho
Copy link
Contributor Author

efagerho commented Sep 17, 2018

It looks like there are a few tests where this patch causes a test to fail. They all raise an exception in the same place:

tensorflow/contrib/learn/python/learn/estimators/head.py", line 1924, in _centered_bias_step

What's strange is that the code that builds the graph doesn't fail when tf.transpose is called, i.e. the graph node is created just as expected, so its input parameters seem validated. Having gone through every such call with some good old print debugging, the parameters don't look like anything strange. The exception in the test is raised when the optimizer is creating the backprop graph for the bias computation and at this point it looks like some variable and it's grad have different shapes:

File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py", line 1562, in testEnableCenteredBias
    regressor.fit(input_fn=_input_fn, steps=5)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/python/util/deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 525, in fit
    loss = self._train_model(input_fn=input_fn, hooks=hooks)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1042, in _train_model
    model_fn_ops = self._get_train_ops(features, labels)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1265, in _get_train_ops
    return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1228, in _call_model_fn
    model_fn_results = self._model_fn(features, labels, **kwargs)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/contrib/learn/python/learn/estimators/dnn.py", line 214, in _dnn_model_fn
    logits=logits)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/contrib/learn/python/learn/estimators/head.py", line 758, in create_model_fn_ops
    enable_centered_bias=self._enable_centered_bias)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/contrib/learn/python/learn/estimators/head.py", line 669, in _create_model_fn_ops
    batch_size, loss_fn, weight_tensor)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/contrib/learn/python/learn/estimators/head.py", line 1940, in _train_op
    weights=weights)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/contrib/learn/python/learn/estimators/head.py", line 1924, in _centered_bias_step
    centered_bias_loss, var_list=(centered_bias,), name=name)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/python/training/optimizer.py", line 410, in minimize
    name=name)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/python/training/optimizer.py", line 607, in apply_gradients
    update_ops.append(processor.update_op(self, grad))
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/python/training/optimizer.py", line 115, in update_op
    update_op = optimizer._apply_dense(g, self._v)  # pylint: disable=protected-access
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/python/training/adagrad.py", line 103, in _apply_dense
    use_locking=self._use_locking)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/python/training/gen_training_ops.py", line 174, in apply_adagrad
    use_locking=use_locking, update_slots=update_slots, name=name)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/python/util/deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/python/framework/ops.py", line 3274, in create_op
    op_def=op_def)
  File "/home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/python/framework/ops.py", line 1770, in __init__
    self._traceback = tf_stack.extract_stack()

InvalidArgumentError (see above for traceback): var and grad do not have the same shape[1] []
         [[node dnn/regression_head/centered_bias_step/update_dnn/regression_head/centered_bias_weight/ApplyAdagrad (defined at /home/efagerholm/.cache/bazel/_bazel_efagerholm/3bd66cc293ffd5c1e1b6be4e441d09f4/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/contrib/learn/dnn_test.runfiles/org_tensorflow/tensorflow/contrib/learn/python/learn/estimators/head.py:1924)  = ApplyAdagrad[T=DT_FLOAT, _class=["loc:@dnn/r...plyAdagrad"], update_slots=true, use_locking=false, _device="/job:localhost/replica:0/task:0/device:CPU:0"](dnn/regression_head/centered_bias_weight, dnn/regression_head/dnn/regression_head/centered_bias_weight/Adagrad, dnn/regression_head/centered_bias_step/learning_rate, dnn/regression_head/gradients/dnn/regression_head/centered_bias_step/Tile_grad/Sum)]]

I don't quite understand how it would be possible to have a valid forward graph and then have the optimizer end up with different sizes for variables during backprop. I'll look closer into this later this week.

@efagerho
Copy link
Contributor Author

It looks like I've triggered a bug in TensorFlow (probably the grappler optimizer). The following code fails with the error in the message above:

      a = ops.convert_to_tensor(a, name="a") 
      if not a.get_shape().ndims: 
        rank = gen_array_ops.rank(a) 
        perm = (rank - 1) - gen_math_ops._range(0, rank, 1) 
      else: 
        rank = a.get_shape().ndims 
        perm = (rank - 1) - np.arange(rank, dtype=np.int32) 

However, if I simply add a tf.Print on the perm parameter it works, i.e. the following code passes unit tests:

      a = ops.convert_to_tensor(a, name="a") 
      if not a.get_shape().ndims: 
        rank = gen_array_ops.rank(a) 
        perm = (rank - 1) - gen_math_ops._range(0, rank, 1) 
      else: 
        rank = a.get_shape().ndims 
        perm = (rank - 1) - np.arange(rank, dtype=np.int32) 
        from tensorflow.python.ops import logging_ops 
        perm = logging_ops.Print(perm, [perm], "sdfsdf") 

@alextp
Copy link
Contributor

alextp commented Sep 17, 2018 via email

@alextp
Copy link
Contributor

alextp commented Sep 17, 2018

I suggest this because I think it's not a grappler-related bug but instead an issue where some piece of code downstream behaves differently whether perm is a tensor or not, and print makes it a tensor.

@efagerho
Copy link
Contributor Author

Can you fix this by doing perm = ops.convert_to_tensor((rank - 1) - np.arange(rank, dtype=np.int32))?

Should have mentioned that I already tried this and it doesn't help. In fact, I tried the following things:

1. perm = logging_ops.Print(perm, [perm], "sdfsdf")
2. perm = constant(perm)
3. perm = identity(perm)
4. perm = ops.convert_to_tensor(perm)

Unit tests only pass with (1), the others all fail.

@efagerho
Copy link
Contributor Author

Since tf.Print() is basically tf.identity(), I'm not sure if there could be some strange device placement issues going on here? However, I'm running tests with "--config=opt", so there's really only the CPU to choose from, so I can't see how this could factor in either.

@alextp
Copy link
Contributor

alextp commented Sep 17, 2018

@rmlarsen is there someone on the grappler side who can help investigate this failure?

@rmlarsen
Copy link
Member

rmlarsen commented Oct 1, 2018

@efagerho thanks for the PR and sorry for the delay. Let me take a look.

@rmlarsen
Copy link
Member

rmlarsen commented Oct 1, 2018

This does appear to be a Grappler bug. The tests pass when I disable all Grappler optimizations. I will hunt down and squash the bug now.

@rmlarsen
Copy link
Member

rmlarsen commented Oct 3, 2018

I believe this was caused by a bug in the shape function of Transpose. I will submit a fix shortly. Then we should be able to proceed with this PR.

@efagerho
Copy link
Contributor Author

efagerho commented Oct 3, 2018

I believe this was caused by a bug in the shape function of Transpose. I will submit a fix shortly. Then we should be able to proceed with this PR.

That's quite unexpected. Would have assumed that code to have been fairly well exercised. Thanks for figuring it out!

@rmlarsen
Copy link
Member

rmlarsen commented Oct 3, 2018

@efagerho indeed!

@rmlarsen
Copy link
Member

rmlarsen commented Oct 3, 2018

@efagerho @alextp it looks like fixing the shape function was not enough, and that there is a separate bug in the Grappler shape inference or constant folding. :-(
I'll keep digging.

@rmlarsen
Copy link
Member

rmlarsen commented Oct 3, 2018

@efagerho @alextp OK found the second bug in reduction index materialization (a part of Grappler constant folding).

@rmlarsen
Copy link
Member

rmlarsen commented Oct 4, 2018

@efagerho @alextp I have submitted the bugfix for Grappler and we can proceed. I have verified that this change now works, but let's keep it as a PR so you get credited for it.

@tensorflow-copybara tensorflow-copybara merged commit 864e290 into tensorflow:master Oct 4, 2018
tensorflow-copybara pushed a commit that referenced this pull request Oct 4, 2018
PiperOrigin-RevId: 215824410
@rmlarsen
Copy link
Member

rmlarsen commented Oct 4, 2018

@efagerho your PR has now been merged. Thanks for the contribution!

tensorflow-copybara pushed a commit that referenced this pull request Oct 5, 2018
Automated rollback of PR #21945
END_PUBLIC
Automated rollback of commit 863f614. Revert #21945.

PiperOrigin-RevId: 215913175
@efagerho
Copy link
Contributor Author

efagerho commented Oct 8, 2018

Seems like the patch got rolled back. Were the Grappler fixes checked in before the CI ran?

@alextp
Copy link
Contributor

alextp commented Oct 8, 2018

We're working on resubmitting it; there were some obscure test failures triggered by this.

benjamintanweihao pushed a commit to benjamintanweihao/tensorflow that referenced this pull request Oct 12, 2018
Automated rollback of PR tensorflow#21945
END_PUBLIC
Automated rollback of commit 863f614. Revert tensorflow#21945.

PiperOrigin-RevId: 215913175
benjamintanweihao pushed a commit to benjamintanweihao/tensorflow that referenced this pull request Dec 5, 2018
Automated rollback of PR tensorflow#21945
END_PUBLIC
Automated rollback of commit 863f614. Revert tensorflow#21945.

PiperOrigin-RevId: 215913175
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes ready to pull PR ready for merge process
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants