Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Grappler] RemoveIdentityTranspose also removes conjugate #27500

Closed
jheymann85 opened this issue Apr 4, 2019 · 4 comments
Closed

[Grappler] RemoveIdentityTranspose also removes conjugate #27500

jheymann85 opened this issue Apr 4, 2019 · 4 comments
Assignees
Labels
comp:grappler Grappler related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug

Comments

@jheymann85
Copy link

jheymann85 commented Apr 4, 2019

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 18.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: N/A
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 1.13
  • Python version: 3.6
  • Bazel version (if compiling from source): N/A
  • GCC/Compiler version (if compiling from source): N/A
  • CUDA/cuDNN version: 10.0
  • GPU model and memory: N/A
  • Exact command to reproduce: see below

During graph optimization, RemoveIdentityTranspose also removes a transpose if it conjugates the input (see here)

Simple example:

import tensorflow as tf
with tf.Graph().as_default():
    sess = tf.Session()
    a = tf.placeholder(tf.complex64)
    data = [[1j], [1j]]
    print(sess.run(tf.transpose(a, (0, 1), conjugate=True), {a: data}))  # not optimized
    print(sess.run(tf.transpose(a, (0, 1), conjugate=True) + 1, {a: data}))  # optimized, no conjugate will be applied
    print(sess.run(tf.conj(a) + 1, {a: data}))

Output:

[[0.-1.j]
 [0.-1.j]]
[[1.+1.j]
 [1.+1.j]]
[[1.-1.j]
 [1.-1.j]]

This can happen when using einsum/tensordot and a probably related issue is #19771

A possible fix would be to also check for IsConjugateTranspose here

@mohantym
Copy link
Contributor

Hi @jheymann85 ! PR #40223 has fixed this issue. Attaching resolved gist for reference. Can we move this issue to closed status now? Thanks!

@mohantym mohantym added the stat:awaiting response Status - Awaiting response from author label Apr 11, 2022
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Apr 18, 2022
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:grappler Grappler related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug
Projects
None yet
Development

No branches or pull requests

5 participants