diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index f4fa895d2dd..48463084eb4 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -258,6 +258,8 @@ tfjs_cc_library( ":SelectV2", ":Sigmoid", ":Softmax", + ":Square", + ":SquaredDifference", ":Sub", ":Tile", ":Transpose", @@ -894,6 +896,17 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "SquaredDifference", + srcs = ["kernels/SquaredDifference.cc"], + deps = [ + ":backend", + ":binary", + ":unary", + ":util", + ], +) + tfjs_cc_library( name = "Sub", srcs = ["kernels/Sub.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/SquaredDifference.cc b/tfjs-backend-wasm/src/cc/kernels/SquaredDifference.cc new file mode 100644 index 00000000000..7a40f0ee007 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/SquaredDifference.cc @@ -0,0 +1,68 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif +#include + +#include + +#include "src/cc/binary.h" +#include "src/cc/unary.h" +#include "src/cc/util.h" + +namespace { +template +inline T squared_diff(T a, T b) { + return (a - b) * (a - b); +} +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void SquaredDifference( + const size_t a_id, const size_t* a_shape_ptr, const size_t a_shape_len, + const size_t b_id, const size_t* b_shape_ptr, const size_t b_shape_len, + const DType dtype, const size_t out_id) { + switch (dtype) { + case DType::float32: + binary_xnn_f32(a_id, a_shape_ptr, a_shape_len, b_id, b_shape_ptr, + b_shape_len, out_id, xnn_create_subtract_nd_f32, + xnn_setup_subtract_nd_f32); + unary_xnn_f32(out_id, out_id, xnn_create_square_nc_f32, + xnn_setup_square_nc_f32); + break; + case DType::int32: + binary_i32(a_id, b_id, out_id, squared_diff); + break; + case DType::boolean: + binary_bool(a_id, b_id, out_id, squared_diff); + break; + default: + util::warn("SquaredDifference for tensor ids %d and %d failed. " + "Unknown dtype %d", a_id, b_id, dtype); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs + diff --git a/tfjs-backend-wasm/src/kernels/SquaredDifference.ts b/tfjs-backend-wasm/src/kernels/SquaredDifference.ts new file mode 100644 index 00000000000..e0127f812c9 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/SquaredDifference.ts @@ -0,0 +1,21 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {KernelConfig, SquaredDifference} from '@tensorflow/tfjs-core'; +import {createBinaryKernelConfig} from './binary_kernel'; +const supportsFullBroadcast = true; +export const squaredDifferenceConfig: KernelConfig = + createBinaryKernelConfig(SquaredDifference, supportsFullBroadcast); diff --git a/tfjs-backend-wasm/src/register_all_kernels.ts b/tfjs-backend-wasm/src/register_all_kernels.ts index 21b9926cacd..b21337772f5 100644 --- a/tfjs-backend-wasm/src/register_all_kernels.ts +++ b/tfjs-backend-wasm/src/register_all_kernels.ts @@ -85,6 +85,7 @@ import {softmaxConfig} from './kernels/Softmax'; import {splitVConfig} from './kernels/Split'; import {sqrtConfig} from './kernels/Sqrt'; import {squareConfig} from './kernels/Square'; +import {squaredDifferenceConfig} from './kernels/SquaredDifference'; import {subConfig} from './kernels/Sub'; import {sumConfig} from './kernels/Sum'; import {tanhConfig} from './kernels/Tanh'; @@ -161,6 +162,7 @@ const kernelConfigs: KernelConfig[] = [ splitVConfig, sqrtConfig, squareConfig, + squaredDifferenceConfig, subConfig, sumConfig, tanhConfig, diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 7fa56fa5254..9f84a680730 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -194,6 +194,7 @@ const TEST_FILTERS: TestFilter[] = [ {include: 'rotate '}, {include: 'flipLeftRight '}, {include: 'square '}, + {include: 'squaredDifference'}, { startsWith: 'min ', excludes: [