diff --git a/tfjs-backend-wasm/src/cc/BUILD.bazel b/tfjs-backend-wasm/src/cc/BUILD.bazel index fb20e020b64..4effdd43425 100644 --- a/tfjs-backend-wasm/src/cc/BUILD.bazel +++ b/tfjs-backend-wasm/src/cc/BUILD.bazel @@ -293,6 +293,8 @@ tfjs_cc_library( name = "all_kernels", deps = [ ":Abs", + ":Acos", + ":Acosh", ":Add", ":AddN", ":All", @@ -383,6 +385,26 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "Acos", + srcs = ["kernels/Acos.cc"], + deps = [ + ":backend", + ":unary", + ":util", + ], +) + +tfjs_cc_library( + name = "Acosh", + srcs = ["kernels/Acosh.cc"], + deps = [ + ":backend", + ":unary", + ":util", + ], +) + tfjs_cc_library( name = "Add", srcs = ["kernels/Add.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/Acos.cc b/tfjs-backend-wasm/src/cc/kernels/Acos.cc new file mode 100644 index 00000000000..166b2b8c97d --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Acos.cc @@ -0,0 +1,61 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "tfjs-backend-wasm/src/cc/backend.h" +#include "tfjs-backend-wasm/src/cc/unary.h" +#include "tfjs-backend-wasm/src/cc/util.h" + +namespace { + +template +inline T acos_impl(T n) { + return static_cast(std::acosf(static_cast(n))); +} + +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Acos(const int x_id, const DType dtype, const int out_id) { + switch (dtype) { + case DType::float32: + unary_f32(x_id, out_id, acos_impl); + break; + case DType::int32: + unary_i32(x_id, out_id, acos_impl); + break; + default: + util::warn("Acos for tensor id %d failed. Unsupported dtype %d", x_id, + dtype); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/Acosh.cc b/tfjs-backend-wasm/src/cc/kernels/Acosh.cc new file mode 100644 index 00000000000..60d75c36c45 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/Acosh.cc @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "tfjs-backend-wasm/src/cc/backend.h" +#include "tfjs-backend-wasm/src/cc/unary.h" +#include "tfjs-backend-wasm/src/cc/util.h" + +namespace { + +template +inline T acosh_impl(T n) { + return static_cast(std::acoshf(static_cast(n))); +} + +} // namespace + +namespace tfjs { +namespace wasm { +// We use C-style API to interface with Javascript. +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif +void Acosh(const int x_id, const DType dtype, const int out_id) { + switch (dtype) { + case DType::float32: + unary_f32(x_id, out_id, acosh_impl); + break; + default: + util::warn("Acosh for tensor id %d failed. Unsupported dtype %d", x_id, + dtype); + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/cc/kernels/Atan.cc b/tfjs-backend-wasm/src/cc/kernels/Atan.cc index acecfa5bc3c..3744ef284b6 100644 --- a/tfjs-backend-wasm/src/cc/kernels/Atan.cc +++ b/tfjs-backend-wasm/src/cc/kernels/Atan.cc @@ -28,8 +28,8 @@ namespace { template -inline float atan_impl(T n) { - return std::atanf(static_cast(n)); +inline T atan_impl(T n) { + return static_cast(std::atanf(static_cast(n))); } } // namespace @@ -48,7 +48,7 @@ void Atan(const int x_id, const DType dtype, const int out_id) { unary_f32(x_id, out_id, atan_impl); break; case DType::int32: - unary_i32_with_f32_out(x_id, out_id, atan_impl); + unary_i32(x_id, out_id, atan_impl); break; default: util::warn("Atan for tensor id %d failed. Unsupported dtype %d", x_id, diff --git a/tfjs-backend-wasm/src/kernels/Acos.ts b/tfjs-backend-wasm/src/kernels/Acos.ts new file mode 100644 index 00000000000..cc9beb82766 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Acos.ts @@ -0,0 +1,22 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Acos, KernelConfig} from '@tensorflow/tfjs-core'; + +import {createUnaryKernelConfig} from './unary_kernel'; + +export const acosConfig: KernelConfig = createUnaryKernelConfig(Acos); diff --git a/tfjs-backend-wasm/src/kernels/Acosh.ts b/tfjs-backend-wasm/src/kernels/Acosh.ts new file mode 100644 index 00000000000..9fca53298c9 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Acosh.ts @@ -0,0 +1,22 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Acosh, KernelConfig} from '@tensorflow/tfjs-core'; + +import {createUnaryKernelConfig} from './unary_kernel'; + +export const acoshConfig: KernelConfig = createUnaryKernelConfig(Acosh); diff --git a/tfjs-backend-wasm/src/register_all_kernels.ts b/tfjs-backend-wasm/src/register_all_kernels.ts index 6d45bf403d5..d9390d53ab3 100644 --- a/tfjs-backend-wasm/src/register_all_kernels.ts +++ b/tfjs-backend-wasm/src/register_all_kernels.ts @@ -21,6 +21,8 @@ import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; import {_fusedMatMulConfig} from './kernels/_FusedMatMul'; import {absConfig} from './kernels/Abs'; +import {acosConfig} from './kernels/Acos'; +import {acoshConfig} from './kernels/Acosh'; import {addConfig} from './kernels/Add'; import {addNConfig} from './kernels/AddN'; import {allConfig} from './kernels/All'; @@ -136,6 +138,8 @@ import {zerosLikeConfig} from './kernels/ZerosLike'; const kernelConfigs: KernelConfig[] = [ _fusedMatMulConfig, absConfig, + acosConfig, + acoshConfig, addConfig, addNConfig, allConfig, diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 9e28b59c693..33866139bdd 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -403,6 +403,8 @@ const TEST_FILTERS: TestFilter[] = [ {include: 'reciprocal'}, {include: 'isNaN'}, {include: 'atan '}, + {include: 'acos '}, + {include: 'acosh '}, ]; const customInclude = (testName: string) => {