From 9941ef50e0daa4069bf384543e7111492e1a2a78 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Mon, 8 Jun 2020 11:30:59 -0700 Subject: [PATCH 1/3] {WASM} Add OneHot op --- tfjs-backend-wasm/src/cc/BUILD | 9 +++ tfjs-backend-wasm/src/cc/kernels/OneHot.cc | 55 +++++++++++++++++++ tfjs-backend-wasm/src/kernels/OneHot.ts | 58 ++++++++++++++++++++ tfjs-backend-wasm/src/kernels/all_kernels.ts | 1 + tfjs-backend-wasm/src/setup_test.ts | 6 +- tfjs-core/src/ops/one_hot.ts | 7 +-- 6 files changed, 128 insertions(+), 8 deletions(-) create mode 100644 tfjs-backend-wasm/src/cc/kernels/OneHot.cc create mode 100644 tfjs-backend-wasm/src/kernels/OneHot.ts diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 71a1bf001d2..26113879d93 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -216,6 +216,7 @@ tfjs_cc_library( ":NonMaxSuppressionV3", ":NonMaxSuppressionV5", ":NotEqual", + ":OneHot", ":PadV2", ":Pow", ":Prelu", @@ -643,6 +644,14 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "OneHot", + srcs = ["kernels/OneHot.cc"], + deps = [ + ":backend", + ], +) + tfjs_cc_library( name = "NonMaxSuppressionV5", srcs = ["kernels/NonMaxSuppressionV5.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/OneHot.cc b/tfjs-backend-wasm/src/cc/kernels/OneHot.cc new file mode 100644 index 00000000000..1ff8f557759 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/OneHot.cc @@ -0,0 +1,55 @@ +/* Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include "src/cc/backend.h" + +namespace tfjs { +namespace wasm { +extern "C" { + +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif + +void OneHot(const size_t indices_id, const size_t depth, const int32_t on_value, + const int32_t off_value, const size_t out_id) { + auto& indices_info = backend::get_tensor_info(indices_id); + auto& out_info = backend::get_tensor_info_out(out_id); + + const int* indices_buf = indices_info.i32(); + const size_t indices_size = indices_info.size; + + int* out_buf = out_info.i32_write(); + const size_t out_size = out_info.size; + + // Initialize output with off_value. + std::fill(out_buf, out_buf + out_size, off_value); + + for (size_t i = 0; i < indices_size; ++i) { + if (indices_buf[i] >= 0 && indices_buf[i] < depth) { + out_buf[i * depth + indices_buf[i]] = on_value; + } + } +} + +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/OneHot.ts b/tfjs-backend-wasm/src/kernels/OneHot.ts new file mode 100644 index 00000000000..380d1fb59a7 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/OneHot.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelFunc, OneHotAttrs, OneHotInputs, registerKernel} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +let wasmOneHot: ( + indicesId: number, depth: number, onValue: number, offValue: number, + outId: number) => void; + +function setup(backend: BackendWasm) { + wasmOneHot = backend.wasm.cwrap('OneHot', null /* void */, [ + 'number', // indices_id + 'number', // depth, + 'number', // onValue + 'number', // offValue + 'number' // out_id + ]); +} + +function oneHot( + args: {inputs: OneHotInputs, attrs: OneHotAttrs, backend: BackendWasm}) { + const {inputs, backend, attrs} = args; + const {indices} = inputs; + const {depth, onValue, offValue} = attrs; + + const out = backend.makeOutput([...indices.shape, depth], 'int32'); + const outId = backend.dataIdMap.get(out.dataId).id; + + const indicesData = backend.dataIdMap.get(indices.dataId); + const indicesId = indicesData.id; + + wasmOneHot(indicesId, depth, onValue, offValue, outId); + + return out; +} + +registerKernel({ + kernelName: 'OneHot', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: oneHot as {} as KernelFunc, +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index e03b9a2f3e5..5091d867ba6 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -58,6 +58,7 @@ import './Neg'; import './NonMaxSuppressionV3'; import './NonMaxSuppressionV5'; import './NotEqual'; +import './OneHot'; import './OnesLike'; import './PadV2'; import './Pow'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 8336a0eac83..794acd95e12 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -218,10 +218,8 @@ const TEST_FILTERS: TestFilter[] = [ 'gradient' // Split is not yet implemented ] }, - { - include: 'transpose', - excludes: ['oneHot'] // oneHot not yet implemented. - }, + {include: 'transpose'}, + {include: 'oneHot'}, {include: 'split'}, {include: 'pad ', excludes: ['complex', 'zerosLike']}, {include: 'clip', excludes: ['gradient']}, diff --git a/tfjs-core/src/ops/one_hot.ts b/tfjs-core/src/ops/one_hot.ts index 277bd3e1b03..d0d4ecb8baa 100644 --- a/tfjs-core/src/ops/one_hot.ts +++ b/tfjs-core/src/ops/one_hot.ts @@ -56,17 +56,16 @@ function oneHot_( const forward: ForwardFunc = (backend, save) => { save([$indices]); - return reshape( - backend.oneHot($indices as Tensor1D, depth, onValue, offValue), - outShape); + return backend.oneHot($indices as Tensor1D, depth, onValue, offValue); }; const inputs: OneHotInputs = {indices: $indices}; const attrs: OneHotAttrs = {depth, onValue, offValue}; - return ENGINE.runKernelFunc( + const res = ENGINE.runKernelFunc( forward, inputs as unknown as NamedTensorMap, null /* grad */, OneHot, attrs as unknown as NamedAttrMap); + return reshape(res, outShape); } export const oneHot = op({oneHot_}); From f3d6b39b377934b2432fa145a402a57b2f094900 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Mon, 15 Jun 2020 11:08:04 -0700 Subject: [PATCH 2/3] Reorder onehot in build file --- tfjs-backend-wasm/src/cc/BUILD | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 26113879d93..9ee42a3401a 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -645,20 +645,20 @@ tfjs_cc_library( ) tfjs_cc_library( - name = "OneHot", - srcs = ["kernels/OneHot.cc"], + name = "NonMaxSuppressionV5", + srcs = ["kernels/NonMaxSuppressionV5.cc"], deps = [ ":backend", + ":non_max_suppression_impl", + ":util", ], ) tfjs_cc_library( - name = "NonMaxSuppressionV5", - srcs = ["kernels/NonMaxSuppressionV5.cc"], + name = "OneHot", + srcs = ["kernels/OneHot.cc"], deps = [ ":backend", - ":non_max_suppression_impl", - ":util", ], ) From b0615b792150ac4a1c21f46a14b390343d01bf17 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Wed, 17 Jun 2020 11:56:56 -0700 Subject: [PATCH 3/3] Minor adjustments --- tfjs-backend-wasm/src/kernels/OneHot.ts | 4 ++-- tfjs-core/src/ops/one_hot.ts | 11 +++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/OneHot.ts b/tfjs-backend-wasm/src/kernels/OneHot.ts index 380d1fb59a7..4462347d4de 100644 --- a/tfjs-backend-wasm/src/kernels/OneHot.ts +++ b/tfjs-backend-wasm/src/kernels/OneHot.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {KernelFunc, OneHotAttrs, OneHotInputs, registerKernel} from '@tensorflow/tfjs-core'; +import {KernelFunc, OneHot, OneHotAttrs, OneHotInputs, registerKernel} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; @@ -51,7 +51,7 @@ function oneHot( } registerKernel({ - kernelName: 'OneHot', + kernelName: OneHot, backendName: 'wasm', setupFunc: setup, kernelFunc: oneHot as {} as KernelFunc, diff --git a/tfjs-core/src/ops/one_hot.ts b/tfjs-core/src/ops/one_hot.ts index d0d4ecb8baa..61d0031dde8 100644 --- a/tfjs-core/src/ops/one_hot.ts +++ b/tfjs-core/src/ops/one_hot.ts @@ -18,7 +18,7 @@ import {ENGINE, ForwardFunc} from '../engine'; import {OneHot, OneHotAttrs, OneHotInputs} from '../kernel_names'; import {NamedAttrMap} from '../kernel_registry'; -import {Tensor, Tensor1D} from '../tensor'; +import {Tensor} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; @@ -50,22 +50,21 @@ function oneHot_( if (depth < 2) { throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`); } - let $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32'); + const $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32'); const outShape = [...$indices.shape, depth]; - $indices = $indices.flatten(); const forward: ForwardFunc = (backend, save) => { save([$indices]); - return backend.oneHot($indices as Tensor1D, depth, onValue, offValue); + return reshape( + backend.oneHot($indices.flatten(), depth, onValue, offValue), outShape); }; const inputs: OneHotInputs = {indices: $indices}; const attrs: OneHotAttrs = {depth, onValue, offValue}; - const res = ENGINE.runKernelFunc( + return ENGINE.runKernelFunc( forward, inputs as unknown as NamedTensorMap, null /* grad */, OneHot, attrs as unknown as NamedAttrMap); - return reshape(res, outShape); } export const oneHot = op({oneHot_});