From f6ea12060224e8f1844db9abfb0297f8a30861fa Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 24 Oct 2019 07:56:59 -0400 Subject: [PATCH 01/21] running --- tfjs-backend-wasm/src/cc/BUILD | 10 +++ .../src/cc/kernels/BatchNormalization.cc | 72 +++++++++++++++++++ tfjs-backend-wasm/src/index_test.ts | 20 ++++++ .../src/kernels/BatchNormalization.ts | 70 ++++++++++++++++++ tfjs-backend-wasm/src/kernels/all_kernels.ts | 1 + tfjs-core/src/ops/batchnorm.ts | 41 +++++++---- 6 files changed, 201 insertions(+), 13 deletions(-) create mode 100644 tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc create mode 100644 tfjs-backend-wasm/src/kernels/BatchNormalization.ts diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index ba0de7b1c84..19cbcbb9621 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -45,9 +45,19 @@ tfjs_cc_library( ":Add", ":BatchMatMul", ":Prelu", + ":BatchNormalization" ] ) +tfjs_cc_library( + name = "BatchNormalization", + srcs = ["kernels/BatchNormalization.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "Add", srcs = ["kernels/Add.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc new file mode 100644 index 00000000000..cb96b5a7903 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc @@ -0,0 +1,72 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/util.h" + +template +void batch_norm_impl(T* x_buf, int x_size, T* mean_buf, int mean_size, + T* variance_buf, int variance_size, T* offset_buf, + int offset_size, T* scale_buf, int scale_size, + T* variance_epsilon_buf, int variance_epsilon_size, + T* out_buf) { + int size = x_size; + for (int i = 0; i < size; ++i) { + out_buf[i] = x_buf[i]; + } +} +// Templates need explicit instantiation when implemented in a .cc file. +template void batch_norm_impl(float* x_buf, int x_size, float* mean_buf, + int mean_size, float* variance_buf, + int variance_size, float* offset_buf, + int offset_size, float* scale_buf, + int scale_size, + float* variance_epsilon_buf, + int variance_epsilon_size, float* out_buf); + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void BatchNormalization(int x_id, int mean_id, int variance_id, int offset_id, + int scale_id, int variance_epsilon_id, int out_id) { + const auto x_info = backend::get_tensor_info(x_id); + const auto mean_info = backend::get_tensor_info(mean_id); + const auto variance_info = backend::get_tensor_info(variance_id); + const auto offset_info = backend::get_tensor_info(offset_id); + const auto scale_info = backend::get_tensor_info(scale_id); + const auto variance_epsilon_info = + backend::get_tensor_info(variance_epsilon_id); + const auto out_info = backend::get_tensor_info(out_id); + batch_norm_impl(x_info.buf.f32, x_info.size, mean_info.buf.f32, + mean_info.size, variance_info.buf.f32, variance_info.size, + offset_info.buf.f32, offset_info.size, scale_info.buf.f32, + scale_info.size, variance_epsilon_info.buf.f32, + variance_epsilon_info.size, out_info.buf.f32); +} + +} // namespace wasm +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index b8a0093b88c..3c752f910fa 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -47,4 +47,24 @@ describeWithFlags('wasm', ALL_ENVS, () => { // This should fail in case of a memory leak. expect(memOffset1).toBe(memOffset2); }); + + fit('batchnorm test', () => { + console.log('testing'); + + const xT = tf.tensor4d([2, 4, 9, 23], [2, 1, 1, 2]); + const meanT = tf.tensor1d([1, 2]); + const varianceT = tf.tensor1d([2, 3]); + const varianceEpsilon = .001; + const scale = tf.tensor1d([4, 5]); + const offset = tf.tensor1d([4, 5]); + + const result = + tf.batchNorm4d(xT, meanT, varianceT, offset, scale, varianceEpsilon); + + // const x = await xT.array(); + // const mean = await meanT.array(); + // const variance = await varianceT.array(); + + console.log(Array.from(result.dataSync())); + }); }); diff --git a/tfjs-backend-wasm/src/kernels/BatchNormalization.ts b/tfjs-backend-wasm/src/kernels/BatchNormalization.ts new file mode 100644 index 00000000000..f7f49a87e3c --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/BatchNormalization.ts @@ -0,0 +1,70 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedTensorInfoMap, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +interface BatchNormInputs extends NamedTensorInfoMap { + x: TensorInfo; + mean: TensorInfo; + variance: TensorInfo; + offset: TensorInfo; + scale: TensorInfo; + varianceEpsilon: TensorInfo; +} + +let wasmBatchNorm: ( + xId: number, meanId: number, varianceId: number, offsetId: number, + scaleId: number, varianceEpsilonId: number, outId: number) => void; + +function setup(backend: BackendWasm): void { + wasmBatchNorm = backend.wasm.cwrap( + 'BatchNormalization', null /* void */, + ['number', 'number', 'number', 'number', 'number', 'number', 'number']); +} + +function batchNormalization( + args: {backend: BackendWasm, inputs: BatchNormInputs}): TensorInfo { + const {backend, inputs} = args; + const {x, mean, variance, offset, scale, varianceEpsilon} = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + const meanId = backend.dataIdMap.get(mean.dataId).id; + const varianceId = backend.dataIdMap.get(variance.dataId).id; + const offsetId = backend.dataIdMap.get(offset.dataId).id; + const scaleId = backend.dataIdMap.get(scale.dataId).id; + const varianceEpsilonId = backend.dataIdMap.get(varianceEpsilon.dataId).id; + + const out = backend.makeOutput(x.shape, x.dtype); + // Short-circuit zero-sized tensors. + if (util.sizeFromShape(x.shape) === 0) { + return out; + } + + const outId = backend.dataIdMap.get(out.dataId).id; + + wasmBatchNorm( + xId, meanId, varianceId, offsetId, scaleId, varianceEpsilonId, outId); + return out; +} + +registerKernel({ + kernelName: 'BatchNormalization', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: batchNormalization +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 76a174c1501..057df838c10 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -20,6 +20,7 @@ // the contents of this file and import only the kernels that are needed. import './Add'; import './BatchMatMul'; +import './BatchNormalization'; import './Cast'; import './Prelu'; import './Reshape'; diff --git a/tfjs-core/src/ops/batchnorm.ts b/tfjs-core/src/ops/batchnorm.ts index 805db94f96c..687b3de0c0f 100644 --- a/tfjs-core/src/ops/batchnorm.ts +++ b/tfjs-core/src/ops/batchnorm.ts @@ -242,6 +242,8 @@ function batchNorm_( const $x = convertToTensor(x, 'x', 'batchNorm'); const $mean = convertToTensor(mean, 'mean', 'batchNorm'); const $variance = convertToTensor(variance, 'variance', 'batchNorm'); + const $varianceEpsilon = + convertToTensor(varianceEpsilon, 'varianceEpsilon', 'batchNorm'); let $scale: Tensor|Tensor1D; if (scale != null) { $scale = convertToTensor(scale, 'scale', 'batchNorm'); @@ -338,22 +340,35 @@ function batchNorm_( return offsetDer.reshape($mean.shape as ShapeMap[R]); }; return { - $x: derX, - $mean: derMean, - $variance: derVariance, - $scale: derScale, - $offset: derOffset + x: derX, + mean: derMean, + variance: derVariance, + scale: derScale, + offset: derOffset, + varianceEpsilon: () => { + return dy; + } }; }; - const res = ENGINE.runKernelFunc((backend, save) => { - const res = backend.batchNormalization( - x4D, batchnormReshape4D($mean), batchnormReshape4D($variance), - varianceEpsilon, batchnormReshape4D($scale), - batchnormReshape4D($offset)); - save([$x, $mean, $variance, $scale]); - return res; - }, {$x, $mean, $variance, $scale, $offset}, der); + const res = ENGINE.runKernelFunc( + (backend, save) => { + const res = backend.batchNormalization( + x4D, batchnormReshape4D($mean), batchnormReshape4D($variance), + varianceEpsilon, batchnormReshape4D($scale), + batchnormReshape4D($offset)); + save([$x, $mean, $variance, $scale]); + return res; + }, + { + x: $x, + mean: $mean, + variance: $variance, + scale: $scale, + offset: $offset, + varianceEpsilon: $varianceEpsilon + }, + der, 'BatchNormalization'); return res.reshape($x.shape); } From 3e3b68d26d4ac7831473eeae0d283a5b4b01388e Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 24 Oct 2019 09:33:54 -0400 Subject: [PATCH 02/21] dont actually implement grad --- tfjs-core/src/ops/batchnorm.ts | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tfjs-core/src/ops/batchnorm.ts b/tfjs-core/src/ops/batchnorm.ts index 687b3de0c0f..79f2fc93b9a 100644 --- a/tfjs-core/src/ops/batchnorm.ts +++ b/tfjs-core/src/ops/batchnorm.ts @@ -344,10 +344,7 @@ function batchNorm_( mean: derMean, variance: derVariance, scale: derScale, - offset: derOffset, - varianceEpsilon: () => { - return dy; - } + offset: derOffset }; }; @@ -368,7 +365,15 @@ function batchNorm_( offset: $offset, varianceEpsilon: $varianceEpsilon }, - der, 'BatchNormalization'); + der as {} as (dy: Tensor, saved: Tensor[]) => { + x: () => Tensor, + mean: () => Tensor| Tensor1D, + variance: () => Tensor| Tensor1D, + scale: () => Tensor| Tensor1D, + offset: () => Tensor| Tensor1D, + varianceEpsilon: () => Tensor + }, + 'BatchNormalization'); return res.reshape($x.shape); } From 7f8449c48d87cddb122fb9abc142ae4594d972d7 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 24 Oct 2019 09:43:00 -0400 Subject: [PATCH 03/21] batchnorm logic --- .../src/cc/kernels/BatchNormalization.cc | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc index cb96b5a7903..05694799fdc 100644 --- a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc +++ b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc @@ -28,9 +28,33 @@ void batch_norm_impl(T* x_buf, int x_size, T* mean_buf, int mean_size, int offset_size, T* scale_buf, int scale_size, T* variance_epsilon_buf, int variance_epsilon_size, T* out_buf) { - int size = x_size; - for (int i = 0; i < size; ++i) { - out_buf[i] = x_buf[i]; + int offi = 0; + int mi = 0; + int si = 0; + int vi = 0; + + for (int i = 0; i < x_size; ++i) { + offi = offi + 1; + mi = mi + 1; + si = si + 1; + vi = vi + 1; + + out_buf[i] = + offset_buf[offi] + (x_buf[i] - mean_buf[mi]) * scale_buf[si] / + sqrt(variance_buf[vi] + variance_epsilon_buf[0]); + + if (offi >= offset_size) { + offi = 0; + } + if (mi >= mean_size) { + mi = 0; + } + if (si >= scale_size) { + si = 0; + } + if (vi >= variance_size) { + vi = 0; + } } } // Templates need explicit instantiation when implemented in a .cc file. From 20d515c836917c40dd3cb4b55edd217a829a76af Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 24 Oct 2019 12:17:01 -0400 Subject: [PATCH 04/21] pass ve as float --- .../src/cc/kernels/BatchNormalization.cc | 20 +++++-------- tfjs-backend-wasm/src/index_test.ts | 2 ++ .../src/kernels/BatchNormalization.ts | 29 ++++++++++++------- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc index 05694799fdc..4d9ecc8198a 100644 --- a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc +++ b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc @@ -26,8 +26,7 @@ template void batch_norm_impl(T* x_buf, int x_size, T* mean_buf, int mean_size, T* variance_buf, int variance_size, T* offset_buf, int offset_size, T* scale_buf, int scale_size, - T* variance_epsilon_buf, int variance_epsilon_size, - T* out_buf) { + float variance_epsilon, T* out_buf) { int offi = 0; int mi = 0; int si = 0; @@ -41,7 +40,7 @@ void batch_norm_impl(T* x_buf, int x_size, T* mean_buf, int mean_size, out_buf[i] = offset_buf[offi] + (x_buf[i] - mean_buf[mi]) * scale_buf[si] / - sqrt(variance_buf[vi] + variance_epsilon_buf[0]); + sqrt(variance_buf[vi] + variance_epsilon); if (offi >= offset_size) { offi = 0; @@ -62,9 +61,8 @@ template void batch_norm_impl(float* x_buf, int x_size, float* mean_buf, int mean_size, float* variance_buf, int variance_size, float* offset_buf, int offset_size, float* scale_buf, - int scale_size, - float* variance_epsilon_buf, - int variance_epsilon_size, float* out_buf); + int scale_size, float variance_epsilon, + float* out_buf); namespace tfjs { namespace wasm { @@ -74,21 +72,19 @@ extern "C" { #ifdef __EMSCRIPTEN__ EMSCRIPTEN_KEEPALIVE #endif -void BatchNormalization(int x_id, int mean_id, int variance_id, int offset_id, - int scale_id, int variance_epsilon_id, int out_id) { +void BatchNormalization(int x_id, int mean_id, int variance_id, int out_id, + int offset_id, int scale_id, float variance_epsilon) { const auto x_info = backend::get_tensor_info(x_id); const auto mean_info = backend::get_tensor_info(mean_id); const auto variance_info = backend::get_tensor_info(variance_id); const auto offset_info = backend::get_tensor_info(offset_id); const auto scale_info = backend::get_tensor_info(scale_id); - const auto variance_epsilon_info = - backend::get_tensor_info(variance_epsilon_id); const auto out_info = backend::get_tensor_info(out_id); + batch_norm_impl(x_info.buf.f32, x_info.size, mean_info.buf.f32, mean_info.size, variance_info.buf.f32, variance_info.size, offset_info.buf.f32, offset_info.size, scale_info.buf.f32, - scale_info.size, variance_epsilon_info.buf.f32, - variance_epsilon_info.size, out_info.buf.f32); + scale_info.size, variance_epsilon, out_info.buf.f32); } } // namespace wasm diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index 3c752f910fa..5f68a8686d4 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -60,6 +60,8 @@ describeWithFlags('wasm', ALL_ENVS, () => { const result = tf.batchNorm4d(xT, meanT, varianceT, offset, scale, varianceEpsilon); + // const result = + // tf.batchNorm4d(xT, meanT, varianceT, offset, scale, varianceEpsilon); // const x = await xT.array(); // const mean = await meanT.array(); diff --git a/tfjs-backend-wasm/src/kernels/BatchNormalization.ts b/tfjs-backend-wasm/src/kernels/BatchNormalization.ts index f7f49a87e3c..e02378ed6b2 100644 --- a/tfjs-backend-wasm/src/kernels/BatchNormalization.ts +++ b/tfjs-backend-wasm/src/kernels/BatchNormalization.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {NamedTensorInfoMap, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; @@ -25,12 +25,15 @@ interface BatchNormInputs extends NamedTensorInfoMap { variance: TensorInfo; offset: TensorInfo; scale: TensorInfo; - varianceEpsilon: TensorInfo; +} + +interface BatchNormAttrs extends NamedAttrMap { + varianceEpsilon: number; } let wasmBatchNorm: ( - xId: number, meanId: number, varianceId: number, offsetId: number, - scaleId: number, varianceEpsilonId: number, outId: number) => void; + xId: number, meanId: number, varianceId: number, outId: number, + offsetId: number, scaleId: number, varianceEpsilon: number) => void; function setup(backend: BackendWasm): void { wasmBatchNorm = backend.wasm.cwrap( @@ -39,15 +42,19 @@ function setup(backend: BackendWasm): void { } function batchNormalization( - args: {backend: BackendWasm, inputs: BatchNormInputs}): TensorInfo { - const {backend, inputs} = args; - const {x, mean, variance, offset, scale, varianceEpsilon} = inputs; + args: + {backend: BackendWasm, inputs: BatchNormInputs, attrs: BatchNormAttrs}): + TensorInfo { + console.log('IN BATCH NORMALIZATION'); + const {backend, inputs, attrs} = args; + const {varianceEpsilon} = attrs; + console.log(varianceEpsilon); + const {x, mean, variance, offset, scale} = inputs; const xId = backend.dataIdMap.get(x.dataId).id; const meanId = backend.dataIdMap.get(mean.dataId).id; const varianceId = backend.dataIdMap.get(variance.dataId).id; - const offsetId = backend.dataIdMap.get(offset.dataId).id; - const scaleId = backend.dataIdMap.get(scale.dataId).id; - const varianceEpsilonId = backend.dataIdMap.get(varianceEpsilon.dataId).id; + const offsetId = offset ? backend.dataIdMap.get(offset.dataId).id : null; + const scaleId = scale ? backend.dataIdMap.get(scale.dataId).id : null; const out = backend.makeOutput(x.shape, x.dtype); // Short-circuit zero-sized tensors. @@ -58,7 +65,7 @@ function batchNormalization( const outId = backend.dataIdMap.get(out.dataId).id; wasmBatchNorm( - xId, meanId, varianceId, offsetId, scaleId, varianceEpsilonId, outId); + xId, meanId, varianceId, outId, offsetId, scaleId, varianceEpsilon); return out; } From d201637a4768103cdbf7a2292106e0360ccb5b4f Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 24 Oct 2019 12:17:09 -0400 Subject: [PATCH 05/21] pass ve as float --- tfjs-core/src/ops/batchnorm.ts | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/tfjs-core/src/ops/batchnorm.ts b/tfjs-core/src/ops/batchnorm.ts index 79f2fc93b9a..a0494bad8fd 100644 --- a/tfjs-core/src/ops/batchnorm.ts +++ b/tfjs-core/src/ops/batchnorm.ts @@ -242,8 +242,6 @@ function batchNorm_( const $x = convertToTensor(x, 'x', 'batchNorm'); const $mean = convertToTensor(mean, 'mean', 'batchNorm'); const $variance = convertToTensor(variance, 'variance', 'batchNorm'); - const $varianceEpsilon = - convertToTensor(varianceEpsilon, 'varianceEpsilon', 'batchNorm'); let $scale: Tensor|Tensor1D; if (scale != null) { $scale = convertToTensor(scale, 'scale', 'batchNorm'); @@ -357,23 +355,8 @@ function batchNorm_( save([$x, $mean, $variance, $scale]); return res; }, - { - x: $x, - mean: $mean, - variance: $variance, - scale: $scale, - offset: $offset, - varianceEpsilon: $varianceEpsilon - }, - der as {} as (dy: Tensor, saved: Tensor[]) => { - x: () => Tensor, - mean: () => Tensor| Tensor1D, - variance: () => Tensor| Tensor1D, - scale: () => Tensor| Tensor1D, - offset: () => Tensor| Tensor1D, - varianceEpsilon: () => Tensor - }, - 'BatchNormalization'); + {x: $x, mean: $mean, variance: $variance, scale: $scale, offset: $offset}, + der, 'BatchNormalization', {varianceEpsilon}); return res.reshape($x.shape); } From 7068073980211c1dec56fb92dd41036c0f38739c Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 25 Oct 2019 15:52:26 -0400 Subject: [PATCH 06/21] optional offset --- .../src/cc/kernels/BatchNormalization.cc | 16 +++++++++++----- .../src/kernels/BatchNormalization.ts | 6 ++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc index 4d9ecc8198a..9021b3891cb 100644 --- a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc +++ b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc @@ -32,16 +32,22 @@ void batch_norm_impl(T* x_buf, int x_size, T* mean_buf, int mean_size, int si = 0; int vi = 0; + if (offset_buf == nullptr) { + float offset_buf_replace[1] = {5}; + offset_buf = offset_buf_replace; + offset_size = 1; + } + for (int i = 0; i < x_size; ++i) { + out_buf[i] = + offset_buf[offi] + (x_buf[i] - mean_buf[mi]) * scale_buf[si] / + sqrt(variance_buf[vi] + variance_epsilon); + offi = offi + 1; mi = mi + 1; si = si + 1; vi = vi + 1; - out_buf[i] = - offset_buf[offi] + (x_buf[i] - mean_buf[mi]) * scale_buf[si] / - sqrt(variance_buf[vi] + variance_epsilon); - if (offi >= offset_size) { offi = 0; } @@ -77,9 +83,9 @@ void BatchNormalization(int x_id, int mean_id, int variance_id, int out_id, const auto x_info = backend::get_tensor_info(x_id); const auto mean_info = backend::get_tensor_info(mean_id); const auto variance_info = backend::get_tensor_info(variance_id); - const auto offset_info = backend::get_tensor_info(offset_id); const auto scale_info = backend::get_tensor_info(scale_id); const auto out_info = backend::get_tensor_info(out_id); + const auto offset_info = backend::get_tensor_info(offset_id); batch_norm_impl(x_info.buf.f32, x_info.size, mean_info.buf.f32, mean_info.size, variance_info.buf.f32, variance_info.size, diff --git a/tfjs-backend-wasm/src/kernels/BatchNormalization.ts b/tfjs-backend-wasm/src/kernels/BatchNormalization.ts index e02378ed6b2..6c5a5954289 100644 --- a/tfjs-backend-wasm/src/kernels/BatchNormalization.ts +++ b/tfjs-backend-wasm/src/kernels/BatchNormalization.ts @@ -45,16 +45,14 @@ function batchNormalization( args: {backend: BackendWasm, inputs: BatchNormInputs, attrs: BatchNormAttrs}): TensorInfo { - console.log('IN BATCH NORMALIZATION'); const {backend, inputs, attrs} = args; const {varianceEpsilon} = attrs; - console.log(varianceEpsilon); const {x, mean, variance, offset, scale} = inputs; const xId = backend.dataIdMap.get(x.dataId).id; const meanId = backend.dataIdMap.get(mean.dataId).id; const varianceId = backend.dataIdMap.get(variance.dataId).id; - const offsetId = offset ? backend.dataIdMap.get(offset.dataId).id : null; - const scaleId = scale ? backend.dataIdMap.get(scale.dataId).id : null; + const offsetId = offset ? backend.dataIdMap.get(offset.dataId).id : -1; + const scaleId = scale ? backend.dataIdMap.get(scale.dataId).id : -1; const out = backend.makeOutput(x.shape, x.dtype); // Short-circuit zero-sized tensors. From 9b9e476b45d358ea32ce81e661413d788f15ac9f Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 25 Oct 2019 15:53:44 -0400 Subject: [PATCH 07/21] revive argument order --- tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc | 4 ++-- tfjs-backend-wasm/src/kernels/BatchNormalization.ts | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc index 9021b3891cb..66027f12eb5 100644 --- a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc +++ b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc @@ -78,8 +78,8 @@ extern "C" { #ifdef __EMSCRIPTEN__ EMSCRIPTEN_KEEPALIVE #endif -void BatchNormalization(int x_id, int mean_id, int variance_id, int out_id, - int offset_id, int scale_id, float variance_epsilon) { +void BatchNormalization(int x_id, int mean_id, int variance_id, int offset_id, + int scale_id, float variance_epsilon, int out_id) { const auto x_info = backend::get_tensor_info(x_id); const auto mean_info = backend::get_tensor_info(mean_id); const auto variance_info = backend::get_tensor_info(variance_id); diff --git a/tfjs-backend-wasm/src/kernels/BatchNormalization.ts b/tfjs-backend-wasm/src/kernels/BatchNormalization.ts index 6c5a5954289..1c80e065267 100644 --- a/tfjs-backend-wasm/src/kernels/BatchNormalization.ts +++ b/tfjs-backend-wasm/src/kernels/BatchNormalization.ts @@ -32,8 +32,8 @@ interface BatchNormAttrs extends NamedAttrMap { } let wasmBatchNorm: ( - xId: number, meanId: number, varianceId: number, outId: number, - offsetId: number, scaleId: number, varianceEpsilon: number) => void; + xId: number, meanId: number, varianceId: number, offsetId: number, + scaleId: number, varianceEpsilon: number, outId: number) => void; function setup(backend: BackendWasm): void { wasmBatchNorm = backend.wasm.cwrap( @@ -63,7 +63,7 @@ function batchNormalization( const outId = backend.dataIdMap.get(out.dataId).id; wasmBatchNorm( - xId, meanId, varianceId, outId, offsetId, scaleId, varianceEpsilon); + xId, meanId, varianceId, offsetId, scaleId, varianceEpsilon, outId); return out; } From e20b3ca79aa4ef67900aa63ef79767462413a80e Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 25 Oct 2019 15:54:34 -0400 Subject: [PATCH 08/21] scale is also optional --- tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc index 66027f12eb5..eefdef68b15 100644 --- a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc +++ b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc @@ -33,11 +33,17 @@ void batch_norm_impl(T* x_buf, int x_size, T* mean_buf, int mean_size, int vi = 0; if (offset_buf == nullptr) { - float offset_buf_replace[1] = {5}; - offset_buf = offset_buf_replace; + float offset_buf_default[1] = {0}; + offset_buf = offset_buf_default; offset_size = 1; } + if (scale_buf == nullptr) { + float scale_buf_default[1] = {1}; + scale_buf = scale_buf_default; + scale_size = 1; + } + for (int i = 0; i < x_size; ++i) { out_buf[i] = offset_buf[offi] + (x_buf[i] - mean_buf[mi]) * scale_buf[si] / From bacb02e9cdb99a6112e7abb3a36ad24109e10338 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Fri, 25 Oct 2019 15:59:10 -0400 Subject: [PATCH 09/21] reset tests --- tfjs-backend-wasm/src/index_test.ts | 25 ++++++++++++------------- tfjs-backend-wasm/src/setup_test.ts | 2 +- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index 5f68a8686d4..d7305e73406 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -48,25 +48,24 @@ describeWithFlags('wasm', ALL_ENVS, () => { expect(memOffset1).toBe(memOffset2); }); - fit('batchnorm test', () => { - console.log('testing'); - + fit('simple batchnorm4D, no offset or scale, 2x1x1x2', async () => { const xT = tf.tensor4d([2, 4, 9, 23], [2, 1, 1, 2]); const meanT = tf.tensor1d([1, 2]); const varianceT = tf.tensor1d([2, 3]); const varianceEpsilon = .001; - const scale = tf.tensor1d([4, 5]); - const offset = tf.tensor1d([4, 5]); - - const result = - tf.batchNorm4d(xT, meanT, varianceT, offset, scale, varianceEpsilon); - // const result = - // tf.batchNorm4d(xT, meanT, varianceT, offset, scale, varianceEpsilon); - // const x = await xT.array(); - // const mean = await meanT.array(); - // const variance = await varianceT.array(); + const result = tf.batchNorm4d( + xT, meanT, varianceT, undefined, undefined, varianceEpsilon); + const x = await xT.array(); + const mean = await meanT.array(); + const variance = await varianceT.array(); console.log(Array.from(result.dataSync())); + console.log([ + (x[0][0][0][0] - mean[0]) * 1 / Math.sqrt(variance[0] + varianceEpsilon), + (x[0][0][0][1] - mean[1]) * 1 / Math.sqrt(variance[1] + varianceEpsilon), + (x[1][0][0][0] - mean[0]) * 1 / Math.sqrt(variance[0] + varianceEpsilon), + (x[1][0][0][1] - mean[1]) * 1 / Math.sqrt(variance[1] + varianceEpsilon) + ]); }); }); diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 8ac19df47b9..9a5b04c54e8 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -27,7 +27,7 @@ const grepFilter = env.specFilter; /** Tests that have these substrings in their name will be included. */ const INCLUDE_LIST: string[] = [ 'add ', 'matmul ', 'prelu ', ' cast', 'sigmoid', 'abs ', 'sub ', 'mul ', - 'div ' + 'div ', 'batchNorm' ]; /** Tests that have these substrings in their name will be excluded. */ const EXCLUDE_LIST: string[] = [ From 48f64a4627ce084e9b7605470188946c9096fc5f Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 28 Oct 2019 11:55:23 -0400 Subject: [PATCH 10/21] optional args --- .../src/cc/kernels/BatchNormalization.cc | 77 ++++++++----------- tfjs-backend-wasm/src/index_test.ts | 8 +- 2 files changed, 38 insertions(+), 47 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc index eefdef68b15..bc7d3a6e232 100644 --- a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc +++ b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc @@ -22,27 +22,44 @@ #include "src/cc/backend.h" #include "src/cc/util.h" -template -void batch_norm_impl(T* x_buf, int x_size, T* mean_buf, int mean_size, - T* variance_buf, int variance_size, T* offset_buf, - int offset_size, T* scale_buf, int scale_size, - float variance_epsilon, T* out_buf) { +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void BatchNormalization(int x_id, int mean_id, int variance_id, int offset_id, + int scale_id, float variance_epsilon, int out_id) { + const auto x_info = backend::get_tensor_info(x_id); + const auto mean_info = backend::get_tensor_info(mean_id); + const auto variance_info = backend::get_tensor_info(variance_id); + const auto scale_info = backend::get_tensor_info(scale_id); + const auto out_info = backend::get_tensor_info(out_id); + const auto offset_info = backend::get_tensor_info(offset_id); + + float* x_buf = x_info.buf.f32; + int x_size = x_info.size; + float* mean_buf = mean_info.buf.f32; + int mean_size = mean_info.size; + float* variance_buf = variance_info.buf.f32; + int variance_size = variance_info.size; + + float* out_buf = out_info.buf.f32; + int offi = 0; int mi = 0; int si = 0; int vi = 0; - if (offset_buf == nullptr) { - float offset_buf_default[1] = {0}; - offset_buf = offset_buf_default; - offset_size = 1; - } + float scale_buf_default[1] = {1}; + float* scale_buf = scale_id < 0 ? scale_buf_default : scale_info.buf.f32; + int scale_size = scale_id < 0 ? 1 : scale_info.size; - if (scale_buf == nullptr) { - float scale_buf_default[1] = {1}; - scale_buf = scale_buf_default; - scale_size = 1; - } + float offset_buf_default[1] = {0}; + float* offset_buf = offset_id < 0 ? offset_buf_default : offset_info.buf.f32; + int offset_size = offset_id < 0 ? 1 : offset_info.size; for (int i = 0; i < x_size; ++i) { out_buf[i] = @@ -68,36 +85,6 @@ void batch_norm_impl(T* x_buf, int x_size, T* mean_buf, int mean_size, } } } -// Templates need explicit instantiation when implemented in a .cc file. -template void batch_norm_impl(float* x_buf, int x_size, float* mean_buf, - int mean_size, float* variance_buf, - int variance_size, float* offset_buf, - int offset_size, float* scale_buf, - int scale_size, float variance_epsilon, - float* out_buf); - -namespace tfjs { -namespace wasm { -// We use C-style API to interface with Javascript. -extern "C" { - -#ifdef __EMSCRIPTEN__ -EMSCRIPTEN_KEEPALIVE -#endif -void BatchNormalization(int x_id, int mean_id, int variance_id, int offset_id, - int scale_id, float variance_epsilon, int out_id) { - const auto x_info = backend::get_tensor_info(x_id); - const auto mean_info = backend::get_tensor_info(mean_id); - const auto variance_info = backend::get_tensor_info(variance_id); - const auto scale_info = backend::get_tensor_info(scale_id); - const auto out_info = backend::get_tensor_info(out_id); - const auto offset_info = backend::get_tensor_info(offset_id); - - batch_norm_impl(x_info.buf.f32, x_info.size, mean_info.buf.f32, - mean_info.size, variance_info.buf.f32, variance_info.size, - offset_info.buf.f32, offset_info.size, scale_info.buf.f32, - scale_info.size, variance_epsilon, out_info.buf.f32); -} } // namespace wasm } // namespace wasm diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index d7305e73406..5bb8ab67fc1 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -54,12 +54,16 @@ describeWithFlags('wasm', ALL_ENVS, () => { const varianceT = tf.tensor1d([2, 3]); const varianceEpsilon = .001; - const result = tf.batchNorm4d( - xT, meanT, varianceT, undefined, undefined, varianceEpsilon); + const result = + tf.batchNorm4d(xT, meanT, varianceT, null, null, varianceEpsilon); const x = await xT.array(); const mean = await meanT.array(); const variance = await varianceT.array(); + + // should be: + // [0.706930, 1.154508, 5.655440, 12.1223354] + console.log(Array.from(result.dataSync())); console.log([ (x[0][0][0][0] - mean[0]) * 1 / Math.sqrt(variance[0] + varianceEpsilon), From d59d37856ad8475fb9014c6469cd054eb0f9be80 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 28 Oct 2019 11:57:02 -0400 Subject: [PATCH 11/21] remove test --- tfjs-backend-wasm/src/index_test.ts | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/tfjs-backend-wasm/src/index_test.ts b/tfjs-backend-wasm/src/index_test.ts index 5bb8ab67fc1..b8a0093b88c 100644 --- a/tfjs-backend-wasm/src/index_test.ts +++ b/tfjs-backend-wasm/src/index_test.ts @@ -47,29 +47,4 @@ describeWithFlags('wasm', ALL_ENVS, () => { // This should fail in case of a memory leak. expect(memOffset1).toBe(memOffset2); }); - - fit('simple batchnorm4D, no offset or scale, 2x1x1x2', async () => { - const xT = tf.tensor4d([2, 4, 9, 23], [2, 1, 1, 2]); - const meanT = tf.tensor1d([1, 2]); - const varianceT = tf.tensor1d([2, 3]); - const varianceEpsilon = .001; - - const result = - tf.batchNorm4d(xT, meanT, varianceT, null, null, varianceEpsilon); - - const x = await xT.array(); - const mean = await meanT.array(); - const variance = await varianceT.array(); - - // should be: - // [0.706930, 1.154508, 5.655440, 12.1223354] - - console.log(Array.from(result.dataSync())); - console.log([ - (x[0][0][0][0] - mean[0]) * 1 / Math.sqrt(variance[0] + varianceEpsilon), - (x[0][0][0][1] - mean[1]) * 1 / Math.sqrt(variance[1] + varianceEpsilon), - (x[1][0][0][0] - mean[0]) * 1 / Math.sqrt(variance[0] + varianceEpsilon), - (x[1][0][0][1] - mean[1]) * 1 / Math.sqrt(variance[1] + varianceEpsilon) - ]); - }); }); From be1654aad86a5be15a777927131709d30e3bd800 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 28 Oct 2019 12:12:24 -0400 Subject: [PATCH 12/21] save inputs --- tfjs-core/src/ops/batchnorm.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tfjs-core/src/ops/batchnorm.ts b/tfjs-core/src/ops/batchnorm.ts index a0494bad8fd..b61e7c9dc7a 100644 --- a/tfjs-core/src/ops/batchnorm.ts +++ b/tfjs-core/src/ops/batchnorm.ts @@ -346,6 +346,8 @@ function batchNorm_( }; }; + const inputsToSave = [$x, $mean, $variance, $scale]; + const res = ENGINE.runKernelFunc( (backend, save) => { const res = backend.batchNormalization( @@ -356,7 +358,7 @@ function batchNorm_( return res; }, {x: $x, mean: $mean, variance: $variance, scale: $scale, offset: $offset}, - der, 'BatchNormalization', {varianceEpsilon}); + der, 'BatchNormalization', {varianceEpsilon}, inputsToSave); return res.reshape($x.shape); } From 4b34109584f8cb38122608e642523ee324f2a241 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 28 Oct 2019 12:20:14 -0400 Subject: [PATCH 13/21] default --- .../src/cc/kernels/BatchNormalization.cc | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc index bc7d3a6e232..c1f520cfbfe 100644 --- a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc +++ b/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc @@ -35,9 +35,7 @@ void BatchNormalization(int x_id, int mean_id, int variance_id, int offset_id, const auto x_info = backend::get_tensor_info(x_id); const auto mean_info = backend::get_tensor_info(mean_id); const auto variance_info = backend::get_tensor_info(variance_id); - const auto scale_info = backend::get_tensor_info(scale_id); const auto out_info = backend::get_tensor_info(out_id); - const auto offset_info = backend::get_tensor_info(offset_id); float* x_buf = x_info.buf.f32; int x_size = x_info.size; @@ -54,12 +52,28 @@ void BatchNormalization(int x_id, int mean_id, int variance_id, int offset_id, int vi = 0; float scale_buf_default[1] = {1}; - float* scale_buf = scale_id < 0 ? scale_buf_default : scale_info.buf.f32; - int scale_size = scale_id < 0 ? 1 : scale_info.size; + float* scale_buf; + int scale_size; + if (scale_id < 0) { + scale_buf = scale_buf_default; + scale_size = 1; + } else { + const auto scale_info = backend::get_tensor_info(scale_id); + scale_buf = scale_info.buf.f32; + scale_size = scale_info.size; + } float offset_buf_default[1] = {0}; - float* offset_buf = offset_id < 0 ? offset_buf_default : offset_info.buf.f32; - int offset_size = offset_id < 0 ? 1 : offset_info.size; + float* offset_buf; + int offset_size; + if (offset_id < 0) { + offset_buf = offset_buf_default; + offset_size = 1; + } else { + const auto offset_info = backend::get_tensor_info(offset_id); + offset_buf = offset_info.buf.f32; + offset_size = offset_info.size; + } for (int i = 0; i < x_size; ++i) { out_buf[i] = From abbc343249b6d758d3395bc44aa1cfe3b8ff2ede Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 28 Oct 2019 14:31:05 -0400 Subject: [PATCH 14/21] rename --- tfjs-backend-wasm/src/cc/BUILD | 6 +++--- .../cc/kernels/{BatchNormalization.cc => FusedBatchNorm.cc} | 4 ++-- .../kernels/{BatchNormalization.ts => FusedBatchNorm.ts} | 6 +++--- tfjs-backend-wasm/src/kernels/all_kernels.ts | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) rename tfjs-backend-wasm/src/cc/kernels/{BatchNormalization.cc => FusedBatchNorm.cc} (94%) rename tfjs-backend-wasm/src/kernels/{BatchNormalization.ts => FusedBatchNorm.ts} (95%) diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index b02ea03e453..14266d61d56 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -64,15 +64,15 @@ tfjs_cc_library( ":Div", ":Mul", ":Prelu", - ":BatchNormalization", + ":FusedBatchNorm", ":Sigmoid", ":Sub", ] ) tfjs_cc_library( - name = "BatchNormalization", - srcs = ["kernels/BatchNormalization.cc"], + name = "FusedBatchNorm", + srcs = ["kernels/FusedBatchNorm.cc"], deps = [ ":backend", ":util", diff --git a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc b/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc similarity index 94% rename from tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc rename to tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc index c1f520cfbfe..e48f1d07dfa 100644 --- a/tfjs-backend-wasm/src/cc/kernels/BatchNormalization.cc +++ b/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc @@ -30,8 +30,8 @@ extern "C" { #ifdef __EMSCRIPTEN__ EMSCRIPTEN_KEEPALIVE #endif -void BatchNormalization(int x_id, int mean_id, int variance_id, int offset_id, - int scale_id, float variance_epsilon, int out_id) { +void FusedBatchNorm(int x_id, int mean_id, int variance_id, int offset_id, + int scale_id, float variance_epsilon, int out_id) { const auto x_info = backend::get_tensor_info(x_id); const auto mean_info = backend::get_tensor_info(mean_id); const auto variance_info = backend::get_tensor_info(variance_id); diff --git a/tfjs-backend-wasm/src/kernels/BatchNormalization.ts b/tfjs-backend-wasm/src/kernels/FusedBatchNorm.ts similarity index 95% rename from tfjs-backend-wasm/src/kernels/BatchNormalization.ts rename to tfjs-backend-wasm/src/kernels/FusedBatchNorm.ts index 1c80e065267..03b3b4f17cc 100644 --- a/tfjs-backend-wasm/src/kernels/BatchNormalization.ts +++ b/tfjs-backend-wasm/src/kernels/FusedBatchNorm.ts @@ -37,11 +37,11 @@ let wasmBatchNorm: ( function setup(backend: BackendWasm): void { wasmBatchNorm = backend.wasm.cwrap( - 'BatchNormalization', null /* void */, + 'FusedBatchNorm', null /* void */, ['number', 'number', 'number', 'number', 'number', 'number', 'number']); } -function batchNormalization( +function fusedBatchNorm( args: {backend: BackendWasm, inputs: BatchNormInputs, attrs: BatchNormAttrs}): TensorInfo { @@ -71,5 +71,5 @@ registerKernel({ kernelName: 'BatchNormalization', backendName: 'wasm', setupFunc: setup, - kernelFunc: batchNormalization + kernelFunc: fusedBatchNorm }); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 572be6bc88f..193d3ccfefe 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -21,7 +21,7 @@ import './Abs'; import './Add'; import './BatchMatMul'; -import './BatchNormalization'; +import './FusedBatchNorm'; import './Cast'; import './Div'; import './Mul'; From 731100bd200b088c0e021e0554dd42a7a25650b5 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 28 Oct 2019 14:42:06 -0400 Subject: [PATCH 15/21] precompute --- .../src/cc/kernels/FusedBatchNorm.cc | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc b/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc index e48f1d07dfa..406e65d4219 100644 --- a/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc +++ b/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc @@ -17,10 +17,7 @@ #endif #include -#include - #include "src/cc/backend.h" -#include "src/cc/util.h" namespace tfjs { namespace wasm { @@ -46,10 +43,10 @@ void FusedBatchNorm(int x_id, int mean_id, int variance_id, int offset_id, float* out_buf = out_info.buf.f32; - int offi = 0; - int mi = 0; - int si = 0; - int vi = 0; + int offset_i = 0; + int mean_i = 0; + int scale_i = 0; + int variance_i = 0; float scale_buf_default[1] = {1}; float* scale_buf; @@ -75,27 +72,32 @@ void FusedBatchNorm(int x_id, int mean_id, int variance_id, int offset_id, offset_size = offset_info.size; } + float normalization_factor[variance_size]; + for (int i = 0; i < variance_size; ++i) { + normalization_factor[i] = sqrt(variance_buf[i] + variance_epsilon); + } + for (int i = 0; i < x_size; ++i) { - out_buf[i] = - offset_buf[offi] + (x_buf[i] - mean_buf[mi]) * scale_buf[si] / - sqrt(variance_buf[vi] + variance_epsilon); + out_buf[i] = offset_buf[offset_i] + (x_buf[i] - mean_buf[mean_i]) * + scale_buf[scale_i] / + normalization_factor[variance_i]; - offi = offi + 1; - mi = mi + 1; - si = si + 1; - vi = vi + 1; + offset_i = offset_i + 1; + mean_i = mean_i + 1; + scale_i = scale_i + 1; + variance_i = variance_i + 1; - if (offi >= offset_size) { - offi = 0; + if (offset_i >= offset_size) { + offset_i = 0; } - if (mi >= mean_size) { - mi = 0; + if (mean_i >= mean_size) { + mean_i = 0; } - if (si >= scale_size) { - si = 0; + if (scale_i >= scale_size) { + scale_i = 0; } - if (vi >= variance_size) { - vi = 0; + if (variance_i >= variance_size) { + variance_i = 0; } } } From 2567967c6b4fc3d6404984da4c10dbc43aaef010 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 28 Oct 2019 14:44:17 -0400 Subject: [PATCH 16/21] style --- tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc b/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc index 406e65d4219..b95c5dc5842 100644 --- a/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc +++ b/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc @@ -82,10 +82,10 @@ void FusedBatchNorm(int x_id, int mean_id, int variance_id, int offset_id, scale_buf[scale_i] / normalization_factor[variance_i]; - offset_i = offset_i + 1; - mean_i = mean_i + 1; - scale_i = scale_i + 1; - variance_i = variance_i + 1; + ++offset_i; + ++mean_i; + ++scale_i; + ++variance_i; if (offset_i >= offset_size) { offset_i = 0; From 03da7d85430808c9649a091700da396a50156cdc Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 28 Oct 2019 14:46:27 -0400 Subject: [PATCH 17/21] edits --- tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc | 2 +- tfjs-backend-wasm/src/kernels/FusedBatchNorm.ts | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc b/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc index b95c5dc5842..e608050c051 100644 --- a/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc +++ b/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc @@ -102,6 +102,6 @@ void FusedBatchNorm(int x_id, int mean_id, int variance_id, int offset_id, } } -} // namespace wasm +} // extern "C" } // namespace wasm } // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/FusedBatchNorm.ts b/tfjs-backend-wasm/src/kernels/FusedBatchNorm.ts index 03b3b4f17cc..fd055b9d4d4 100644 --- a/tfjs-backend-wasm/src/kernels/FusedBatchNorm.ts +++ b/tfjs-backend-wasm/src/kernels/FusedBatchNorm.ts @@ -51,8 +51,9 @@ function fusedBatchNorm( const xId = backend.dataIdMap.get(x.dataId).id; const meanId = backend.dataIdMap.get(mean.dataId).id; const varianceId = backend.dataIdMap.get(variance.dataId).id; - const offsetId = offset ? backend.dataIdMap.get(offset.dataId).id : -1; - const scaleId = scale ? backend.dataIdMap.get(scale.dataId).id : -1; + const offsetId = + offset != null ? backend.dataIdMap.get(offset.dataId).id : -1; + const scaleId = scale != null ? backend.dataIdMap.get(scale.dataId).id : -1; const out = backend.makeOutput(x.shape, x.dtype); // Short-circuit zero-sized tensors. From 765bd2759144f4353a9680c79602084ac32c5f11 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 28 Oct 2019 15:28:13 -0400 Subject: [PATCH 18/21] lint --- tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc b/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc index e608050c051..d982d0314c2 100644 --- a/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc +++ b/tfjs-backend-wasm/src/cc/kernels/FusedBatchNorm.cc @@ -17,6 +17,7 @@ #endif #include +#include #include "src/cc/backend.h" namespace tfjs { @@ -72,7 +73,7 @@ void FusedBatchNorm(int x_id, int mean_id, int variance_id, int offset_id, offset_size = offset_info.size; } - float normalization_factor[variance_size]; + std::vector normalization_factor(variance_size); for (int i = 0; i < variance_size; ++i) { normalization_factor[i] = sqrt(variance_buf[i] + variance_epsilon); } From 279f4a8d19e8788ee2e653aa0d750bdf4d373548 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 28 Oct 2019 15:38:04 -0400 Subject: [PATCH 19/21] save From 0d26f9440e80bf922bcb5cbdd69f15e736232c5e Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 28 Oct 2019 15:43:39 -0400 Subject: [PATCH 20/21] save From 63ede89e94233f12501b51a3e4f43a1aa11a1ac3 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 28 Oct 2019 15:50:43 -0400 Subject: [PATCH 21/21] save