Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix crash with tf.transpose when a is complex and conjugate is True #46973

Merged

Conversation

yongtang
Copy link
Member

@yongtang yongtang commented Feb 6, 2021

This PR tries to address the issue raised in #46891 where
tf.transpose will crash when a is complex and conjugate is True.
The issue comes from:

However, as ndims < 2 has already been handled properly:

default:
TransposeSimple<T, conjugate>(d, in, perm, out);
break;

The check could be removed.

This PR fixes #46891.

Signed-off-by: Yong Tang yong.tang.github@outlook.com

@google-ml-butler google-ml-butler bot added the size:S CL Change Size: Small label Feb 6, 2021
@google-cla google-cla bot added the cla: yes label Feb 6, 2021
@gbaned gbaned self-assigned this Feb 7, 2021
@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Feb 7, 2021
@gbaned gbaned added the awaiting review Pull request awaiting review label Feb 16, 2021
@gbaned gbaned added the prtype:bugfix PR to fix a bug label Mar 1, 2021
PR Queue automation moved this from Assigned Reviewer to Reviewer Requested Changes Mar 27, 2021
@yongtang yongtang force-pushed the 46891-tf.transpose-conjugate branch from b28c1f3 to dc69ff1 Compare March 28, 2021 18:04
@yongtang
Copy link
Member Author

@rohan100jain @cantonios thanks for the review. The PR has been updated. Please take a look.

@tensorflowbutler tensorflowbutler removed the awaiting review Pull request awaiting review label Mar 30, 2021
Copy link
Contributor

@cantonios cantonios left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Ideally those extra tests would be folded into the original testComplex64/128, which already tests different tensor ranks. Though with the current v1 decorator and the need to update these tests anyways for v2, this is probably fine for now.

@google-ml-butler google-ml-butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Mar 31, 2021
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Mar 31, 2021
Copy link
Contributor

@cantonios cantonios left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yongtang I updated the entire test to TF v2, so can you now merge your tests into testComplex64() and testComplex128()?

Something like this would work:

    self._testBoth(np.array(np.complex(1, 2)).astype(np.complex128))
    self._testBoth(np.complex(1, 2) * np.arange(0, 21).astype(np.complex128))

This PR tries to address the issue raised in 46891 where
tf.transpose will crash when a is complex and conjugate is True.
The issue comes from:
https://github.com/tensorflow/tensorflow/blob/57bbc5e0d4b93483b8ae853352173516f1c08018/tensorflow/core/kernels/transpose_functor.h#L169

However, as ndims < 2 has already been handled properly:
https://github.com/tensorflow/tensorflow/blob/57bbc5e0d4b93483b8ae853352173516f1c08018/tensorflow/core/kernels/transpose_functor_cpu.cc#L103-L105
The check could be removed.

This PR fixes 46891.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@yongtang yongtang force-pushed the 46891-tf.transpose-conjugate branch from dc69ff1 to 20245fc Compare April 5, 2021 19:40
@google-ml-butler google-ml-butler bot removed the ready to pull PR ready for merge process label Apr 5, 2021
@yongtang
Copy link
Member Author

yongtang commented Apr 5, 2021

Thanks @cantonios for the help. The PR has been updated. Please take a look and let me know if there are any issues.

@gbaned gbaned requested a review from cantonios April 6, 2021 14:26
…omplex128()

with additional update to sort test cases in increasing rank order

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@yongtang yongtang force-pushed the 46891-tf.transpose-conjugate branch from 20245fc to 6cbc31d Compare April 6, 2021 15:40
@yongtang
Copy link
Member Author

yongtang commented Apr 6, 2021

Thanks @cantonios for the help. The PR has been updated. Please let me know if there are any additional issues.

@google-ml-butler google-ml-butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Apr 6, 2021
@copybara-service copybara-service bot merged commit 1dc6a7c into tensorflow:master Apr 7, 2021
PR Queue automation moved this from Reviewer Requested Changes to Merged Apr 7, 2021
@yongtang yongtang deleted the 46891-tf.transpose-conjugate branch April 7, 2021 16:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes kokoro:force-run Tests on submitted change prtype:bugfix PR to fix a bug ready to pull PR ready for merge process size:S CL Change Size: Small
Projects
PR Queue
  
Merged
Development

Successfully merging this pull request may close these issues.

tf.transpose crashes(abort) if a is complex
6 participants