From 45c540182ca1467b7c962fbea73df1fc8acf20db Mon Sep 17 00:00:00 2001 From: Chunnien Chan Date: Sun, 8 Jan 2023 21:36:53 -0800 Subject: [PATCH 1/3] Add acos and acosh kernels --- tfjs-backend-wasm/src/cc/BUILD | 22 +++++++ tfjs-backend-wasm/src/cc/kernels/Acos.cc | 64 +++++++++++++++++++ tfjs-backend-wasm/src/cc/kernels/Acosh.cc | 64 +++++++++++++++++++ tfjs-backend-wasm/src/kernels/Acos.ts | 22 +++++++ tfjs-backend-wasm/src/kernels/Acosh.ts | 22 +++++++ tfjs-backend-wasm/src/register_all_kernels.ts | 4 ++ tfjs-backend-wasm/src/setup_test.ts | 2 + 7 files changed, 200 insertions(+) create mode 100644 tfjs-backend-wasm/src/cc/kernels/Acos.cc create mode 100644 tfjs-backend-wasm/src/cc/kernels/Acosh.cc create mode 100644 tfjs-backend-wasm/src/kernels/Acos.ts create mode 100644 tfjs-backend-wasm/src/kernels/Acosh.ts diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index fb20e020b64..4effdd43425 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -293,6 +293,8 @@ tfjs_cc_library( name = "all_kernels", deps = [ ":Abs", + ":Acos", + ":Acosh", ":Add", ":AddN", ":All", @@ -383,6 +385,26 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "Acos", + srcs = ["kernels/Acos.cc"], + deps = [ + ":backend", + ":unary", + ":util", + ], +) + +tfjs_cc_library( + name = "Acosh", + srcs = ["kernels/Acosh.cc"], + deps = [ + ":backend", + ":unary", + ":util", + ], +) + tfjs_cc_library( name = "Add", srcs = ["kernels/Add.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/Acos.cc b/tfjs-backend-wasm/src/cc/kernels/Acos.cc new file mode 100644 index 00000000000..f5ce8077d5c --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Acos.cc @@ -0,0 +1,64 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "tfjs-backend-wasm/src/cc/backend.h" +#include "tfjs-backend-wasm/src/cc/unary.h" +#include "tfjs-backend-wasm/src/cc/util.h" + +namespace { + +template +inline T acos_impl(T n) { + return static_cast(std::acosf(static_cast(n))); +} + +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Acos(const int x_id, const DType dtype, const int out_id) { + switch (dtype) { + case DType::float32: + unary_f32(x_id, out_id, acos_impl); + break; + case DType::int32: + unary_i32(x_id, out_id, acos_impl); + break; + case DType::boolean: + unary_bool(x_id, out_id, acos_impl); + break; + default: + util::warn("Acos for tensor id %d failed. Unsupported dtype %d", x_id, + dtype); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/Acosh.cc b/tfjs-backend-wasm/src/cc/kernels/Acosh.cc new file mode 100644 index 00000000000..040fb2f0a34 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Acosh.cc @@ -0,0 +1,64 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "tfjs-backend-wasm/src/cc/backend.h" +#include "tfjs-backend-wasm/src/cc/unary.h" +#include "tfjs-backend-wasm/src/cc/util.h" + +namespace { + +template +inline T acosh_impl(T n) { + return static_cast(std::acoshf(static_cast(n))); +} + +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Acosh(const int x_id, const DType dtype, const int out_id) { + switch (dtype) { + case DType::float32: + unary_f32(x_id, out_id, acosh_impl); + break; + case DType::int32: + unary_i32(x_id, out_id, acosh_impl); + break; + case DType::boolean: + unary_bool(x_id, out_id, acosh_impl); + break; + default: + util::warn("Acosh for tensor id %d failed. Unsupported dtype %d", x_id, + dtype); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/Acos.ts b/tfjs-backend-wasm/src/kernels/Acos.ts new file mode 100644 index 00000000000..cc9beb82766 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Acos.ts @@ -0,0 +1,22 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Acos, KernelConfig} from '@tensorflow/tfjs-core'; + +import {createUnaryKernelConfig} from './unary_kernel'; + +export const acosConfig: KernelConfig = createUnaryKernelConfig(Acos); diff --git a/tfjs-backend-wasm/src/kernels/Acosh.ts b/tfjs-backend-wasm/src/kernels/Acosh.ts new file mode 100644 index 00000000000..9fca53298c9 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Acosh.ts @@ -0,0 +1,22 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Acosh, KernelConfig} from '@tensorflow/tfjs-core'; + +import {createUnaryKernelConfig} from './unary_kernel'; + +export const acoshConfig: KernelConfig = createUnaryKernelConfig(Acosh); diff --git a/tfjs-backend-wasm/src/register_all_kernels.ts b/tfjs-backend-wasm/src/register_all_kernels.ts index 6d45bf403d5..d9390d53ab3 100644 --- a/tfjs-backend-wasm/src/register_all_kernels.ts +++ b/tfjs-backend-wasm/src/register_all_kernels.ts @@ -21,6 +21,8 @@ import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; import {_fusedMatMulConfig} from './kernels/_FusedMatMul'; import {absConfig} from './kernels/Abs'; +import {acosConfig} from './kernels/Acos'; +import {acoshConfig} from './kernels/Acosh'; import {addConfig} from './kernels/Add'; import {addNConfig} from './kernels/AddN'; import {allConfig} from './kernels/All'; @@ -136,6 +138,8 @@ import {zerosLikeConfig} from './kernels/ZerosLike'; const kernelConfigs: KernelConfig[] = [ _fusedMatMulConfig, absConfig, + acosConfig, + acoshConfig, addConfig, addNConfig, allConfig, diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 9e28b59c693..33866139bdd 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -403,6 +403,8 @@ const TEST_FILTERS: TestFilter[] = [ {include: 'reciprocal'}, {include: 'isNaN'}, {include: 'atan '}, + {include: 'acos '}, + {include: 'acosh '}, ]; const customInclude = (testName: string) => { From c1de872216f86f231c23f72e8997ec71aef7bbac Mon Sep 17 00:00:00 2001 From: Chunnien Chan Date: Sun, 8 Jan 2023 21:37:02 -0800 Subject: [PATCH 2/3] Update atan kernel --- tfjs-backend-wasm/src/cc/kernels/Atan.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/Atan.cc b/tfjs-backend-wasm/src/cc/kernels/Atan.cc index acecfa5bc3c..c26af72abe1 100644 --- a/tfjs-backend-wasm/src/cc/kernels/Atan.cc +++ b/tfjs-backend-wasm/src/cc/kernels/Atan.cc @@ -28,8 +28,8 @@ namespace { template -inline float atan_impl(T n) { - return std::atanf(static_cast(n)); +inline T atan_impl(T n) { + return static_cast(std::atanf(static_cast(n))); } } // namespace @@ -48,7 +48,10 @@ void Atan(const int x_id, const DType dtype, const int out_id) { unary_f32(x_id, out_id, atan_impl); break; case DType::int32: - unary_i32_with_f32_out(x_id, out_id, atan_impl); + unary_i32(x_id, out_id, atan_impl); + break; + case DType::boolean: + unary_bool(x_id, out_id, atan_impl); break; default: util::warn("Atan for tensor id %d failed. Unsupported dtype %d", x_id, From 8185b6083302c75f5f9303b45667c22a047b374d Mon Sep 17 00:00:00 2001 From: Chunnien Chan Date: Mon, 9 Jan 2023 14:43:30 -0800 Subject: [PATCH 3/3] Remove implementations for unsupported types --- tfjs-backend-wasm/src/cc/kernels/Acos.cc | 3 --- tfjs-backend-wasm/src/cc/kernels/Acosh.cc | 6 ------ tfjs-backend-wasm/src/cc/kernels/Atan.cc | 3 --- 3 files changed, 12 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/Acos.cc b/tfjs-backend-wasm/src/cc/kernels/Acos.cc index f5ce8077d5c..166b2b8c97d 100644 --- a/tfjs-backend-wasm/src/cc/kernels/Acos.cc +++ b/tfjs-backend-wasm/src/cc/kernels/Acos.cc @@ -50,9 +50,6 @@ void Acos(const int x_id, const DType dtype, const int out_id) { case DType::int32: unary_i32(x_id, out_id, acos_impl); break; - case DType::boolean: - unary_bool(x_id, out_id, acos_impl); - break; default: util::warn("Acos for tensor id %d failed. Unsupported dtype %d", x_id, dtype); diff --git a/tfjs-backend-wasm/src/cc/kernels/Acosh.cc b/tfjs-backend-wasm/src/cc/kernels/Acosh.cc index 040fb2f0a34..60d75c36c45 100644 --- a/tfjs-backend-wasm/src/cc/kernels/Acosh.cc +++ b/tfjs-backend-wasm/src/cc/kernels/Acosh.cc @@ -47,12 +47,6 @@ void Acosh(const int x_id, const DType dtype, const int out_id) { case DType::float32: unary_f32(x_id, out_id, acosh_impl); break; - case DType::int32: - unary_i32(x_id, out_id, acosh_impl); - break; - case DType::boolean: - unary_bool(x_id, out_id, acosh_impl); - break; default: util::warn("Acosh for tensor id %d failed. Unsupported dtype %d", x_id, dtype); diff --git a/tfjs-backend-wasm/src/cc/kernels/Atan.cc b/tfjs-backend-wasm/src/cc/kernels/Atan.cc index c26af72abe1..3744ef284b6 100644 --- a/tfjs-backend-wasm/src/cc/kernels/Atan.cc +++ b/tfjs-backend-wasm/src/cc/kernels/Atan.cc @@ -50,9 +50,6 @@ void Atan(const int x_id, const DType dtype, const int out_id) { case DType::int32: unary_i32(x_id, out_id, atan_impl); break; - case DType::boolean: - unary_bool(x_id, out_id, atan_impl); - break; default: util::warn("Atan for tensor id %d failed. Unsupported dtype %d", x_id, dtype);