diff --git a/tfjs-core/src/backends/backend.ts b/tfjs-core/src/backends/backend.ts index c3c212626f1..fd3dbbaab3d 100644 --- a/tfjs-core/src/backends/backend.ts +++ b/tfjs-core/src/backends/backend.ts @@ -590,8 +590,7 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer { nonMaxSuppression( boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, - iouThreshold: number, scoreThreshold?: number, - softNmsSigma?: number): Tensor1D { + iouThreshold: number, scoreThreshold?: number): Tensor1D { return notYetImplemented('nonMaxSuppression'); } diff --git a/tfjs-core/src/backends/cpu/all_kernels.ts b/tfjs-core/src/backends/cpu/all_kernels.ts index 8d59f9ee9dd..75920b35b88 100644 --- a/tfjs-core/src/backends/cpu/all_kernels.ts +++ b/tfjs-core/src/backends/cpu/all_kernels.ts @@ -19,3 +19,4 @@ // global registry when we compile the library. A modular build would replace // the contents of this file and import only the kernels that are needed. import './square'; +import './non_max_suppression_v5'; diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index 5b04218c3ac..ee2a1e6772a 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -42,7 +42,7 @@ import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util'; import {BackendTimingInfo, DataStorage, EPSILON_FLOAT32, KernelBackend} from '../backend'; import * as backend_util from '../backend_util'; import * as complex_util from '../complex_util'; -import {nonMaxSuppressionImpl} from '../non_max_suppression_impl'; +import {nonMaxSuppressionV3} from '../non_max_suppression_impl'; import {split} from '../split_shared'; import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; @@ -3270,15 +3270,13 @@ export class MathBackendCPU extends KernelBackend { nonMaxSuppression( boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, - iouThreshold: number, scoreThreshold: number, - softNmsSigma: number): Tensor1D { + iouThreshold: number, scoreThreshold: number): Tensor1D { assertNotComplex(boxes, 'nonMaxSuppression'); const boxesVals = this.readSync(boxes.dataId) as TypedArray; const scoresVals = this.readSync(scores.dataId) as TypedArray; - return nonMaxSuppressionImpl( - boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, - softNmsSigma); + return nonMaxSuppressionV3( + boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); } fft(x: Tensor2D): Tensor2D { diff --git a/tfjs-core/src/backends/cpu/non_max_suppression_v5.ts b/tfjs-core/src/backends/cpu/non_max_suppression_v5.ts new file mode 100644 index 00000000000..62fdb3781cf --- /dev/null +++ b/tfjs-core/src/backends/cpu/non_max_suppression_v5.ts @@ -0,0 +1,63 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo} from '../../kernel_registry'; +import {TypedArray} from '../../types'; +import {nonMaxSuppressionV5} from '../non_max_suppression_impl'; + +import {MathBackendCPU} from './backend_cpu'; +import {assertNotComplex} from './cpu_util'; + +interface NonMaxSuppressionWithScoreInputs extends NamedTensorInfoMap { + boxes: TensorInfo; + scores: TensorInfo; +} + +interface NonMaxSuppressionWithScoreAttrs extends NamedAttrMap { + maxOutputSize: number; + iouThreshold: number; + scoreThreshold: number; + softNmsSigma: number; +} + +registerKernel({ + kernelName: 'NonMaxSuppressionV5', + backendName: 'cpu', + kernelFunc: ({inputs, backend, attrs}) => { + const {boxes, scores} = inputs as NonMaxSuppressionWithScoreInputs; + const {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma} = + attrs as NonMaxSuppressionWithScoreAttrs; + + const cpuBackend = backend as MathBackendCPU; + + assertNotComplex(boxes, 'NonMaxSuppressionWithScore'); + + const boxesVals = cpuBackend.data.get(boxes.dataId).values as TypedArray; + const scoresVals = cpuBackend.data.get(scores.dataId).values as TypedArray; + + const maxOutputSizeVal = maxOutputSize; + const iouThresholdVal = iouThreshold; + const scoreThresholdVal = scoreThreshold; + const softNmsSigmaVal = softNmsSigma; + + const {selectedIndices, selectedScores} = nonMaxSuppressionV5( + boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, + scoreThresholdVal, softNmsSigmaVal); + + return [selectedIndices, selectedScores]; + } +}); diff --git a/tfjs-core/src/backends/non_max_suppression_impl.ts b/tfjs-core/src/backends/non_max_suppression_impl.ts index ac6d5d02476..3f7bb8d8961 100644 --- a/tfjs-core/src/backends/non_max_suppression_impl.ts +++ b/tfjs-core/src/backends/non_max_suppression_impl.ts @@ -19,8 +19,9 @@ * Implementation of the NonMaxSuppression kernel shared between webgl and cpu. */ -import {tensor1d} from '../ops/tensor_ops'; +import {scalar, tensor1d} from '../ops/tensor_ops'; import {Tensor1D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {TypedArray} from '../types'; import {binaryInsert} from './array_util'; @@ -31,10 +32,41 @@ interface Candidate { suppressBeginIndex: number; } -export function nonMaxSuppressionImpl( +export function nonMaxSuppressionV3( + boxes: TypedArray, scores: TypedArray, maxOutputSize: number, + iouThreshold: number, scoreThreshold: number): Tensor1D { + const dummySoftNmsSigma = 0.0; + + return nonMaxSuppressionImpl_( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, + dummySoftNmsSigma) + .selectedIndices as Tensor1D; +} + +export function nonMaxSuppressionV5( boxes: TypedArray, scores: TypedArray, maxOutputSize: number, iouThreshold: number, scoreThreshold: number, - softNmsSigma: number): Tensor1D { + softNmsSigma: number): NamedTensorMap { + // For NonMaxSuppressionV5Op, we always return a second output holding + // corresponding scores. + const returnScoresTensor = true; + + const result = nonMaxSuppressionImpl_( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, + returnScoresTensor); + + result.numValidOutputs.dispose(); + + return { + selectedIndices: result.selectedIndices, + selectedScores: result.selectedScores + }; +} + +function nonMaxSuppressionImpl_( + boxes: TypedArray, scores: TypedArray, maxOutputSize: number, + iouThreshold: number, scoreThreshold: number, softNmsSigma: number, + returnScoresTensor = false, padToMaxOutputSize = false): NamedTensorMap { // The list is sorted in ascending order, so that we can always pop the // candidate with the largest score in O(1) time. const candidates = @@ -47,9 +79,10 @@ export function nonMaxSuppressionImpl( // before. const scale = softNmsSigma > 0 ? (-0.5 / softNmsSigma) : 0.0; - const selected: number[] = []; + const selectedIndices: number[] = []; + const selectedScores: number[] = []; - while (selected.length < maxOutputSize && candidates.length > 0) { + while (selectedIndices.length < maxOutputSize && candidates.length > 0) { const candidate = candidates.pop(); const {score: originalScore, boxIndex, suppressBeginIndex} = candidate; @@ -59,13 +92,13 @@ export function nonMaxSuppressionImpl( // Overlapping boxes are likely to have similar scores, therefore we // iterate through the previously selected boxes backwards in order to - // see if candidate's score should be suppressed. We use suppressBeginIndex - // to track and ensure a candidate can be suppressed by a selected box no - // more than once. Also, if the overlap exceeds iouThreshold, we simply - // ignore the candidate. + // see if candidate's score should be suppressed. We use + // suppressBeginIndex to track and ensure a candidate can be suppressed + // by a selected box no more than once. Also, if the overlap exceeds + // iouThreshold, we simply ignore the candidate. let ignoreCandidate = false; - for (let j = selected.length - 1; j >= suppressBeginIndex; --j) { - const iou = intersectionOverUnion(boxes, boxIndex, selected[j]); + for (let j = selectedIndices.length - 1; j >= suppressBeginIndex; --j) { + const iou = intersectionOverUnion(boxes, boxIndex, selectedIndices[j]); if (iou >= iouThreshold) { ignoreCandidate = true; @@ -81,19 +114,20 @@ export function nonMaxSuppressionImpl( } // At this point, if `candidate.score` has not dropped below - // `scoreThreshold`, then we know that we went through all of the previous - // selections and can safely update `suppressBeginIndex` to the end of the - // selected array. Then we can re-insert the candidate with the updated - // score and suppressBeginIndex back in the candidate list. If on the other - // hand, `candidate.score` has dropped below the score threshold, we will - // not add it back to the candidates list. - candidate.suppressBeginIndex = selected.length; + // `scoreThreshold`, then we know that we went through all of the + // previous selections and can safely update `suppressBeginIndex` to the + // end of the selected array. Then we can re-insert the candidate with + // the updated score and suppressBeginIndex back in the candidate list. + // If on the other hand, `candidate.score` has dropped below the score + // threshold, we will not add it back to the candidates list. + candidate.suppressBeginIndex = selectedIndices.length; if (!ignoreCandidate) { - // Candidate has passed all the tests, and is not suppressed, so select - // the candidate. + // Candidate has passed all the tests, and is not suppressed, so + // select the candidate. if (candidate.score === originalScore) { - selected.push(boxIndex); + selectedIndices.push(boxIndex); + selectedScores.push(candidate.score); } else if (candidate.score > scoreThreshold) { // Candidate's score is suppressed but is still high enough to be // considered, so add back to the candidates list. @@ -102,7 +136,18 @@ export function nonMaxSuppressionImpl( } } - return tensor1d(selected, 'int32'); + // NonMaxSuppressionV4 feature: padding output to maxOutputSize. + const numValidOutputs = selectedIndices.length; + if (padToMaxOutputSize) { + selectedIndices.fill(0, numValidOutputs); + selectedScores.fill(0.0, numValidOutputs); + } + + return { + selectedIndices: tensor1d(selectedIndices, 'int32'), + selectedScores: tensor1d(selectedScores, 'float32'), + numValidOutputs: scalar(numValidOutputs, 'int32') + }; } function intersectionOverUnion(boxes: TypedArray, i: number, j: number) { diff --git a/tfjs-core/src/backends/webgl/all_kernels.ts b/tfjs-core/src/backends/webgl/all_kernels.ts index 973dacddad2..724e83ce034 100644 --- a/tfjs-core/src/backends/webgl/all_kernels.ts +++ b/tfjs-core/src/backends/webgl/all_kernels.ts @@ -20,3 +20,4 @@ // the contents of this file and import only the kernels that are needed. import './square'; import './fromPixels'; +import './non_max_suppression_v5'; diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index 99b38ae9dd8..2bce838b3aa 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -45,7 +45,7 @@ import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} fr import {DataStorage, EPSILON_FLOAT16, EPSILON_FLOAT32, KernelBackend} from '../backend'; import * as backend_util from '../backend_util'; import {mergeRealAndImagArrays} from '../complex_util'; -import {nonMaxSuppressionImpl} from '../non_max_suppression_impl'; +import {nonMaxSuppressionV3} from '../non_max_suppression_impl'; import {split} from '../split_shared'; import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; @@ -2244,16 +2244,14 @@ export class MathBackendWebGL extends KernelBackend { nonMaxSuppression( boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, - iouThreshold: number, scoreThreshold: number, - softNmsSigma: number): Tensor1D { + iouThreshold: number, scoreThreshold: number): Tensor1D { warn( 'tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead'); const boxesVals = boxes.dataSync(); const scoresVals = scores.dataSync(); - return nonMaxSuppressionImpl( - boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, - softNmsSigma); + return nonMaxSuppressionV3( + boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); } cropAndResize( diff --git a/tfjs-core/src/backends/webgl/non_max_suppression_v5.ts b/tfjs-core/src/backends/webgl/non_max_suppression_v5.ts new file mode 100644 index 00000000000..4a8c0ac4156 --- /dev/null +++ b/tfjs-core/src/backends/webgl/non_max_suppression_v5.ts @@ -0,0 +1,65 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo} from '../../kernel_registry'; +import {warn} from '../../log'; +import {TypedArray} from '../../types'; +import {nonMaxSuppressionV5} from '../non_max_suppression_impl'; + +import {MathBackendWebGL} from './backend_webgl'; + +interface NonMaxSuppressionWithScoreInputs extends NamedTensorInfoMap { + boxes: TensorInfo; + scores: TensorInfo; +} + +interface NonMaxSuppressionWithScoreAttrs extends NamedAttrMap { + maxOutputSize: number; + iouThreshold: number; + scoreThreshold: number; + softNmsSigma: number; +} + +registerKernel({ + kernelName: 'NonMaxSuppressionV5', + backendName: 'webgl', + kernelFunc: ({inputs, backend, attrs}) => { + warn( + 'tf.nonMaxSuppression() in webgl locks the UI thread. ' + + 'Call tf.nonMaxSuppressionAsync() instead'); + + const {boxes, scores} = inputs as NonMaxSuppressionWithScoreInputs; + const {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma} = + attrs as NonMaxSuppressionWithScoreAttrs; + + const gpuBackend = backend as MathBackendWebGL; + + const boxesVals = gpuBackend.readSync(boxes.dataId) as TypedArray; + const scoresVals = gpuBackend.readSync(scores.dataId) as TypedArray; + + const maxOutputSizeVal = maxOutputSize; + const iouThresholdVal = iouThreshold; + const scoreThresholdVal = scoreThreshold; + const softNmsSigmaVal = softNmsSigma; + + const {selectedIndices, selectedScores} = nonMaxSuppressionV5( + boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, + scoreThresholdVal, softNmsSigmaVal); + + return [selectedIndices, selectedScores]; + } +}); diff --git a/tfjs-core/src/ops/image_ops.ts b/tfjs-core/src/ops/image_ops.ts index 5cebd9ced2a..195569ee54b 100644 --- a/tfjs-core/src/ops/image_ops.ts +++ b/tfjs-core/src/ops/image_ops.ts @@ -15,9 +15,10 @@ * ============================================================================= */ -import {nonMaxSuppressionImpl} from '../backends/non_max_suppression_impl'; +import {nonMaxSuppressionV3, nonMaxSuppressionV5} from '../backends/non_max_suppression_impl'; import {ENGINE, ForwardFunc} from '../engine'; import {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; @@ -147,12 +148,6 @@ function resizeNearestNeighbor_( * Performs non maximum suppression of bounding boxes based on * iou (intersection over union). * - * This op also supports a Soft-NMS mode (c.f. - * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score - * of other overlapping boxes, therefore favoring different regions of the image - * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma` - * parameter to be larger than 0. - * * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of * the bounding box. @@ -163,20 +158,18 @@ function resizeNearestNeighbor_( * Defaults to 0.5 (50% box overlap). * @param scoreThreshold A threshold for deciding when to remove boxes based * on score. Defaults to -inf, which means any score is accepted. - * @param softNmsSigma A float representing the sigma parameter for Soft NMS. * @return A 1D tensor with the selected box indices. */ /** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */ function nonMaxSuppression_( boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike, maxOutputSize: number, iouThreshold = 0.5, - scoreThreshold = Number.NEGATIVE_INFINITY, softNmsSigma = 0): Tensor1D { + scoreThreshold = Number.NEGATIVE_INFINITY): Tensor1D { const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression'); const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression'); const inputs = nonMaxSuppSanityCheck( - $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, - softNmsSigma); + $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold); maxOutputSize = inputs.maxOutputSize; iouThreshold = inputs.iouThreshold; scoreThreshold = inputs.scoreThreshold; @@ -184,18 +177,99 @@ function nonMaxSuppression_( const attrs = {maxOutputSize, iouThreshold, scoreThreshold}; return ENGINE.runKernelFunc( b => b.nonMaxSuppression( - $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, - softNmsSigma), + $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold), {boxes: $boxes, scores: $scores}, null /* grad */, 'NonMaxSuppressionV3', attrs); } /** This is the async version of `nonMaxSuppression` */ async function nonMaxSuppressionAsync_( + boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike, + maxOutputSize: number, iouThreshold = 0.5, + scoreThreshold = Number.NEGATIVE_INFINITY): Promise { + const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync'); + const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync'); + + const inputs = nonMaxSuppSanityCheck( + $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold); + maxOutputSize = inputs.maxOutputSize; + iouThreshold = inputs.iouThreshold; + scoreThreshold = inputs.scoreThreshold; + + const boxesAndScores = await Promise.all([$boxes.data(), $scores.data()]); + const boxesVals = boxesAndScores[0]; + const scoresVals = boxesAndScores[1]; + + const res = nonMaxSuppressionV3( + boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); + if ($boxes !== boxes) { + $boxes.dispose(); + } + if ($scores !== scores) { + $scores.dispose(); + } + return res; +} + +/** + * Performs non maximum suppression of bounding boxes based on + * iou (intersection over union). + * + * This op also supports a Soft-NMS mode (c.f. + * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score + * of other overlapping boxes, therefore favoring different regions of the image + * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma` + * parameter to be larger than 0. + * + * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is + * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of + * the bounding box. + * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`. + * @param maxOutputSize The maximum number of boxes to be selected. + * @param iouThreshold A float representing the threshold for deciding whether + * boxes overlap too much with respect to IOU. Must be between [0, 1]. + * Defaults to 0.5 (50% box overlap). + * @param scoreThreshold A threshold for deciding when to remove boxes based + * on score. Defaults to -inf, which means any score is accepted. + * @param softNmsSigma A float representing the sigma parameter for Soft NMS. + * When sigma is 0, it falls back to nonMaxSuppression. + * @return A map with the following properties: + * - selectedIndices: A 1D tensor with the selected box indices. + * - selectedScores: A 1D tensor with the corresponding scores for each + * selected box. + */ +/** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */ +function nonMaxSuppressionWithScore_( boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike, maxOutputSize: number, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, - softNmsSigma = 0.0): Promise { + softNmsSigma = 0.0): NamedTensorMap { + const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression'); + const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression'); + + const inputs = nonMaxSuppSanityCheck( + $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, + softNmsSigma); + maxOutputSize = inputs.maxOutputSize; + iouThreshold = inputs.iouThreshold; + scoreThreshold = inputs.scoreThreshold; + softNmsSigma = inputs.softNmsSigma; + + const attrs = {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma}; + + const result = ENGINE.runKernel( + 'NonMaxSuppressionV5', {boxes: $boxes, scores: $scores}, + attrs) as Tensor[]; + + return {selectedIndices: result[0], selectedScores: result[1]}; +} + +/** This is the async version of `nonMaxSuppressionWithScore` */ +async function nonMaxSuppressionWithScoreAsync_( + boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike, + maxOutputSize: number, iouThreshold = 0.5, + scoreThreshold = Number.NEGATIVE_INFINITY, + softNmsSigma = 0.0): Promise { const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync'); const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync'); @@ -211,9 +285,10 @@ async function nonMaxSuppressionAsync_( const boxesVals = boxesAndScores[0]; const scoresVals = boxesAndScores[1]; - const res = nonMaxSuppressionImpl( + const res = nonMaxSuppressionV5( boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma); + if ($boxes !== boxes) { $boxes.dispose(); } @@ -225,7 +300,7 @@ async function nonMaxSuppressionAsync_( function nonMaxSuppSanityCheck( boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, - iouThreshold: number, scoreThreshold: number, softNmsSigma: number): { + iouThreshold: number, scoreThreshold: number, softNmsSigma?: number): { maxOutputSize: number, iouThreshold: number, scoreThreshold: number, @@ -238,7 +313,7 @@ function nonMaxSuppSanityCheck( scoreThreshold = Number.NEGATIVE_INFINITY; } if (softNmsSigma == null) { - softNmsSigma = 0; + softNmsSigma = 0.0; } const numBoxes = boxes.shape[0]; @@ -340,4 +415,6 @@ export const resizeBilinear = op({resizeBilinear_}); export const resizeNearestNeighbor = op({resizeNearestNeighbor_}); export const nonMaxSuppression = op({nonMaxSuppression_}); export const nonMaxSuppressionAsync = nonMaxSuppressionAsync_; +export const nonMaxSuppressionWithScore = op({nonMaxSuppressionWithScore_}); +export const nonMaxSuppressionWithScoreAsync = nonMaxSuppressionWithScoreAsync_; export const cropAndResize = op({cropAndResize_}); diff --git a/tfjs-core/src/ops/image_ops_test.ts b/tfjs-core/src/ops/image_ops_test.ts index ae5d4cc6354..0e40a73124c 100644 --- a/tfjs-core/src/ops/image_ops_test.ts +++ b/tfjs-core/src/ops/image_ops_test.ts @@ -19,227 +19,251 @@ import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose, expectArraysEqual} from '../test_util'; describeWithFlags('nonMaxSuppression', ALL_ENVS, () => { - it('select from three clusters', async () => { - const boxes = tf.tensor2d( - [ - 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, - 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); - const maxOutputSize = 3; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([3]); - expectArraysEqual(await indices.data(), [3, 0, 5]); - }); - - it('select from three clusters flipped coordinates', async () => { - const boxes = tf.tensor2d( - [ - 1, 1, 0, 0, 0, 0.1, 1, 1.1, 0, .9, 1, -0.1, - 0, 10, 1, 11, 1, 10.1, 0, 11.1, 1, 101, 0, 100 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); - const maxOutputSize = 3; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([3]); - expectArraysEqual(await indices.data(), [3, 0, 5]); - }); - - it('select at most two boxes from three clusters', async () => { - const boxes = tf.tensor2d( - [ - 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, - 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); - const maxOutputSize = 2; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([2]); - expectArraysEqual(await indices.data(), [3, 0]); - }); - - it('select at most thirty boxes from three clusters', async () => { - const boxes = tf.tensor2d( - [ - 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, - 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); - const maxOutputSize = 30; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([3]); - expectArraysEqual(await indices.data(), [3, 0, 5]); - }); - - it('select single box', async () => { - const boxes = tf.tensor2d([0, 0, 1, 1], [1, 4]); - const scores = tf.tensor1d([0.9]); - const maxOutputSize = 3; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([1]); - expectArraysEqual(await indices.data(), [0]); - }); - - it('select from ten identical boxes', async () => { - const numBoxes = 10; - const corners = new Array(numBoxes) - .fill(0) - .map(_ => [0, 0, 1, 1]) - .reduce((arr, curr) => arr.concat(curr)); - const boxes = tf.tensor2d(corners, [numBoxes, 4]); - const scores = tf.tensor1d(Array(numBoxes).fill(0.9)); - const maxOutputSize = 3; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([1]); - expectArraysEqual(await indices.data(), [0]); - }); - - it('inconsistent box and score shapes', () => { - const boxes = tf.tensor2d( - [ - 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, - 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5]); - const maxOutputSize = 30; - const iouThreshold = 0.5; - const scoreThreshold = 0; - expect( - () => tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold)) - .toThrowError(/scores has incompatible shape with boxes/); - }); - - it('invalid iou threshold', () => { - const boxes = tf.tensor2d([0, 0, 1, 1], [1, 4]); - const scores = tf.tensor1d([0.9]); - const maxOutputSize = 3; - const iouThreshold = 1.2; - const scoreThreshold = 0; - expect( - () => tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold)) - .toThrowError(/iouThreshold must be in \[0, 1\]/); - }); - - it('empty input', async () => { - const boxes = tf.tensor2d([], [0, 4]); - const scores = tf.tensor1d([]); - const maxOutputSize = 3; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([0]); - expectArraysEqual(await indices.data(), []); - }); - - it('accepts a tensor-like object', async () => { - const boxes = [[0, 0, 1, 1], [0, 1, 1, 2]]; - const scores = [1, 2]; - const indices = tf.image.nonMaxSuppression(boxes, scores, 10); - expect(indices.shape).toEqual([2]); - expect(indices.dtype).toEqual('int32'); - expectArraysEqual(await indices.data(), [1, 0]); - }); - - it('select from three clusters with SoftNMS', async () => { - const boxes = tf.tensor2d( - [ - 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, - 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); - const maxOutputSize = 6; - const iouThreshold = 1.0; - const scoreThreshold = 0; - const softNmsSigma = 0.5; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, - softNmsSigma); - - expect(indices.shape).toEqual([6]); - expectArraysEqual(await indices.data(), [3, 0, 1, 5, 4, 2]); + describe('NonMaxSuppression Basic', () => { + it('select from three clusters', async () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 3; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([3]); + expectArraysEqual(await indices.data(), [3, 0, 5]); + }); + + it('select from three clusters flipped coordinates', async () => { + const boxes = tf.tensor2d( + [ + 1, 1, 0, 0, 0, 0.1, 1, 1.1, 0, .9, 1, -0.1, + 0, 10, 1, 11, 1, 10.1, 0, 11.1, 1, 101, 0, 100 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 3; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([3]); + expectArraysEqual(await indices.data(), [3, 0, 5]); + }); + + it('select at most two boxes from three clusters', async () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 2; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([2]); + expectArraysEqual(await indices.data(), [3, 0]); + }); + + it('select at most thirty boxes from three clusters', async () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 30; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([3]); + expectArraysEqual(await indices.data(), [3, 0, 5]); + }); + + it('select single box', async () => { + const boxes = tf.tensor2d([0, 0, 1, 1], [1, 4]); + const scores = tf.tensor1d([0.9]); + const maxOutputSize = 3; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([1]); + expectArraysEqual(await indices.data(), [0]); + }); + + it('select from ten identical boxes', async () => { + const numBoxes = 10; + const corners = new Array(numBoxes) + .fill(0) + .map(_ => [0, 0, 1, 1]) + .reduce((arr, curr) => arr.concat(curr)); + const boxes = tf.tensor2d(corners, [numBoxes, 4]); + const scores = tf.tensor1d(Array(numBoxes).fill(0.9)); + const maxOutputSize = 3; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([1]); + expectArraysEqual(await indices.data(), [0]); + }); + + it('inconsistent box and score shapes', () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5]); + const maxOutputSize = 30; + const iouThreshold = 0.5; + const scoreThreshold = 0; + expect( + () => tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold)) + .toThrowError(/scores has incompatible shape with boxes/); + }); + + it('invalid iou threshold', () => { + const boxes = tf.tensor2d([0, 0, 1, 1], [1, 4]); + const scores = tf.tensor1d([0.9]); + const maxOutputSize = 3; + const iouThreshold = 1.2; + const scoreThreshold = 0; + expect( + () => tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold)) + .toThrowError(/iouThreshold must be in \[0, 1\]/); + }); + + it('empty input', async () => { + const boxes = tf.tensor2d([], [0, 4]); + const scores = tf.tensor1d([]); + const maxOutputSize = 3; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([0]); + expectArraysEqual(await indices.data(), []); + }); + + it('accepts a tensor-like object', async () => { + const boxes = [[0, 0, 1, 1], [0, 1, 1, 2]]; + const scores = [1, 2]; + const indices = tf.image.nonMaxSuppression(boxes, scores, 10); + expect(indices.shape).toEqual([2]); + expect(indices.dtype).toEqual('int32'); + expectArraysEqual(await indices.data(), [1, 0]); + }); + }); + + describe('NonMaxSuppressionWithScore', () => { + it('select from three clusters with SoftNMS', async () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 6; + const iouThreshold = 1.0; + const scoreThreshold = 0; + const softNmsSigma = 0.5; + + const {selectedIndices, selectedScores} = + tf.image.nonMaxSuppressionWithScore( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, + softNmsSigma); + + expectArraysEqual(await selectedIndices.data(), [3, 0, 1, 5, 4, 2]); + + expectArraysClose( + await selectedScores.data(), [0.95, 0.9, 0.384, 0.3, 0.256, 0.197]); + }); }); }); describeWithFlags('nonMaxSuppressionAsync', ALL_ENVS, () => { - it('select from three clusters', async () => { - const boxes = tf.tensor2d( - [ - 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, - 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); - const maxOutputSize = 3; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = await tf.image.nonMaxSuppressionAsync( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([3]); - expectArraysEqual(await indices.data(), [3, 0, 5]); - }); - - it('accepts a tensor-like object', async () => { - const boxes = [[0, 0, 1, 1], [0, 1, 1, 2]]; - const scores = [1, 2]; - const indices = await tf.image.nonMaxSuppressionAsync(boxes, scores, 10); - expect(indices.shape).toEqual([2]); - expect(indices.dtype).toEqual('int32'); - expectArraysEqual(await indices.data(), [1, 0]); - }); - - it('select from three clusters with SoftNMS', async () => { - const boxes = tf.tensor2d( - [ - 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, - 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); - const maxOutputSize = 6; - const iouThreshold = 1.0; - const scoreThreshold = 0; - const softNmsSigma = 0.5; - const indices = await tf.image.nonMaxSuppressionAsync( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, - softNmsSigma); - - expect(indices.shape).toEqual([6]); - expectArraysEqual(await indices.data(), [3, 0, 1, 5, 4, 2]); + describe('NonMaxSuppressionAsync basic', () => { + it('select from three clusters', async () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 3; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = await tf.image.nonMaxSuppressionAsync( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([3]); + expectArraysEqual(await indices.data(), [3, 0, 5]); + }); + + it('accepts a tensor-like object', async () => { + const boxes = [[0, 0, 1, 1], [0, 1, 1, 2]]; + const scores = [1, 2]; + const indices = await tf.image.nonMaxSuppressionAsync(boxes, scores, 10); + expect(indices.shape).toEqual([2]); + expect(indices.dtype).toEqual('int32'); + expectArraysEqual(await indices.data(), [1, 0]); + }); + }); + + describe('NonMaxSuppressionWithScoreAsync', () => { + it('select from three clusters with SoftNMS', async () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 6; + const iouThreshold = 1.0; + const scoreThreshold = 0; + const softNmsSigma = 0.5; + + const numTensorsBefore = tf.memory().numTensors; + + const {selectedIndices, selectedScores} = + await tf.image.nonMaxSuppressionWithScoreAsync( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, + softNmsSigma); + + const numTensorsAfter = tf.memory().numTensors; + + expectArraysEqual(await selectedIndices.data(), [3, 0, 1, 5, 4, 2]); + + expectArraysClose( + await selectedScores.data(), [0.95, 0.9, 0.384, 0.3, 0.256, 0.197]); + + // The number of tensors should increase by the number of tensors + // returned (i.e. selectedIndices and selectedScores). + expect(numTensorsAfter).toEqual(numTensorsBefore + 2); + }); }); });