diff --git a/tfjs-node/src/index.ts b/tfjs-node/src/index.ts index 2ca6bbf87fd..6533cde569f 100644 --- a/tfjs-node/src/index.ts +++ b/tfjs-node/src/index.ts @@ -24,6 +24,9 @@ import {NodeJSKernelBackend} from './nodejs_kernel_backend'; import {TFJSBinding} from './tfjs_binding'; import * as nodeVersion from './version'; +// Import all kernels. +import './kernels/all_kernels'; + // tslint:disable-next-line:no-require-imports const binary = require('node-pre-gyp'); const bindingPath = diff --git a/tfjs-node/src/kernels/all_kernels.ts b/tfjs-node/src/kernels/all_kernels.ts new file mode 100644 index 00000000000..c6eea842964 --- /dev/null +++ b/tfjs-node/src/kernels/all_kernels.ts @@ -0,0 +1,21 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// We explicitly import the modular kernels so they get registered in the +// global registry when we compile the library. A modular build would replace +// the contents of this file and import only the kernels that are needed. +import './non_max_suppression_v5'; diff --git a/tfjs-node/src/kernels/non_max_suppression_v5.ts b/tfjs-node/src/kernels/non_max_suppression_v5.ts new file mode 100644 index 00000000000..6f6dbb46b89 --- /dev/null +++ b/tfjs-node/src/kernels/non_max_suppression_v5.ts @@ -0,0 +1,67 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, scalar, Tensor1D, Tensor2D, TensorInfo} from '@tensorflow/tfjs-core'; + +import {createTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend'; + +interface NonMaxSuppressionWithScoreInputs extends NamedTensorInfoMap { + boxes: TensorInfo; + scores: TensorInfo; +} + +interface NonMaxSuppressionWithScoreAttrs extends NamedAttrMap { + maxOutputSize: number; + iouThreshold: number; + scoreThreshold: number; + softNmsSigma: number; +} + +// TODO(nsthorat, dsmilkov): Remove dependency on tensors, use dataId. +registerKernel({ + kernelName: 'NonMaxSuppressionV5', + backendName: 'tensorflow', + kernelFunc: ({inputs, backend, attrs}) => { + const {boxes, scores} = inputs as NonMaxSuppressionWithScoreInputs; + const {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma} = + attrs as NonMaxSuppressionWithScoreAttrs; + const maxOutputSizeTensor = scalar(maxOutputSize, 'int32'); + const iouThresholdTensor = scalar(iouThreshold); + const scoreThresholdTensor = scalar(scoreThreshold); + const softNmsSigmaTensor = scalar(softNmsSigma); + const opAttrs = [createTypeOpAttr('T', boxes.dtype)]; + + const nodeBackend = backend as NodeJSKernelBackend; + + const [selectedIndices, selectedScores, validOutputs] = + nodeBackend.executeMultipleOutputs( + 'NonMaxSuppressionV5', opAttrs, + [ + boxes as Tensor2D, scores as Tensor1D, maxOutputSizeTensor, + iouThresholdTensor, scoreThresholdTensor, softNmsSigmaTensor + ], + 3); + + maxOutputSizeTensor.dispose(); + iouThresholdTensor.dispose(); + scoreThresholdTensor.dispose(); + softNmsSigmaTensor.dispose(); + validOutputs.dispose(); + + return [selectedIndices, selectedScores]; + } +});