-
Notifications
You must be signed in to change notification settings - Fork 2k
[core] Modularize batchnorm. #2996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
tafsiri
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! LGTM with three very tiny comments.
Reviewed 13 of 13 files at r1.
Reviewable status:complete! 1 of 1 approvals obtained (waiting on @annxingyuan, @dsmilkov, @lina128, and @tafsiri)
tfjs-core/src/kernel_names.ts, line 24 at r1 (raw file):
import {PixelData} from './types'; export const BatchNormalization = 'BatchNormalization';
I believe your observation about this mapping onto FusedBatchNorm is correct so let's use that as the kernel name. In the converter all versions of FusedBatchNorm mapped onto this function https://cs.opensource.google/tensorflow/tfjs/+/master:tfjs-converter/src/operations/executors/normalization_executor.ts;l=31?q=batchnorm&ss=tensorflow%2Ftfjs:tfjs-converter%2F
tfjs-core/src/ops/batchnorm.ts, line 102 at r1 (raw file):
'equal ranks.'); const x4D: Tensor4D = xAs4D($x);
Let's move this into the forward function.
tfjs-core/src/ops/batchnorm4d.ts, line 93 at r1 (raw file):
} export const batchNormalization4d = op({batchNormalization4d_});
Could you add a //todo(yassogba) to remove these later since the batchNormalization* form is already deprecated.
Thank you for the review, Yannick! Done. Also changed all the ops to be called directly instead of chained through Tensor in grad. |
|
Great! transforming away from chained ops is great to do when we can. |
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is