-
Notifications
You must be signed in to change notification settings - Fork 2k
[wasm] Add batchnorm kernel. #2264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
dsmilkov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 approvals obtained (waiting on @annxingyuan)
tfjs-core/src/ops/batchnorm.ts, line 359 at r1 (raw file):
}, {x: $x, mean: $mean, variance: $variance, scale: $scale, offset: $offset}, der, 'BatchNormalization', {varianceEpsilon});
in order to have the gradient work with the new modular kernel, pass an array of inputs [$x, $mean, $variance, $scale] at an additional argument to runKernelFunc, telling the engine which inputsToSave. See engine.runKernelFunc signature for details.
nsthorat
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 6 of 7 files at r2.
Reviewable status:complete! 1 of 1 approvals obtained (waiting on @annxingyuan, @dsmilkov, and @nsthorat)
tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc, line 20 at r2 (raw file):
#include <math.h> #include <vector>
do you need vector?
tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc, line 49 at r2 (raw file):
float* out_buf = out_info.buf.f32; int offi = 0;
can you make these longer names, e.g. offset_i, mean_i etc?
tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc, line 81 at r2 (raw file):
out_buf[i] = offset_buf[offi] + (x_buf[i] - mean_buf[mi]) * scale_buf[si] / sqrt(variance_buf[vi] + variance_epsilon);
we might want to precompute the denominator here because x_size usually is much larger than the size of the variance tensor (can you double check me on that for a real model)?
tfjs-backend-wasm/src/kernels/BatchNormalization.ts, line 54 at r2 (raw file):
const meanId = backend.dataIdMap.get(mean.dataId).id; const varianceId = backend.dataIdMap.get(variance.dataId).id; const offsetId = offset ? backend.dataIdMap.get(offset.dataId).id : -1;
offset != null
scale != null
|
(wait for daniel LGTM before submitting) |
dsmilkov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple comments, but nothing blocking . Nice work!!
Reviewed 2 of 7 files at r2.
Reviewable status:complete! 2 of 1 approvals obtained (waiting on @annxingyuan, @dsmilkov, and @nsthorat)
tfjs-backend-wasm/src/cc/BUILD, line 67 at r2 (raw file):
":Mul", ":Prelu", ":BatchNormalization",
rename to "FusedBatchNorm" to follow the kernel name.
tfjs-backend-wasm/src/cc/BUILD, line 75 at r2 (raw file):
tfjs_cc_library( name = "BatchNormalization", srcs = ["kernels/BatchNormalization.cc"],
rename the file to FusedBatchNorm.cc
tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc, line 20 at r2 (raw file):
#include <math.h> #include <vector>
remove unused import "vector"
tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc, line 23 at r2 (raw file):
#include "src/cc/backend.h" #include "src/cc/util.h"
remove unused util.h
tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc, line 83 at r2 (raw file):
sqrt(variance_buf[vi] + variance_epsilon); offi = offi + 1;
use ++var instead of var + 1 - fortunately we are not dealing with shaders here :)
tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc, line 103 at r2 (raw file):
} } // namespace wasm
nit : extern "C" ?
tfjs-backend-wasm/src/kernels/BatchNormalization.ts, line 71 at r2 (raw file):
registerKernel({ kernelName: 'BatchNormalization',
FusedBatchNorm
annxingyuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 2 of 1 approvals obtained
tfjs-backend-wasm/src/cc/BUILD, line 67 at r2 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
rename to "FusedBatchNorm" to follow the kernel name.
Done
tfjs-backend-wasm/src/cc/BUILD, line 75 at r2 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
rename the file to FusedBatchNorm.cc
Done
tfjs-core/src/ops/batchnorm.ts, line 359 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
in order to have the gradient work with the new modular kernel, pass an array of inputs
[$x, $mean, $variance, $scale]at an additional argument torunKernelFunc, telling the engine whichinputsToSave. Seeengine.runKernelFuncsignature for details.
Done
tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc, line 20 at r2 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
remove unused import "vector"
Done
tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc, line 20 at r2 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
do you need vector?
Done
tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc, line 23 at r2 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
remove unused util.h
Done
tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc, line 49 at r2 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
can you make these longer names, e.g. offset_i, mean_i etc?
Done
tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc, line 81 at r2 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
we might want to precompute the denominator here because x_size usually is much larger than the size of the variance tensor (can you double check me on that for a real model)?
Done
For faceMesh variance is the size of the final dimension of x.
tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc, line 83 at r2 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
use
++varinstead ofvar + 1- fortunately we are not dealing with shaders here :)
Done
tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc, line 103 at r2 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
nit : extern "C" ?
Done
tfjs-backend-wasm/src/kernels/BatchNormalization.ts, line 54 at r2 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
offset != null
scale != null
Done
tfjs-backend-wasm/src/kernels/BatchNormalization.ts, line 71 at r2 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
FusedBatchNorm
Done
dsmilkov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 4 of 4 files at r3.
Reviewable status:complete! 2 of 1 approvals obtained
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is