Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve shape function check for tf.roll #18611

Merged
merged 11 commits into from
Apr 18, 2018

Conversation

yongtang
Copy link
Member

The tf.roll op has requirements for the shape of inputs. However, the shape of the inputs are only done at the runtime inside the kernel.

This fix improve the shape function so that the check could be done early if shape is already known in the shape function.

The following validations have been added in the shape function with test cases:

  • The input must be 1-D or higher
  • The shift must be scalar or 1-D.
  • The axis must be scalar or 1-D.
  • The shift and axis should be the same size.

They matches validations in the kernel.

Signed-off-by: Yong Tang yong.tang.github@outlook.com

The `tf.roll` op has requirements for the shape of inputs. However,
the shape of the inputs are only done at the runtime inside the kernel.
This fix improve the shape function so that the check could be
done early if shape is already known in the shape function.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Copy link

@yzhwang yzhwang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change. LGTM!

def testRollInputMustVectorHigherRaises(self):
tensor = 7
# The input should be 1-D or higher, checked is done in kernel.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nits: just for consistency, maybe change to checked in kernel.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzhwang The comment has been updated. Thanks for the review!

`checked is done in kernel.` -> `checked in kernel.`

for review feedback.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
@martinwicke martinwicke merged commit 3bc595d into tensorflow:master Apr 18, 2018
@yongtang yongtang deleted the 04162018-roll-shape branch April 18, 2018 22:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants