From b6f89561e77ff7ac1834ba935fb9fea3a7d7b92d Mon Sep 17 00:00:00 2001 From: Nikhil Thorat Date: Tue, 17 Dec 2019 14:16:01 -0500 Subject: [PATCH 1/8] [WASM] Add Less and LessEqual. --- tfjs-backend-wasm/src/cc/BUILD | 20 +++++++ tfjs-backend-wasm/src/cc/kernels/Less.cc | 59 ++++++++++++++++++ tfjs-backend-wasm/src/cc/kernels/LessEqual.cc | 60 +++++++++++++++++++ tfjs-backend-wasm/src/kernels/Greater.ts | 2 +- tfjs-backend-wasm/src/kernels/GreaterEqual.ts | 2 +- tfjs-backend-wasm/src/kernels/Less.ts | 20 +++++++ tfjs-backend-wasm/src/kernels/LessEqual.ts | 20 +++++++ 7 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 tfjs-backend-wasm/src/cc/kernels/Less.cc create mode 100644 tfjs-backend-wasm/src/cc/kernels/LessEqual.cc create mode 100644 tfjs-backend-wasm/src/kernels/Less.ts create mode 100644 tfjs-backend-wasm/src/kernels/LessEqual.ts diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index d1a5e831621..b51f3d22988 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -403,6 +403,26 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "Less", + srcs = ["kernels/Less.cc"], + deps = [ + ":backend", + ":binary", + ":util", + ], +) + +tfjs_cc_library( + name = "LessEqual", + srcs = ["kernels/LessEqual.cc"], + deps = [ + ":backend", + ":binary", + ":util" + ], +) + tfjs_cc_library( name = "Log", srcs = ["kernels/Log.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/Less.cc b/tfjs-backend-wasm/src/cc/kernels/Less.cc new file mode 100644 index 00000000000..73883506a64 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Less.cc @@ -0,0 +1,59 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "src/cc/binary.h" +#include "src/cc/util.h" + +namespace { +template +inline bool less(T a, T b) { + return a < b; +} +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Less(const int a_id, const size_t* a_shape_ptr, const int a_shape_len, + const int b_id, const size_t* b_shape_ptr, const int b_shape_len, + const DType input_type, const int out_id) { + switch (input_type) { + case DType::float32: + compare_f32(a_id, b_id, out_id, less); + break; + case DType::int32: + compare_i32(a_id, b_id, out_id, less); + break; + case DType::boolean: + compare_bool(a_id, b_id, out_id, less); + break; + default: + util::warn( + "Less for tensor ids %d and %d failed. Unsupported input_type %d", + a_id, b_id, input_type); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/LessEqual.cc b/tfjs-backend-wasm/src/cc/kernels/LessEqual.cc new file mode 100644 index 00000000000..bcf99fe1417 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/LessEqual.cc @@ -0,0 +1,60 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "src/cc/binary.h" +#include "src/cc/util.h" + +namespace { +template +inline bool lessEqual(T a, T b) { + return a <= b; +} +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void LessEqual(const int a_id, const size_t* a_shape_ptr, const int a_shape_len, + const int b_id, const size_t* b_shape_ptr, const int b_shape_len, + const DType input_type, const int out_id) { + switch (input_type) { + case DType::float32: + compare_f32(a_id, b_id, out_id, lessEqual); + break; + case DType::int32: + compare_i32(a_id, b_id, out_id, lessEqual); + break; + case DType::boolean: + compare_bool(a_id, b_id, out_id, lessEqual); + break; + default: + util::warn( + "LessEqual for tensor ids %d and %d failed." + "Unsupported input_type %d", + a_id, b_id, input_type); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/Greater.ts b/tfjs-backend-wasm/src/kernels/Greater.ts index e145cd17d4c..3a3c2b384f8 100644 --- a/tfjs-backend-wasm/src/kernels/Greater.ts +++ b/tfjs-backend-wasm/src/kernels/Greater.ts @@ -15,6 +15,6 @@ * ============================================================================= */ -import { registerBinaryKernel } from './binary_kernel'; +import {registerBinaryKernel} from './binary_kernel'; const supportsBroadcast = true; registerBinaryKernel('Greater', supportsBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/GreaterEqual.ts b/tfjs-backend-wasm/src/kernels/GreaterEqual.ts index 9563c4b716d..3c999a01942 100644 --- a/tfjs-backend-wasm/src/kernels/GreaterEqual.ts +++ b/tfjs-backend-wasm/src/kernels/GreaterEqual.ts @@ -15,6 +15,6 @@ * ============================================================================= */ -import { registerBinaryKernel } from './binary_kernel'; +import {registerBinaryKernel} from './binary_kernel'; const supportsBroadcast = true; registerBinaryKernel('GreaterEqual', supportsBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/Less.ts b/tfjs-backend-wasm/src/kernels/Less.ts new file mode 100644 index 00000000000..256940d5802 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Less.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {registerBinaryKernel} from './binary_kernel'; +const supportsBroadcast = true; +registerBinaryKernel('Less', supportsBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/LessEqual.ts b/tfjs-backend-wasm/src/kernels/LessEqual.ts new file mode 100644 index 00000000000..1900758c32a --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/LessEqual.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {registerBinaryKernel} from './binary_kernel'; +const supportsBroadcast = true; +registerBinaryKernel('LessEqual', supportsBroadcast, 'bool'); From c504fb3e18d17499d25fa2c11193e0810cd0b790 Mon Sep 17 00:00:00 2001 From: Nikhil Thorat Date: Tue, 17 Dec 2019 14:17:20 -0500 Subject: [PATCH 2/8] save --- tfjs-backend-wasm/src/cc/BUILD | 2 +- tfjs-backend-wasm/src/kernels/all_kernels.ts | 2 ++ tfjs-backend-wasm/src/setup_test.ts | 19 +++++++++++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index b51f3d22988..2950faab46d 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -419,7 +419,7 @@ tfjs_cc_library( deps = [ ":backend", ":binary", - ":util" + ":util", ], ) diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index a0941aa873c..4f6d0c283ce 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -38,6 +38,8 @@ import './FusedConv2D'; import './FusedDepthwiseConv2D'; import './Greater'; import './GreaterEqual'; +import './Less'; +import './LessEqual'; import './Log'; import './Max'; import './Maximum'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 2e257e6f90a..eb695ed345c 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -237,6 +237,25 @@ const TEST_FILTERS: TestFilter[] = [ 'broadcasting Tensor4D shapes' // Same as above. ] }, + { + include: 'less ', + excludes: [ + 'broadcasting Tensor2D shapes', // Broadcasting along inner dims not + // supported yet. + 'broadcasting Tensor3D shapes', // Same as above. + 'broadcasting Tensor4D shapes' // Same as above. + ] + }, + { + include: 'lessEqual', + excludes: [ + 'gradient', // Not yet implemented. + 'broadcasting Tensor2D shapes', // Broadcasting along inner dims not + // supported yet. + 'broadcasting Tensor3D shapes', // Same as above. + 'broadcasting Tensor4D shapes' // Same as above. + ] + }, { include: 'mean ', excludes: [ From 9c75f22830c6990e96ed8e7b13958c176a032c5c Mon Sep 17 00:00:00 2001 From: Nikhil Thorat Date: Tue, 17 Dec 2019 14:19:56 -0500 Subject: [PATCH 3/8] save --- tfjs-core/src/ops/compare.ts | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tfjs-core/src/ops/compare.ts b/tfjs-core/src/ops/compare.ts index ffbab4792b2..2fa8d138091 100644 --- a/tfjs-core/src/ops/compare.ts +++ b/tfjs-core/src/ops/compare.ts @@ -90,7 +90,9 @@ function less_( [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); - return ENGINE.runKernelFunc(backend => backend.less($a, $b), {$a, $b}) as T; + return ENGINE.runKernelFunc( + backend => backend.less($a, $b), {a: $a, b: $b}, null /* grad */, + 'Less') as T; } /** @@ -166,8 +168,11 @@ function lessEqual_( [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); - return ENGINE.runKernelFunc(backend => backend.lessEqual($a, $b), {$a, $b}) as - T; + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.greaterEqual($a, $b); + save([$a, $b]); + return res; + }, {a: $a, b: $b}, null /* grad */, 'LessEqual') as T; } function lessEqualStrict_( @@ -203,9 +208,8 @@ function greater_( assertAndGetBroadcastShape($a.shape, $b.shape); return ENGINE.runKernelFunc( - backend => backend.greater($a, $b), - {a: $a, b: $b}, null /* grad */, 'Greater' - ) as T; + backend => backend.greater($a, $b), {a: $a, b: $b}, + null /* grad */, 'Greater') as T; } function greaterStrict_(a: T|TensorLike, b: T|TensorLike): T { From f9371e3631941a0459bc437d5b3348c774566efb Mon Sep 17 00:00:00 2001 From: Nikhil Thorat Date: Tue, 17 Dec 2019 14:47:09 -0500 Subject: [PATCH 4/8] save --- tfjs-backend-wasm/src/kernels/Greater.ts | 2 +- tfjs-backend-wasm/src/kernels/GreaterEqual.ts | 2 +- tfjs-backend-wasm/src/kernels/Less.ts | 2 +- tfjs-backend-wasm/src/kernels/LessEqual.ts | 2 +- .../src/kernels/binary_kernel.ts | 8 +++++--- tfjs-backend-wasm/src/setup_test.ts | 20 ++++++++++--------- tfjs-core/src/ops/compare.ts | 2 +- 7 files changed, 21 insertions(+), 17 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/Greater.ts b/tfjs-backend-wasm/src/kernels/Greater.ts index 3a3c2b384f8..e381d50eadc 100644 --- a/tfjs-backend-wasm/src/kernels/Greater.ts +++ b/tfjs-backend-wasm/src/kernels/Greater.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = true; +const supportsBroadcast = false; registerBinaryKernel('Greater', supportsBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/GreaterEqual.ts b/tfjs-backend-wasm/src/kernels/GreaterEqual.ts index 3c999a01942..555c8cf1c6a 100644 --- a/tfjs-backend-wasm/src/kernels/GreaterEqual.ts +++ b/tfjs-backend-wasm/src/kernels/GreaterEqual.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = true; +const supportsBroadcast = false; registerBinaryKernel('GreaterEqual', supportsBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/Less.ts b/tfjs-backend-wasm/src/kernels/Less.ts index 256940d5802..8eb5e2f6c56 100644 --- a/tfjs-backend-wasm/src/kernels/Less.ts +++ b/tfjs-backend-wasm/src/kernels/Less.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = true; +const supportsBroadcast = false; registerBinaryKernel('Less', supportsBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/LessEqual.ts b/tfjs-backend-wasm/src/kernels/LessEqual.ts index 1900758c32a..a113dc4bac2 100644 --- a/tfjs-backend-wasm/src/kernels/LessEqual.ts +++ b/tfjs-backend-wasm/src/kernels/LessEqual.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = true; +const supportsBroadcast = false; registerBinaryKernel('LessEqual', supportsBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/binary_kernel.ts b/tfjs-backend-wasm/src/kernels/binary_kernel.ts index 25723c5119f..da7cd22f274 100644 --- a/tfjs-backend-wasm/src/kernels/binary_kernel.ts +++ b/tfjs-backend-wasm/src/kernels/binary_kernel.ts @@ -21,7 +21,7 @@ import {BackendWasm} from '../backend_wasm'; import {CppDType} from './types'; export function registerBinaryKernel( - kernelName: string, supportsBroadcast: boolean, dtype?: DataType) { + kernelName: string, supportsFullBroadcast: boolean, dtype?: DataType) { let wasmFunc: (aId: number, aShape: Uint8Array, aShapeLen: number, bId: number, bShape: Uint8Array, bShapeLen: number, dtype: number, outId: number) => @@ -63,7 +63,7 @@ export function registerBinaryKernel( aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, CppDType[a.dtype], outId); - if (supportsBroadcast) { + if (supportsFullBroadcast) { kernelFunc(); return out; } @@ -76,7 +76,9 @@ export function registerBinaryKernel( kernelFunc(); return out; } else { - throw new Error('Broadcasting along inner dims is not yet supported'); + throw new Error( + `Broadcasting along outer dims is not yet ` + + `supported for ${kernelName}.`); } } diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index eb695ed345c..01939a3bc07 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -240,20 +240,22 @@ const TEST_FILTERS: TestFilter[] = [ { include: 'less ', excludes: [ - 'broadcasting Tensor2D shapes', // Broadcasting along inner dims not - // supported yet. - 'broadcasting Tensor3D shapes', // Same as above. - 'broadcasting Tensor4D shapes' // Same as above. + 'broadcasting Tensor2D shapes', // Broadcasting along inner dims not + // supported yet. + 'broadcasting Tensor3D shapes', // Same as above. + 'broadcasting Tensor3D float32', // Same as above. + 'broadcasting Tensor4D shapes' // Same as above. ] }, { include: 'lessEqual', excludes: [ - 'gradient', // Not yet implemented. - 'broadcasting Tensor2D shapes', // Broadcasting along inner dims not - // supported yet. - 'broadcasting Tensor3D shapes', // Same as above. - 'broadcasting Tensor4D shapes' // Same as above. + 'gradient', // Not yet implemented. + 'broadcasting Tensor2D shapes', // Broadcasting along inner dims not + // supported yet. + 'broadcasting Tensor3D shapes', // Same as above. + 'broadcasting Tensor3D float32', // Same as above. + 'broadcasting Tensor4D shapes' // Same as above. ] }, { diff --git a/tfjs-core/src/ops/compare.ts b/tfjs-core/src/ops/compare.ts index 2fa8d138091..7b32b7d8b89 100644 --- a/tfjs-core/src/ops/compare.ts +++ b/tfjs-core/src/ops/compare.ts @@ -169,7 +169,7 @@ function lessEqual_( assertAndGetBroadcastShape($a.shape, $b.shape); return ENGINE.runKernelFunc((backend, save) => { - const res = backend.greaterEqual($a, $b); + const res = backend.lessEqual($a, $b); save([$a, $b]); return res; }, {a: $a, b: $b}, null /* grad */, 'LessEqual') as T; From 2e1bb77f00e4eac9f7e3bd88411231b853ab505b Mon Sep 17 00:00:00 2001 From: Nikhil Thorat Date: Tue, 17 Dec 2019 14:48:57 -0500 Subject: [PATCH 5/8] save --- tfjs-backend-wasm/src/kernels/Add.ts | 4 ++-- tfjs-backend-wasm/src/kernels/Div.ts | 4 ++-- tfjs-backend-wasm/src/kernels/FloorDiv.ts | 4 ++-- tfjs-backend-wasm/src/kernels/Greater.ts | 4 ++-- tfjs-backend-wasm/src/kernels/GreaterEqual.ts | 4 ++-- tfjs-backend-wasm/src/kernels/Less.ts | 4 ++-- tfjs-backend-wasm/src/kernels/LessEqual.ts | 4 ++-- tfjs-backend-wasm/src/kernels/Maximum.ts | 4 ++-- tfjs-backend-wasm/src/kernels/Minimum.ts | 4 ++-- tfjs-backend-wasm/src/kernels/Mul.ts | 4 ++-- tfjs-backend-wasm/src/kernels/Sub.ts | 4 ++-- 11 files changed, 22 insertions(+), 22 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/Add.ts b/tfjs-backend-wasm/src/kernels/Add.ts index 41b984afaf3..2a57026ea95 100644 --- a/tfjs-backend-wasm/src/kernels/Add.ts +++ b/tfjs-backend-wasm/src/kernels/Add.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = true; -registerBinaryKernel('Add', supportsBroadcast); +const supportsFullBroadcast = true; +registerBinaryKernel('Add', supportsFullBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/Div.ts b/tfjs-backend-wasm/src/kernels/Div.ts index 0e0780d4caa..7baef07b5f5 100644 --- a/tfjs-backend-wasm/src/kernels/Div.ts +++ b/tfjs-backend-wasm/src/kernels/Div.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = false; -registerBinaryKernel('Div', supportsBroadcast); +const supportsFullBroadcast = false; +registerBinaryKernel('Div', supportsFullBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/FloorDiv.ts b/tfjs-backend-wasm/src/kernels/FloorDiv.ts index 9e8b15d4d7e..7b3d852eda7 100644 --- a/tfjs-backend-wasm/src/kernels/FloorDiv.ts +++ b/tfjs-backend-wasm/src/kernels/FloorDiv.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = false; -registerBinaryKernel('FloorDiv', supportsBroadcast); +const supportsFullBroadcast = false; +registerBinaryKernel('FloorDiv', supportsFullBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/Greater.ts b/tfjs-backend-wasm/src/kernels/Greater.ts index e381d50eadc..286b1631d6c 100644 --- a/tfjs-backend-wasm/src/kernels/Greater.ts +++ b/tfjs-backend-wasm/src/kernels/Greater.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = false; -registerBinaryKernel('Greater', supportsBroadcast, 'bool'); +const supportsFullBroadcast = false; +registerBinaryKernel('Greater', supportsFullBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/GreaterEqual.ts b/tfjs-backend-wasm/src/kernels/GreaterEqual.ts index 555c8cf1c6a..7ba75e543e5 100644 --- a/tfjs-backend-wasm/src/kernels/GreaterEqual.ts +++ b/tfjs-backend-wasm/src/kernels/GreaterEqual.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = false; -registerBinaryKernel('GreaterEqual', supportsBroadcast, 'bool'); +const supportsFullBroadcast = false; +registerBinaryKernel('GreaterEqual', supportsFullBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/Less.ts b/tfjs-backend-wasm/src/kernels/Less.ts index 8eb5e2f6c56..321c07cc998 100644 --- a/tfjs-backend-wasm/src/kernels/Less.ts +++ b/tfjs-backend-wasm/src/kernels/Less.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = false; -registerBinaryKernel('Less', supportsBroadcast, 'bool'); +const supportsFullBroadcast = false; +registerBinaryKernel('Less', supportsFullBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/LessEqual.ts b/tfjs-backend-wasm/src/kernels/LessEqual.ts index a113dc4bac2..9ca0688fa58 100644 --- a/tfjs-backend-wasm/src/kernels/LessEqual.ts +++ b/tfjs-backend-wasm/src/kernels/LessEqual.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = false; -registerBinaryKernel('LessEqual', supportsBroadcast, 'bool'); +const supportsFullBroadcast = false; +registerBinaryKernel('LessEqual', supportsFullBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/Maximum.ts b/tfjs-backend-wasm/src/kernels/Maximum.ts index 647c54f8977..250834101a7 100644 --- a/tfjs-backend-wasm/src/kernels/Maximum.ts +++ b/tfjs-backend-wasm/src/kernels/Maximum.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = false; -registerBinaryKernel('Maximum', supportsBroadcast); +const supportsFullBroadcast = false; +registerBinaryKernel('Maximum', supportsFullBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/Minimum.ts b/tfjs-backend-wasm/src/kernels/Minimum.ts index 0c167b7c318..e266420b7e6 100644 --- a/tfjs-backend-wasm/src/kernels/Minimum.ts +++ b/tfjs-backend-wasm/src/kernels/Minimum.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = false; -registerBinaryKernel('Minimum', supportsBroadcast); +const supportsFullBroadcast = false; +registerBinaryKernel('Minimum', supportsFullBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/Mul.ts b/tfjs-backend-wasm/src/kernels/Mul.ts index f50cb0cc105..bf93abf9fb6 100644 --- a/tfjs-backend-wasm/src/kernels/Mul.ts +++ b/tfjs-backend-wasm/src/kernels/Mul.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = true; -registerBinaryKernel('Mul', supportsBroadcast); +const supportsFullBroadcast = true; +registerBinaryKernel('Mul', supportsFullBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/Sub.ts b/tfjs-backend-wasm/src/kernels/Sub.ts index 80768ace259..935d4704d62 100644 --- a/tfjs-backend-wasm/src/kernels/Sub.ts +++ b/tfjs-backend-wasm/src/kernels/Sub.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = true; -registerBinaryKernel('Sub', supportsBroadcast); +const supportsFullBroadcast = true; +registerBinaryKernel('Sub', supportsFullBroadcast); From c3dbfebe04a81649487aee56714e2130f51940ef Mon Sep 17 00:00:00 2001 From: Nikhil Thorat Date: Tue, 17 Dec 2019 15:19:36 -0500 Subject: [PATCH 6/8] save --- tfjs-backend-wasm/src/cc/BUILD | 12 +++++ tfjs-backend-wasm/src/cc/binary.h | 9 ++++ .../src/cc/kernels/LogicalAnd.cc | 52 +++++++++++++++++++ tfjs-backend-wasm/src/kernels/LogicalAnd.ts | 20 +++++++ tfjs-backend-wasm/src/kernels/all_kernels.ts | 1 + tfjs-backend-wasm/src/setup_test.ts | 19 +++++-- tfjs-core/src/ops/logical_ops.ts | 3 +- 7 files changed, 110 insertions(+), 6 deletions(-) create mode 100644 tfjs-backend-wasm/src/cc/kernels/LogicalAnd.cc create mode 100644 tfjs-backend-wasm/src/kernels/LogicalAnd.ts diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 2950faab46d..e8d9b224b76 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -158,6 +158,8 @@ tfjs_cc_library( ":FusedDepthwiseConv2D", ":Greater", ":GreaterEqual", + ":Less", + ":LessEqual", ":Max", ":MaxPool", ":Maximum", @@ -423,6 +425,16 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "LogicalAnd", + srcs = ["kernels/LogicalAnd.cc"], + deps = [ + ":backend", + ":binary", + ":util", + ], +) + tfjs_cc_library( name = "Log", srcs = ["kernels/Log.cc"], diff --git a/tfjs-backend-wasm/src/cc/binary.h b/tfjs-backend-wasm/src/cc/binary.h index f6952fc4992..d005ebd10b3 100644 --- a/tfjs-backend-wasm/src/cc/binary.h +++ b/tfjs-backend-wasm/src/cc/binary.h @@ -87,6 +87,15 @@ inline void compare_bool(const int a_id, const int b_id, const int out_id, out_info.b_write(), operation); } +inline void logical(const int a_id, const int b_id, const int out_id, + bool operation(bool, bool)) { + auto& a_info = backend::get_tensor_info(a_id); + auto& b_info = backend::get_tensor_info(b_id); + auto& out_info = backend::get_tensor_info_out(out_id); + binary_impl(a_info.b(), a_info.size, b_info.b(), b_info.size, + out_info.b_write(), operation); +} + typedef xnn_status (*xnn_create_binary_op)(float, float, uint32_t, xnn_operator_t*); typedef xnn_status (*xnn_setup_binary_op)(xnn_operator_t, size_t, const size_t*, diff --git a/tfjs-backend-wasm/src/cc/kernels/LogicalAnd.cc b/tfjs-backend-wasm/src/cc/kernels/LogicalAnd.cc new file mode 100644 index 00000000000..0ee8e2946bd --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/LogicalAnd.cc @@ -0,0 +1,52 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "src/cc/binary.h" +#include "src/cc/util.h" + +namespace { +inline bool logical_and(bool a, bool b) { return a && b; } +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void LogicalAnd(const int a_id, const size_t* a_shape_ptr, + const int a_shape_len, const int b_id, + const size_t* b_shape_ptr, const int b_shape_len, + const DType input_type, const int out_id) { + switch (input_type) { + case DType::boolean: + compare_bool(a_id, b_id, out_id, logical_and); + break; + default: + util::warn( + "LogicalAnd for tensor ids %d and %d failed. Unsupported input_type " + "%d", + a_id, b_id, input_type); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/LogicalAnd.ts b/tfjs-backend-wasm/src/kernels/LogicalAnd.ts new file mode 100644 index 00000000000..4aa300e9608 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/LogicalAnd.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {registerBinaryKernel} from './binary_kernel'; +const supportsFullBroadcast = false; +registerBinaryKernel('LogicalAnd', supportsFullBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 4f6d0c283ce..1d773ae98fe 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -38,6 +38,7 @@ import './FusedConv2D'; import './FusedDepthwiseConv2D'; import './Greater'; import './GreaterEqual'; +import './LogicalAnd'; import './Less'; import './LessEqual'; import './Log'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 01939a3bc07..e2796dc2835 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -221,7 +221,7 @@ const TEST_FILTERS: TestFilter[] = [ { include: 'greater ', excludes: [ - 'broadcasting Tensor2D shapes', // Broadcasting along inner dims not + 'broadcasting Tensor2D shapes', // Broadcasting along outer dims not // supported yet. 'broadcasting Tensor3D shapes', // Same as above. 'broadcasting Tensor4D shapes' // Same as above. @@ -231,7 +231,7 @@ const TEST_FILTERS: TestFilter[] = [ include: 'greaterEqual', excludes: [ 'gradient', // Not yet implemented. - 'broadcasting Tensor2D shapes', // Broadcasting along inner dims not + 'broadcasting Tensor2D shapes', // Broadcasting along outer dims not // supported yet. 'broadcasting Tensor3D shapes', // Same as above. 'broadcasting Tensor4D shapes' // Same as above. @@ -240,7 +240,7 @@ const TEST_FILTERS: TestFilter[] = [ { include: 'less ', excludes: [ - 'broadcasting Tensor2D shapes', // Broadcasting along inner dims not + 'broadcasting Tensor2D shapes', // Broadcasting along outer dims not // supported yet. 'broadcasting Tensor3D shapes', // Same as above. 'broadcasting Tensor3D float32', // Same as above. @@ -251,7 +251,7 @@ const TEST_FILTERS: TestFilter[] = [ include: 'lessEqual', excludes: [ 'gradient', // Not yet implemented. - 'broadcasting Tensor2D shapes', // Broadcasting along inner dims not + 'broadcasting Tensor2D shapes', // Broadcasting along outer dims not // supported yet. 'broadcasting Tensor3D shapes', // Same as above. 'broadcasting Tensor3D float32', // Same as above. @@ -264,7 +264,16 @@ const TEST_FILTERS: TestFilter[] = [ 'axis=0', // Reduction not supported along inner dimensions. ] }, - {startsWith: 'sum '} + {startsWith: 'sum '}, + { + startsWith: 'logicalAnd ', + excludes: [ + 'broadcasting Tensor2D shapes', // Broadcasting along outer dimensions + // not yet supported. + 'broadcasting Tensor3D shapes', // Same as above. + 'broadcasting Tensor4D shapes', // Same as above. + ] + } ]; const customInclude = (testName: string) => { diff --git a/tfjs-core/src/ops/logical_ops.ts b/tfjs-core/src/ops/logical_ops.ts index 3b83d2aa6d9..9000a669705 100644 --- a/tfjs-core/src/ops/logical_ops.ts +++ b/tfjs-core/src/ops/logical_ops.ts @@ -63,7 +63,8 @@ function logicalAnd_( assertAndGetBroadcastShape($a.shape, $b.shape); return ENGINE.runKernelFunc( - backend => backend.logicalAnd($a, $b), {$a, $b}) as T; + backend => backend.logicalAnd($a, $b), {a: $a, b: $b}, + null /* grad */, 'LogicalAnd') as T; } /** From 6c0f1e226c619a8e723098c5d9d56eaf0d098a64 Mon Sep 17 00:00:00 2001 From: Nikhil Thorat Date: Tue, 17 Dec 2019 16:28:12 -0500 Subject: [PATCH 7/8] save --- tfjs-backend-wasm/src/cc/BUILD | 10 +++ tfjs-backend-wasm/src/cc/kernels/Tile.cc | 89 ++++++++++++++++++++ tfjs-backend-wasm/src/kernels/Tile.ts | 74 ++++++++++++++++ tfjs-backend-wasm/src/kernels/all_kernels.ts | 1 + tfjs-backend-wasm/src/setup_test.ts | 7 ++ tfjs-core/benchmarks/index.html | 4 +- tfjs-core/src/backends/tile_impl.ts | 4 +- tfjs-core/src/ops/array_ops.ts | 6 +- 8 files changed, 189 insertions(+), 6 deletions(-) create mode 100644 tfjs-backend-wasm/src/cc/kernels/Tile.cc create mode 100644 tfjs-backend-wasm/src/kernels/Tile.ts diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index e8d9b224b76..efa73a29516 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -174,6 +174,7 @@ tfjs_cc_library( ":ResizeBilinear", ":Sigmoid", ":Sub", + ":Tile", ":Transpose", ], ) @@ -613,6 +614,15 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "Tile", + srcs = ["kernels/Tile.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "Transpose", srcs = ["kernels/Transpose.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/Tile.cc b/tfjs-backend-wasm/src/cc/kernels/Tile.cc new file mode 100644 index 00000000000..3b1b95391f0 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Tile.cc @@ -0,0 +1,89 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/util.h" + +namespace { +template +void tile(const T* x_data, const std::vector& x_shape, + const std::vector& new_shape, T* out_data) { + const size_t x_rank = x_shape.size(); + const std::vector x_strides = tfjs::util::compute_strides(x_shape); + + const size_t out_size = tfjs::util::size_from_shape(new_shape); + const std::vector out_strides = + tfjs::util::compute_strides(new_shape); + + for (size_t i = 0; i < out_size; ++i) { + const std::vector new_loc = + tfjs::util::offset_to_loc(i, out_strides); + + std::vector original_loc(x_rank); + + for (size_t j = 0; j < original_loc.size(); ++j) { + original_loc[j] = new_loc[j] % x_shape[j]; + } + + const size_t original_index = + tfjs::util::loc_to_offset(original_loc, x_strides); + + out_data[i] = x_data[original_index]; + } +} +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Tile(const size_t x_id, const size_t* x_shape_ptr, + const size_t x_shape_length, const size_t* new_shape_ptr, + const size_t new_shape_length, const DType dtype, + const size_t out_id) { + auto x_shape = std::vector(x_shape_ptr, x_shape_ptr + x_shape_length); + auto new_shape = + std::vector(new_shape_ptr, new_shape_ptr + new_shape_length); + auto& x_info = backend::get_tensor_info(x_id); + auto& out_info = backend::get_tensor_info_out(out_id); + + switch (dtype) { + case DType::float32: + tile(x_info.f32(), x_shape, new_shape, out_info.f32_write()); + break; + case DType::int32: + tile(x_info.i32(), x_shape, new_shape, out_info.i32_write()); + break; + case DType::boolean: + tile(x_info.b(), x_shape, new_shape, out_info.b_write()); + break; + default: + util::warn("Tile for tensor id %d failed. Unknown dtype %d", x_id, dtype); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/Tile.ts b/tfjs-backend-wasm/src/kernels/Tile.ts new file mode 100644 index 00000000000..e18ed2e604e --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Tile.ts @@ -0,0 +1,74 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel} from '@tensorflow/tfjs-core'; +import {TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; +import {CppDType} from './types'; + +interface TileInputs extends NamedTensorInfoMap { + x: TensorInfo; +} + +interface TileAttrs extends NamedAttrMap { + reps: number[]; +} + +let wasmTile: ( + xId: number, xShape: Uint8Array, xShapeSize: number, newShape: Uint8Array, + newShapeSize: number, dtype: number, outId: number) => void; + +function setup(backend: BackendWasm) { + wasmTile = backend.wasm.cwrap('Tile', null /* void */, [ + 'number', // x_id + 'array', // x_shape + 'number', // x_shape.length + 'array', // new_shape + 'number', // new_shape.length + 'number' // out_id + ]); +} + +function tile( + args: {inputs: TileInputs, backend: BackendWasm, attrs: TileAttrs}) { + const {inputs, backend, attrs} = args; + const {x} = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + const {reps} = attrs; + + const newShape: number[] = new Array(x.shape.length); + for (let i = 0; i < newShape.length; i++) { + newShape[i] = x.shape[i] * reps[i]; + } + const xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer); + const newShapeBytes = new Uint8Array(new Int32Array(newShape).buffer); + + const out = backend.makeOutput(newShape, x.dtype); + const outId = backend.dataIdMap.get(out.dataId).id; + wasmTile( + xId, xShapeBytes, x.shape.length, newShapeBytes, newShape.length, + CppDType[out.dtype], outId); + return out; +} + +registerKernel({ + kernelName: 'Tile', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: tile +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 1d773ae98fe..626114d6e37 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -60,5 +60,6 @@ import './Slice'; import './Square'; import './Sub'; import './Sum'; +import './Tile'; import './Transpose'; import './Unpack'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index e2796dc2835..ef7af28fdc8 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -273,6 +273,13 @@ const TEST_FILTERS: TestFilter[] = [ 'broadcasting Tensor3D shapes', // Same as above. 'broadcasting Tensor4D shapes', // Same as above. ] + }, + { + startsWith: 'tile ', + excludes: [ + 'gradient', // Gradient not yet implemented. + 'string tensor' // String tensors not yet implemented. + ] } ]; diff --git a/tfjs-core/benchmarks/index.html b/tfjs-core/benchmarks/index.html index 5211d63b771..5a961bf510b 100644 --- a/tfjs-core/benchmarks/index.html +++ b/tfjs-core/benchmarks/index.html @@ -84,10 +84,10 @@

TensorFlow.js Model Benchmark

- + - + diff --git a/tfjs-core/src/backends/tile_impl.ts b/tfjs-core/src/backends/tile_impl.ts index a1cb8c9121c..dcf4e0d0258 100644 --- a/tfjs-core/src/backends/tile_impl.ts +++ b/tfjs-core/src/backends/tile_impl.ts @@ -35,8 +35,8 @@ export function tile( const newLoc = result.indexToLoc(i); const originalLoc: number[] = new Array(xBuf.rank); - for (let i = 0; i < originalLoc.length; i++) { - originalLoc[i] = newLoc[i] % xBuf.shape[i]; + for (let j = 0; j < originalLoc.length; j++) { + originalLoc[j] = newLoc[j] % xBuf.shape[j]; } const originalIndex = xBuf.locToIndex(originalLoc); diff --git a/tfjs-core/src/ops/array_ops.ts b/tfjs-core/src/ops/array_ops.ts index 1b976b726d1..177c5670a97 100644 --- a/tfjs-core/src/ops/array_ops.ts +++ b/tfjs-core/src/ops/array_ops.ts @@ -568,13 +568,15 @@ function tile_(x: T|TensorLike, reps: number[]): T { } return xGrad as T; }; - return {$x: derX}; + return {x: derX}; }; + const inputsToSave = [$x]; + const attrs = {reps}; return ENGINE.runKernelFunc((backend, save) => { const res = backend.tile($x, reps); save([$x]); return res; - }, {$x}, grad); + }, {x: $x}, grad, 'Tile', attrs, inputsToSave); } /** From e554d8f530e75cf67f228641e504474fdff376ab Mon Sep 17 00:00:00 2001 From: Nikhil Thorat Date: Tue, 17 Dec 2019 16:44:02 -0500 Subject: [PATCH 8/8] save --- tfjs-backend-wasm/src/cc/kernels/Tile.cc | 11 ++++++----- tfjs-core/benchmarks/index.html | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/Tile.cc b/tfjs-backend-wasm/src/cc/kernels/Tile.cc index 3b1b95391f0..04b97d9836e 100644 --- a/tfjs-backend-wasm/src/cc/kernels/Tile.cc +++ b/tfjs-backend-wasm/src/cc/kernels/Tile.cc @@ -24,8 +24,8 @@ namespace { template -void tile(const T* x_data, const std::vector& x_shape, - const std::vector& new_shape, T* out_data) { +void tile_slow(const T* x_data, const std::vector& x_shape, + const std::vector& new_shape, T* out_data) { const size_t x_rank = x_shape.size(); const std::vector x_strides = tfjs::util::compute_strides(x_shape); @@ -71,13 +71,14 @@ void Tile(const size_t x_id, const size_t* x_shape_ptr, switch (dtype) { case DType::float32: - tile(x_info.f32(), x_shape, new_shape, out_info.f32_write()); + tile_slow(x_info.f32(), x_shape, new_shape, out_info.f32_write()); break; case DType::int32: - tile(x_info.i32(), x_shape, new_shape, out_info.i32_write()); + tile_slow(x_info.i32(), x_shape, new_shape, + out_info.i32_write()); break; case DType::boolean: - tile(x_info.b(), x_shape, new_shape, out_info.b_write()); + tile_slow(x_info.b(), x_shape, new_shape, out_info.b_write()); break; default: util::warn("Tile for tensor id %d failed. Unknown dtype %d", x_id, dtype); diff --git a/tfjs-core/benchmarks/index.html b/tfjs-core/benchmarks/index.html index 5a961bf510b..5211d63b771 100644 --- a/tfjs-core/benchmarks/index.html +++ b/tfjs-core/benchmarks/index.html @@ -84,10 +84,10 @@

TensorFlow.js Model Benchmark

- + - +