diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 71a1bf001d2..9ee42a3401a 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -216,6 +216,7 @@ tfjs_cc_library( ":NonMaxSuppressionV3", ":NonMaxSuppressionV5", ":NotEqual", + ":OneHot", ":PadV2", ":Pow", ":Prelu", @@ -653,6 +654,14 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "OneHot", + srcs = ["kernels/OneHot.cc"], + deps = [ + ":backend", + ], +) + tfjs_cc_library( name = "PadV2", srcs = ["kernels/PadV2.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/OneHot.cc b/tfjs-backend-wasm/src/cc/kernels/OneHot.cc new file mode 100644 index 00000000000..1ff8f557759 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/OneHot.cc @@ -0,0 +1,55 @@ +/* Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include "src/cc/backend.h" + +namespace tfjs { +namespace wasm { +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif + +void OneHot(const size_t indices_id, const size_t depth, const int32_t on_value, + const int32_t off_value, const size_t out_id) { + auto& indices_info = backend::get_tensor_info(indices_id); + auto& out_info = backend::get_tensor_info_out(out_id); + + const int* indices_buf = indices_info.i32(); + const size_t indices_size = indices_info.size; + + int* out_buf = out_info.i32_write(); + const size_t out_size = out_info.size; + + // Initialize output with off_value. + std::fill(out_buf, out_buf + out_size, off_value); + + for (size_t i = 0; i < indices_size; ++i) { + if (indices_buf[i] >= 0 && indices_buf[i] < depth) { + out_buf[i * depth + indices_buf[i]] = on_value; + } + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/OneHot.ts b/tfjs-backend-wasm/src/kernels/OneHot.ts new file mode 100644 index 00000000000..4462347d4de --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/OneHot.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelFunc, OneHot, OneHotAttrs, OneHotInputs, registerKernel} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +let wasmOneHot: ( + indicesId: number, depth: number, onValue: number, offValue: number, + outId: number) => void; + +function setup(backend: BackendWasm) { + wasmOneHot = backend.wasm.cwrap('OneHot', null /* void */, [ + 'number', // indices_id + 'number', // depth, + 'number', // onValue + 'number', // offValue + 'number' // out_id + ]); +} + +function oneHot( + args: {inputs: OneHotInputs, attrs: OneHotAttrs, backend: BackendWasm}) { + const {inputs, backend, attrs} = args; + const {indices} = inputs; + const {depth, onValue, offValue} = attrs; + + const out = backend.makeOutput([...indices.shape, depth], 'int32'); + const outId = backend.dataIdMap.get(out.dataId).id; + + const indicesData = backend.dataIdMap.get(indices.dataId); + const indicesId = indicesData.id; + + wasmOneHot(indicesId, depth, onValue, offValue, outId); + + return out; +} + +registerKernel({ + kernelName: OneHot, + backendName: 'wasm', + setupFunc: setup, + kernelFunc: oneHot as {} as KernelFunc, +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index e03b9a2f3e5..5091d867ba6 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -58,6 +58,7 @@ import './Neg'; import './NonMaxSuppressionV3'; import './NonMaxSuppressionV5'; import './NotEqual'; +import './OneHot'; import './OnesLike'; import './PadV2'; import './Pow'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 8336a0eac83..794acd95e12 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -218,10 +218,8 @@ const TEST_FILTERS: TestFilter[] = [ 'gradient' // Split is not yet implemented ] }, - { - include: 'transpose', - excludes: ['oneHot'] // oneHot not yet implemented. - }, + {include: 'transpose'}, + {include: 'oneHot'}, {include: 'split'}, {include: 'pad ', excludes: ['complex', 'zerosLike']}, {include: 'clip', excludes: ['gradient']}, diff --git a/tfjs-core/src/ops/one_hot.ts b/tfjs-core/src/ops/one_hot.ts index 277bd3e1b03..61d0031dde8 100644 --- a/tfjs-core/src/ops/one_hot.ts +++ b/tfjs-core/src/ops/one_hot.ts @@ -18,7 +18,7 @@ import {ENGINE, ForwardFunc} from '../engine'; import {OneHot, OneHotAttrs, OneHotInputs} from '../kernel_names'; import {NamedAttrMap} from '../kernel_registry'; -import {Tensor, Tensor1D} from '../tensor'; +import {Tensor} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; @@ -50,15 +50,13 @@ function oneHot_( if (depth < 2) { throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`); } - let $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32'); + const $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32'); const outShape = [...$indices.shape, depth]; - $indices = $indices.flatten(); const forward: ForwardFunc = (backend, save) => { save([$indices]); return reshape( - backend.oneHot($indices as Tensor1D, depth, onValue, offValue), - outShape); + backend.oneHot($indices.flatten(), depth, onValue, offValue), outShape); }; const inputs: OneHotInputs = {indices: $indices};