-
Notifications
You must be signed in to change notification settings - Fork 2k
Modularize BatchNorm in cpu and webgl backend #3960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
lina128
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Jing, for BatchNorm, I think the input will expect a 4D tensor, to align with the TF api: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/fused-batch-norm In this case, we transform tensor to 4D in ops and pass in the transformed tensor. So in the kernel, we can assume x is 4D. Other than that, LGTM.
Reviewable status: 0 of 1 approvals obtained (waiting on @annxingyuan, @jinjingforever, @lina128, and @tafsiri)
tfjs-backend-cpu/src/kernels/BatchNorm.ts, line 87 at r1 (raw file):
} } return backend.makeTensorInfo(get4DShapeFromX(x), x.dtype, outVals);
I think this can just be x.shape.
tfjs-backend-webgl/src/kernels/BatchNorm.ts, line 31 at r1 (raw file):
backend: MathBackendWebGL, attrs: FusedBatchNormAttrs }) => TensorInfo | TensorInfo[] = ({inputs, backend, attrs}) => {
Just TensorInfo.
tfjs-backend-webgl/src/kernels/BatchNorm.ts, line 53 at r1 (raw file):
const $x: TensorInfo = xAs4D(x, backend); const $mean = as1DOr4D(mean, backend);
These transformation will be done at op level, because the TF api specifies the kernel as 4D, so in the kernel, we can assume these have the right shape.
jinjingforever
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! One question: I remember you mentioned that kernelfunc can also be called directly from other places (instead of going through op)? In those cases, would it be the caller's responsibility to transform the inputs?
Reviewable status: 0 of 1 approvals obtained (waiting on @annxingyuan, @jinjingforever, and @tafsiri)
jinjingforever
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Na! Removed reshape related code.
Reviewable status: 0 of 1 approvals obtained (waiting on @annxingyuan, @lina128, and @tafsiri)
tfjs-backend-cpu/src/kernels/BatchNorm.ts, line 87 at r1 (raw file):
Previously, lina128 (Na Li) wrote…
I think this can just be x.shape.
Done.
tfjs-backend-webgl/src/kernels/BatchNorm.ts, line 31 at r1 (raw file):
Previously, lina128 (Na Li) wrote…
Just TensorInfo.
Done
tfjs-backend-webgl/src/kernels/BatchNorm.ts, line 53 at r1 (raw file):
Previously, lina128 (Na Li) wrote…
These transformation will be done at op level, because the TF api specifies the kernel as 4D, so in the kernel, we can assume these have the right shape.
Done
lina128
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Jing, so in general, op should only contain validation, any transformation should happen in kernels. With the exception that if the TF api specifies a tensor to be certain shape, and if our op has a more relaxed requirement, then we would transform the tensor into the specified shape and pass that to the kernel. So when in doubt, check the TF api. There is the risk that other kernels may call this kernel and is not aware of this specification because it is not guarded by code in any way but rather just in documentation. We'll see whether we encounter this case in actual use. If so, we can change the type definition in kernel_names to be more specifically. For example, right now, it is just extend NamedTensorInfo. We can change it to be less generic, like 4DTensorInfo. I also shared two docs with you today, which has documented some of these decisions during the modularization.
Reviewable status: 0 of 1 approvals obtained (waiting on @annxingyuan, @lina128, and @tafsiri)
lina128
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 1 of 1 approvals obtained (waiting on @annxingyuan, @lina128, and @tafsiri)
For cpu backend, looks like I don't really need to do all the reshaping since the algorithm can operate on the underlying data directly. I only need to calculate (with conversion) the output shape. Please correct me if I am wrong.
For webgl backend, I used the Reshape kernelfunc.
Thanks!
This change is