diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index d1a5e831621..efa73a29516 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -158,6 +158,8 @@ tfjs_cc_library( ":FusedDepthwiseConv2D", ":Greater", ":GreaterEqual", + ":Less", + ":LessEqual", ":Max", ":MaxPool", ":Maximum", @@ -172,6 +174,7 @@ tfjs_cc_library( ":ResizeBilinear", ":Sigmoid", ":Sub", + ":Tile", ":Transpose", ], ) @@ -403,6 +406,36 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "Less", + srcs = ["kernels/Less.cc"], + deps = [ + ":backend", + ":binary", + ":util", + ], +) + +tfjs_cc_library( + name = "LessEqual", + srcs = ["kernels/LessEqual.cc"], + deps = [ + ":backend", + ":binary", + ":util", + ], +) + +tfjs_cc_library( + name = "LogicalAnd", + srcs = ["kernels/LogicalAnd.cc"], + deps = [ + ":backend", + ":binary", + ":util", + ], +) + tfjs_cc_library( name = "Log", srcs = ["kernels/Log.cc"], @@ -581,6 +614,15 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "Tile", + srcs = ["kernels/Tile.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "Transpose", srcs = ["kernels/Transpose.cc"], diff --git a/tfjs-backend-wasm/src/cc/binary.h b/tfjs-backend-wasm/src/cc/binary.h index f6952fc4992..d005ebd10b3 100644 --- a/tfjs-backend-wasm/src/cc/binary.h +++ b/tfjs-backend-wasm/src/cc/binary.h @@ -87,6 +87,15 @@ inline void compare_bool(const int a_id, const int b_id, const int out_id, out_info.b_write(), operation); } +inline void logical(const int a_id, const int b_id, const int out_id, + bool operation(bool, bool)) { + auto& a_info = backend::get_tensor_info(a_id); + auto& b_info = backend::get_tensor_info(b_id); + auto& out_info = backend::get_tensor_info_out(out_id); + binary_impl(a_info.b(), a_info.size, b_info.b(), b_info.size, + out_info.b_write(), operation); +} + typedef xnn_status (*xnn_create_binary_op)(float, float, uint32_t, xnn_operator_t*); typedef xnn_status (*xnn_setup_binary_op)(xnn_operator_t, size_t, const size_t*, diff --git a/tfjs-backend-wasm/src/cc/kernels/Less.cc b/tfjs-backend-wasm/src/cc/kernels/Less.cc new file mode 100644 index 00000000000..73883506a64 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Less.cc @@ -0,0 +1,59 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "src/cc/binary.h" +#include "src/cc/util.h" + +namespace { +template +inline bool less(T a, T b) { + return a < b; +} +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Less(const int a_id, const size_t* a_shape_ptr, const int a_shape_len, + const int b_id, const size_t* b_shape_ptr, const int b_shape_len, + const DType input_type, const int out_id) { + switch (input_type) { + case DType::float32: + compare_f32(a_id, b_id, out_id, less); + break; + case DType::int32: + compare_i32(a_id, b_id, out_id, less); + break; + case DType::boolean: + compare_bool(a_id, b_id, out_id, less); + break; + default: + util::warn( + "Less for tensor ids %d and %d failed. Unsupported input_type %d", + a_id, b_id, input_type); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/LessEqual.cc b/tfjs-backend-wasm/src/cc/kernels/LessEqual.cc new file mode 100644 index 00000000000..bcf99fe1417 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/LessEqual.cc @@ -0,0 +1,60 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "src/cc/binary.h" +#include "src/cc/util.h" + +namespace { +template +inline bool lessEqual(T a, T b) { + return a <= b; +} +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void LessEqual(const int a_id, const size_t* a_shape_ptr, const int a_shape_len, + const int b_id, const size_t* b_shape_ptr, const int b_shape_len, + const DType input_type, const int out_id) { + switch (input_type) { + case DType::float32: + compare_f32(a_id, b_id, out_id, lessEqual); + break; + case DType::int32: + compare_i32(a_id, b_id, out_id, lessEqual); + break; + case DType::boolean: + compare_bool(a_id, b_id, out_id, lessEqual); + break; + default: + util::warn( + "LessEqual for tensor ids %d and %d failed." + "Unsupported input_type %d", + a_id, b_id, input_type); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/LogicalAnd.cc b/tfjs-backend-wasm/src/cc/kernels/LogicalAnd.cc new file mode 100644 index 00000000000..0ee8e2946bd --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/LogicalAnd.cc @@ -0,0 +1,52 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "src/cc/binary.h" +#include "src/cc/util.h" + +namespace { +inline bool logical_and(bool a, bool b) { return a && b; } +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void LogicalAnd(const int a_id, const size_t* a_shape_ptr, + const int a_shape_len, const int b_id, + const size_t* b_shape_ptr, const int b_shape_len, + const DType input_type, const int out_id) { + switch (input_type) { + case DType::boolean: + compare_bool(a_id, b_id, out_id, logical_and); + break; + default: + util::warn( + "LogicalAnd for tensor ids %d and %d failed. Unsupported input_type " + "%d", + a_id, b_id, input_type); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/Tile.cc b/tfjs-backend-wasm/src/cc/kernels/Tile.cc new file mode 100644 index 00000000000..04b97d9836e --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Tile.cc @@ -0,0 +1,90 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include "src/cc/backend.h" +#include "src/cc/util.h" + +namespace { +template +void tile_slow(const T* x_data, const std::vector& x_shape, + const std::vector& new_shape, T* out_data) { + const size_t x_rank = x_shape.size(); + const std::vector x_strides = tfjs::util::compute_strides(x_shape); + + const size_t out_size = tfjs::util::size_from_shape(new_shape); + const std::vector out_strides = + tfjs::util::compute_strides(new_shape); + + for (size_t i = 0; i < out_size; ++i) { + const std::vector new_loc = + tfjs::util::offset_to_loc(i, out_strides); + + std::vector original_loc(x_rank); + + for (size_t j = 0; j < original_loc.size(); ++j) { + original_loc[j] = new_loc[j] % x_shape[j]; + } + + const size_t original_index = + tfjs::util::loc_to_offset(original_loc, x_strides); + + out_data[i] = x_data[original_index]; + } +} +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Tile(const size_t x_id, const size_t* x_shape_ptr, + const size_t x_shape_length, const size_t* new_shape_ptr, + const size_t new_shape_length, const DType dtype, + const size_t out_id) { + auto x_shape = std::vector(x_shape_ptr, x_shape_ptr + x_shape_length); + auto new_shape = + std::vector(new_shape_ptr, new_shape_ptr + new_shape_length); + auto& x_info = backend::get_tensor_info(x_id); + auto& out_info = backend::get_tensor_info_out(out_id); + + switch (dtype) { + case DType::float32: + tile_slow(x_info.f32(), x_shape, new_shape, out_info.f32_write()); + break; + case DType::int32: + tile_slow(x_info.i32(), x_shape, new_shape, + out_info.i32_write()); + break; + case DType::boolean: + tile_slow(x_info.b(), x_shape, new_shape, out_info.b_write()); + break; + default: + util::warn("Tile for tensor id %d failed. Unknown dtype %d", x_id, dtype); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/Add.ts b/tfjs-backend-wasm/src/kernels/Add.ts index 41b984afaf3..2a57026ea95 100644 --- a/tfjs-backend-wasm/src/kernels/Add.ts +++ b/tfjs-backend-wasm/src/kernels/Add.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = true; -registerBinaryKernel('Add', supportsBroadcast); +const supportsFullBroadcast = true; +registerBinaryKernel('Add', supportsFullBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/Div.ts b/tfjs-backend-wasm/src/kernels/Div.ts index 0e0780d4caa..7baef07b5f5 100644 --- a/tfjs-backend-wasm/src/kernels/Div.ts +++ b/tfjs-backend-wasm/src/kernels/Div.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = false; -registerBinaryKernel('Div', supportsBroadcast); +const supportsFullBroadcast = false; +registerBinaryKernel('Div', supportsFullBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/FloorDiv.ts b/tfjs-backend-wasm/src/kernels/FloorDiv.ts index 9e8b15d4d7e..7b3d852eda7 100644 --- a/tfjs-backend-wasm/src/kernels/FloorDiv.ts +++ b/tfjs-backend-wasm/src/kernels/FloorDiv.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = false; -registerBinaryKernel('FloorDiv', supportsBroadcast); +const supportsFullBroadcast = false; +registerBinaryKernel('FloorDiv', supportsFullBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/Greater.ts b/tfjs-backend-wasm/src/kernels/Greater.ts index e145cd17d4c..286b1631d6c 100644 --- a/tfjs-backend-wasm/src/kernels/Greater.ts +++ b/tfjs-backend-wasm/src/kernels/Greater.ts @@ -15,6 +15,6 @@ * ============================================================================= */ -import { registerBinaryKernel } from './binary_kernel'; -const supportsBroadcast = true; -registerBinaryKernel('Greater', supportsBroadcast, 'bool'); +import {registerBinaryKernel} from './binary_kernel'; +const supportsFullBroadcast = false; +registerBinaryKernel('Greater', supportsFullBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/GreaterEqual.ts b/tfjs-backend-wasm/src/kernels/GreaterEqual.ts index 9563c4b716d..7ba75e543e5 100644 --- a/tfjs-backend-wasm/src/kernels/GreaterEqual.ts +++ b/tfjs-backend-wasm/src/kernels/GreaterEqual.ts @@ -15,6 +15,6 @@ * ============================================================================= */ -import { registerBinaryKernel } from './binary_kernel'; -const supportsBroadcast = true; -registerBinaryKernel('GreaterEqual', supportsBroadcast, 'bool'); +import {registerBinaryKernel} from './binary_kernel'; +const supportsFullBroadcast = false; +registerBinaryKernel('GreaterEqual', supportsFullBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/Less.ts b/tfjs-backend-wasm/src/kernels/Less.ts new file mode 100644 index 00000000000..321c07cc998 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Less.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {registerBinaryKernel} from './binary_kernel'; +const supportsFullBroadcast = false; +registerBinaryKernel('Less', supportsFullBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/LessEqual.ts b/tfjs-backend-wasm/src/kernels/LessEqual.ts new file mode 100644 index 00000000000..9ca0688fa58 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/LessEqual.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {registerBinaryKernel} from './binary_kernel'; +const supportsFullBroadcast = false; +registerBinaryKernel('LessEqual', supportsFullBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/LogicalAnd.ts b/tfjs-backend-wasm/src/kernels/LogicalAnd.ts new file mode 100644 index 00000000000..4aa300e9608 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/LogicalAnd.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {registerBinaryKernel} from './binary_kernel'; +const supportsFullBroadcast = false; +registerBinaryKernel('LogicalAnd', supportsFullBroadcast, 'bool'); diff --git a/tfjs-backend-wasm/src/kernels/Maximum.ts b/tfjs-backend-wasm/src/kernels/Maximum.ts index 647c54f8977..250834101a7 100644 --- a/tfjs-backend-wasm/src/kernels/Maximum.ts +++ b/tfjs-backend-wasm/src/kernels/Maximum.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = false; -registerBinaryKernel('Maximum', supportsBroadcast); +const supportsFullBroadcast = false; +registerBinaryKernel('Maximum', supportsFullBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/Minimum.ts b/tfjs-backend-wasm/src/kernels/Minimum.ts index 0c167b7c318..e266420b7e6 100644 --- a/tfjs-backend-wasm/src/kernels/Minimum.ts +++ b/tfjs-backend-wasm/src/kernels/Minimum.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = false; -registerBinaryKernel('Minimum', supportsBroadcast); +const supportsFullBroadcast = false; +registerBinaryKernel('Minimum', supportsFullBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/Mul.ts b/tfjs-backend-wasm/src/kernels/Mul.ts index f50cb0cc105..bf93abf9fb6 100644 --- a/tfjs-backend-wasm/src/kernels/Mul.ts +++ b/tfjs-backend-wasm/src/kernels/Mul.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = true; -registerBinaryKernel('Mul', supportsBroadcast); +const supportsFullBroadcast = true; +registerBinaryKernel('Mul', supportsFullBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/Sub.ts b/tfjs-backend-wasm/src/kernels/Sub.ts index 80768ace259..935d4704d62 100644 --- a/tfjs-backend-wasm/src/kernels/Sub.ts +++ b/tfjs-backend-wasm/src/kernels/Sub.ts @@ -16,5 +16,5 @@ */ import {registerBinaryKernel} from './binary_kernel'; -const supportsBroadcast = true; -registerBinaryKernel('Sub', supportsBroadcast); +const supportsFullBroadcast = true; +registerBinaryKernel('Sub', supportsFullBroadcast); diff --git a/tfjs-backend-wasm/src/kernels/Tile.ts b/tfjs-backend-wasm/src/kernels/Tile.ts new file mode 100644 index 00000000000..e18ed2e604e --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Tile.ts @@ -0,0 +1,74 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel} from '@tensorflow/tfjs-core'; +import {TensorInfo} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; +import {CppDType} from './types'; + +interface TileInputs extends NamedTensorInfoMap { + x: TensorInfo; +} + +interface TileAttrs extends NamedAttrMap { + reps: number[]; +} + +let wasmTile: ( + xId: number, xShape: Uint8Array, xShapeSize: number, newShape: Uint8Array, + newShapeSize: number, dtype: number, outId: number) => void; + +function setup(backend: BackendWasm) { + wasmTile = backend.wasm.cwrap('Tile', null /* void */, [ + 'number', // x_id + 'array', // x_shape + 'number', // x_shape.length + 'array', // new_shape + 'number', // new_shape.length + 'number' // out_id + ]); +} + +function tile( + args: {inputs: TileInputs, backend: BackendWasm, attrs: TileAttrs}) { + const {inputs, backend, attrs} = args; + const {x} = inputs; + const xId = backend.dataIdMap.get(x.dataId).id; + const {reps} = attrs; + + const newShape: number[] = new Array(x.shape.length); + for (let i = 0; i < newShape.length; i++) { + newShape[i] = x.shape[i] * reps[i]; + } + const xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer); + const newShapeBytes = new Uint8Array(new Int32Array(newShape).buffer); + + const out = backend.makeOutput(newShape, x.dtype); + const outId = backend.dataIdMap.get(out.dataId).id; + wasmTile( + xId, xShapeBytes, x.shape.length, newShapeBytes, newShape.length, + CppDType[out.dtype], outId); + return out; +} + +registerKernel({ + kernelName: 'Tile', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: tile +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index a0941aa873c..626114d6e37 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -38,6 +38,9 @@ import './FusedConv2D'; import './FusedDepthwiseConv2D'; import './Greater'; import './GreaterEqual'; +import './LogicalAnd'; +import './Less'; +import './LessEqual'; import './Log'; import './Max'; import './Maximum'; @@ -57,5 +60,6 @@ import './Slice'; import './Square'; import './Sub'; import './Sum'; +import './Tile'; import './Transpose'; import './Unpack'; diff --git a/tfjs-backend-wasm/src/kernels/binary_kernel.ts b/tfjs-backend-wasm/src/kernels/binary_kernel.ts index 25723c5119f..da7cd22f274 100644 --- a/tfjs-backend-wasm/src/kernels/binary_kernel.ts +++ b/tfjs-backend-wasm/src/kernels/binary_kernel.ts @@ -21,7 +21,7 @@ import {BackendWasm} from '../backend_wasm'; import {CppDType} from './types'; export function registerBinaryKernel( - kernelName: string, supportsBroadcast: boolean, dtype?: DataType) { + kernelName: string, supportsFullBroadcast: boolean, dtype?: DataType) { let wasmFunc: (aId: number, aShape: Uint8Array, aShapeLen: number, bId: number, bShape: Uint8Array, bShapeLen: number, dtype: number, outId: number) => @@ -63,7 +63,7 @@ export function registerBinaryKernel( aId, aShapeBytes, a.shape.length, bId, bShapeBytes, b.shape.length, CppDType[a.dtype], outId); - if (supportsBroadcast) { + if (supportsFullBroadcast) { kernelFunc(); return out; } @@ -76,7 +76,9 @@ export function registerBinaryKernel( kernelFunc(); return out; } else { - throw new Error('Broadcasting along inner dims is not yet supported'); + throw new Error( + `Broadcasting along outer dims is not yet ` + + `supported for ${kernelName}.`); } } diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 2e257e6f90a..ef7af28fdc8 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -221,7 +221,7 @@ const TEST_FILTERS: TestFilter[] = [ { include: 'greater ', excludes: [ - 'broadcasting Tensor2D shapes', // Broadcasting along inner dims not + 'broadcasting Tensor2D shapes', // Broadcasting along outer dims not // supported yet. 'broadcasting Tensor3D shapes', // Same as above. 'broadcasting Tensor4D shapes' // Same as above. @@ -231,19 +231,56 @@ const TEST_FILTERS: TestFilter[] = [ include: 'greaterEqual', excludes: [ 'gradient', // Not yet implemented. - 'broadcasting Tensor2D shapes', // Broadcasting along inner dims not + 'broadcasting Tensor2D shapes', // Broadcasting along outer dims not // supported yet. 'broadcasting Tensor3D shapes', // Same as above. 'broadcasting Tensor4D shapes' // Same as above. ] }, + { + include: 'less ', + excludes: [ + 'broadcasting Tensor2D shapes', // Broadcasting along outer dims not + // supported yet. + 'broadcasting Tensor3D shapes', // Same as above. + 'broadcasting Tensor3D float32', // Same as above. + 'broadcasting Tensor4D shapes' // Same as above. + ] + }, + { + include: 'lessEqual', + excludes: [ + 'gradient', // Not yet implemented. + 'broadcasting Tensor2D shapes', // Broadcasting along outer dims not + // supported yet. + 'broadcasting Tensor3D shapes', // Same as above. + 'broadcasting Tensor3D float32', // Same as above. + 'broadcasting Tensor4D shapes' // Same as above. + ] + }, { include: 'mean ', excludes: [ 'axis=0', // Reduction not supported along inner dimensions. ] }, - {startsWith: 'sum '} + {startsWith: 'sum '}, + { + startsWith: 'logicalAnd ', + excludes: [ + 'broadcasting Tensor2D shapes', // Broadcasting along outer dimensions + // not yet supported. + 'broadcasting Tensor3D shapes', // Same as above. + 'broadcasting Tensor4D shapes', // Same as above. + ] + }, + { + startsWith: 'tile ', + excludes: [ + 'gradient', // Gradient not yet implemented. + 'string tensor' // String tensors not yet implemented. + ] + } ]; const customInclude = (testName: string) => { diff --git a/tfjs-core/src/backends/tile_impl.ts b/tfjs-core/src/backends/tile_impl.ts index a1cb8c9121c..dcf4e0d0258 100644 --- a/tfjs-core/src/backends/tile_impl.ts +++ b/tfjs-core/src/backends/tile_impl.ts @@ -35,8 +35,8 @@ export function tile( const newLoc = result.indexToLoc(i); const originalLoc: number[] = new Array(xBuf.rank); - for (let i = 0; i < originalLoc.length; i++) { - originalLoc[i] = newLoc[i] % xBuf.shape[i]; + for (let j = 0; j < originalLoc.length; j++) { + originalLoc[j] = newLoc[j] % xBuf.shape[j]; } const originalIndex = xBuf.locToIndex(originalLoc); diff --git a/tfjs-core/src/ops/array_ops.ts b/tfjs-core/src/ops/array_ops.ts index 1b976b726d1..177c5670a97 100644 --- a/tfjs-core/src/ops/array_ops.ts +++ b/tfjs-core/src/ops/array_ops.ts @@ -568,13 +568,15 @@ function tile_(x: T|TensorLike, reps: number[]): T { } return xGrad as T; }; - return {$x: derX}; + return {x: derX}; }; + const inputsToSave = [$x]; + const attrs = {reps}; return ENGINE.runKernelFunc((backend, save) => { const res = backend.tile($x, reps); save([$x]); return res; - }, {$x}, grad); + }, {x: $x}, grad, 'Tile', attrs, inputsToSave); } /** diff --git a/tfjs-core/src/ops/compare.ts b/tfjs-core/src/ops/compare.ts index ffbab4792b2..7b32b7d8b89 100644 --- a/tfjs-core/src/ops/compare.ts +++ b/tfjs-core/src/ops/compare.ts @@ -90,7 +90,9 @@ function less_( [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); - return ENGINE.runKernelFunc(backend => backend.less($a, $b), {$a, $b}) as T; + return ENGINE.runKernelFunc( + backend => backend.less($a, $b), {a: $a, b: $b}, null /* grad */, + 'Less') as T; } /** @@ -166,8 +168,11 @@ function lessEqual_( [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); - return ENGINE.runKernelFunc(backend => backend.lessEqual($a, $b), {$a, $b}) as - T; + return ENGINE.runKernelFunc((backend, save) => { + const res = backend.lessEqual($a, $b); + save([$a, $b]); + return res; + }, {a: $a, b: $b}, null /* grad */, 'LessEqual') as T; } function lessEqualStrict_( @@ -203,9 +208,8 @@ function greater_( assertAndGetBroadcastShape($a.shape, $b.shape); return ENGINE.runKernelFunc( - backend => backend.greater($a, $b), - {a: $a, b: $b}, null /* grad */, 'Greater' - ) as T; + backend => backend.greater($a, $b), {a: $a, b: $b}, + null /* grad */, 'Greater') as T; } function greaterStrict_(a: T|TensorLike, b: T|TensorLike): T { diff --git a/tfjs-core/src/ops/logical_ops.ts b/tfjs-core/src/ops/logical_ops.ts index 3b83d2aa6d9..9000a669705 100644 --- a/tfjs-core/src/ops/logical_ops.ts +++ b/tfjs-core/src/ops/logical_ops.ts @@ -63,7 +63,8 @@ function logicalAnd_( assertAndGetBroadcastShape($a.shape, $b.shape); return ENGINE.runKernelFunc( - backend => backend.logicalAnd($a, $b), {$a, $b}) as T; + backend => backend.logicalAnd($a, $b), {a: $a, b: $b}, + null /* grad */, 'LogicalAnd') as T; } /**