-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve shape function check for tf.roll
#18611
Conversation
The `tf.roll` op has requirements for the shape of inputs. However, the shape of the inputs are only done at the runtime inside the kernel. This fix improve the shape function so that the check could be done early if shape is already known in the shape function. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
7c09fb6
to
4e6c516
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change. LGTM!
def testRollInputMustVectorHigherRaises(self): | ||
tensor = 7 | ||
# The input should be 1-D or higher, checked is done in kernel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor nits: just for consistency, maybe change to checked in kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yzhwang The comment has been updated. Thanks for the review!
`checked is done in kernel.` -> `checked in kernel.` for review feedback. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
The
tf.roll
op has requirements for the shape of inputs. However, the shape of the inputs are only done at the runtime inside the kernel.This fix improve the shape function so that the check could be done early if shape is already known in the shape function.
The following validations have been added in the shape function with test cases:
input
must be 1-D or highershift
must be scalar or 1-D.axis
must be scalar or 1-D.shift
andaxis
should be the same size.They matches validations in the kernel.
Signed-off-by: Yong Tang yong.tang.github@outlook.com