diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 7b5062eff97..14266d61d56 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -64,11 +64,21 @@ tfjs_cc_library( ":Div", ":Mul", ":Prelu", + ":FusedBatchNorm", ":Sigmoid", ":Sub", ] ) +tfjs_cc_library( + name = "FusedBatchNorm", + srcs = ["kernels/FusedBatchNorm.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "Abs", srcs = ["kernels/Abs.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc b/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc new file mode 100644 index 00000000000..d982d0314c2 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc @@ -0,0 +1,108 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include +#include "src/cc/backend.h" + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void FusedBatchNorm(int x_id, int mean_id, int variance_id, int offset_id, + int scale_id, float variance_epsilon, int out_id) { + const auto x_info = backend::get_tensor_info(x_id); + const auto mean_info = backend::get_tensor_info(mean_id); + const auto variance_info = backend::get_tensor_info(variance_id); + const auto out_info = backend::get_tensor_info(out_id); + + float* x_buf = x_info.buf.f32; + int x_size = x_info.size; + float* mean_buf = mean_info.buf.f32; + int mean_size = mean_info.size; + float* variance_buf = variance_info.buf.f32; + int variance_size = variance_info.size; + + float* out_buf = out_info.buf.f32; + + int offset_i = 0; + int mean_i = 0; + int scale_i = 0; + int variance_i = 0; + + float scale_buf_default[1] = {1}; + float* scale_buf; + int scale_size; + if (scale_id < 0) { + scale_buf = scale_buf_default; + scale_size = 1; + } else { + const auto scale_info = backend::get_tensor_info(scale_id); + scale_buf = scale_info.buf.f32; + scale_size = scale_info.size; + } + + float offset_buf_default[1] = {0}; + float* offset_buf; + int offset_size; + if (offset_id < 0) { + offset_buf = offset_buf_default; + offset_size = 1; + } else { + const auto offset_info = backend::get_tensor_info(offset_id); + offset_buf = offset_info.buf.f32; + offset_size = offset_info.size; + } + + std::vector normalization_factor(variance_size); + for (int i = 0; i < variance_size; ++i) { + normalization_factor[i] = sqrt(variance_buf[i] + variance_epsilon); + } + + for (int i = 0; i < x_size; ++i) { + out_buf[i] = offset_buf[offset_i] + (x_buf[i] - mean_buf[mean_i]) * + scale_buf[scale_i] / + normalization_factor[variance_i]; + + ++offset_i; + ++mean_i; + ++scale_i; + ++variance_i; + + if (offset_i >= offset_size) { + offset_i = 0; + } + if (mean_i >= mean_size) { + mean_i = 0; + } + if (scale_i >= scale_size) { + scale_i = 0; + } + if (variance_i >= variance_size) { + variance_i = 0; + } + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/FusedBatchNorm.ts b/tfjs-backend-wasm/src/kernels/FusedBatchNorm.ts new file mode 100644 index 00000000000..fd055b9d4d4 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/FusedBatchNorm.ts @@ -0,0 +1,76 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +interface BatchNormInputs extends NamedTensorInfoMap { + x: TensorInfo; + mean: TensorInfo; + variance: TensorInfo; + offset: TensorInfo; + scale: TensorInfo; +} + +interface BatchNormAttrs extends NamedAttrMap { + varianceEpsilon: number; +} + +let wasmBatchNorm: ( + xId: number, meanId: number, varianceId: number, offsetId: number, + scaleId: number, varianceEpsilon: number, outId: number) => void; + +function setup(backend: BackendWasm): void { + wasmBatchNorm = backend.wasm.cwrap( + 'FusedBatchNorm', null /* void */, + ['number', 'number', 'number', 'number', 'number', 'number', 'number']); +} + +function fusedBatchNorm( + args: + {backend: BackendWasm, inputs: BatchNormInputs, attrs: BatchNormAttrs}): + TensorInfo { + const {backend, inputs, attrs} = args; + const {varianceEpsilon} = attrs; + const {x, mean, variance, offset, scale} = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + const meanId = backend.dataIdMap.get(mean.dataId).id; + const varianceId = backend.dataIdMap.get(variance.dataId).id; + const offsetId = + offset != null ? backend.dataIdMap.get(offset.dataId).id : -1; + const scaleId = scale != null ? backend.dataIdMap.get(scale.dataId).id : -1; + + const out = backend.makeOutput(x.shape, x.dtype); + // Short-circuit zero-sized tensors. + if (util.sizeFromShape(x.shape) === 0) { + return out; + } + + const outId = backend.dataIdMap.get(out.dataId).id; + + wasmBatchNorm( + xId, meanId, varianceId, offsetId, scaleId, varianceEpsilon, outId); + return out; +} + +registerKernel({ + kernelName: 'BatchNormalization', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: fusedBatchNorm +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 2cb5b5f1454..193d3ccfefe 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -21,6 +21,7 @@ import './Abs'; import './Add'; import './BatchMatMul'; +import './FusedBatchNorm'; import './Cast'; import './Div'; import './Mul'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index d738c545d77..2acd1e46de8 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -27,7 +27,7 @@ const grepFilter = env.specFilter; /** Tests that have these substrings in their name will be included. */ const INCLUDE_LIST: string[] = [ 'add ', 'matmul ', 'prelu ', ' cast', 'sigmoid', 'abs ', 'sub ', 'mul ', - 'div ', 'slice ', 'square ' + 'div ', 'batchNorm', 'slice ', 'square ' ]; /** Tests that have these substrings in their name will be excluded. */ const EXCLUDE_LIST: string[] = [ diff --git a/tfjs-core/src/ops/batchnorm.ts b/tfjs-core/src/ops/batchnorm.ts index 805db94f96c..b61e7c9dc7a 100644 --- a/tfjs-core/src/ops/batchnorm.ts +++ b/tfjs-core/src/ops/batchnorm.ts @@ -338,22 +338,27 @@ function batchNorm_( return offsetDer.reshape($mean.shape as ShapeMap[R]); }; return { - $x: derX, - $mean: derMean, - $variance: derVariance, - $scale: derScale, - $offset: derOffset + x: derX, + mean: derMean, + variance: derVariance, + scale: derScale, + offset: derOffset }; }; - const res = ENGINE.runKernelFunc((backend, save) => { - const res = backend.batchNormalization( - x4D, batchnormReshape4D($mean), batchnormReshape4D($variance), - varianceEpsilon, batchnormReshape4D($scale), - batchnormReshape4D($offset)); - save([$x, $mean, $variance, $scale]); - return res; - }, {$x, $mean, $variance, $scale, $offset}, der); + const inputsToSave = [$x, $mean, $variance, $scale]; + + const res = ENGINE.runKernelFunc( + (backend, save) => { + const res = backend.batchNormalization( + x4D, batchnormReshape4D($mean), batchnormReshape4D($variance), + varianceEpsilon, batchnormReshape4D($scale), + batchnormReshape4D($offset)); + save([$x, $mean, $variance, $scale]); + return res; + }, + {x: $x, mean: $mean, variance: $variance, scale: $scale, offset: $offset}, + der, 'BatchNormalization', {varianceEpsilon}, inputsToSave); return res.reshape($x.shape); }