From b29187b3c1dd95bb7eb6cd9cd35d85f73583086f Mon Sep 17 00:00:00 2001 From: Na Li Date: Fri, 13 Dec 2019 17:17:40 -0800 Subject: [PATCH 1/5] Split NonMaxSuppressionV3 and NonMaxSuppresionV5 to two APIs --- tfjs-core/src/backends/backend.ts | 8 +- tfjs-core/src/backends/cpu/backend_cpu.ts | 19 +- .../src/backends/non_max_suppression_impl.ts | 70 ++- tfjs-core/src/backends/webgl/backend_webgl.ts | 18 +- tfjs-core/src/ops/image_ops.ts | 99 +++- tfjs-core/src/ops/image_ops_test.ts | 456 +++++++++--------- 6 files changed, 416 insertions(+), 254 deletions(-) diff --git a/tfjs-core/src/backends/backend.ts b/tfjs-core/src/backends/backend.ts index c3c212626f1..4c11114780f 100644 --- a/tfjs-core/src/backends/backend.ts +++ b/tfjs-core/src/backends/backend.ts @@ -589,9 +589,15 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer { } nonMaxSuppression( + boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, + iouThreshold: number, scoreThreshold?: number): Tensor1D { + return notYetImplemented('nonMaxSuppression'); + } + + nonMaxSuppressionWithScore( boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, iouThreshold: number, scoreThreshold?: number, - softNmsSigma?: number): Tensor1D { + softNmsSigma?: number): [Tensor1D, Tensor1D, Scalar] { return notYetImplemented('nonMaxSuppression'); } diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index 5b04218c3ac..ec9af4eb7f0 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -42,7 +42,7 @@ import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util'; import {BackendTimingInfo, DataStorage, EPSILON_FLOAT32, KernelBackend} from '../backend'; import * as backend_util from '../backend_util'; import * as complex_util from '../complex_util'; -import {nonMaxSuppressionImpl} from '../non_max_suppression_impl'; +import {nonMaxSuppressionV3, nonMaxSuppressionV5} from '../non_max_suppression_impl'; import {split} from '../split_shared'; import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; @@ -3270,13 +3270,24 @@ export class MathBackendCPU extends KernelBackend { nonMaxSuppression( boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, - iouThreshold: number, scoreThreshold: number, - softNmsSigma: number): Tensor1D { + iouThreshold: number, scoreThreshold: number): Tensor1D { assertNotComplex(boxes, 'nonMaxSuppression'); const boxesVals = this.readSync(boxes.dataId) as TypedArray; const scoresVals = this.readSync(scores.dataId) as TypedArray; - return nonMaxSuppressionImpl( + return nonMaxSuppressionV3( + boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); + } + + nonMaxSuppressionWithScore( + boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, + iouThreshold: number, scoreThreshold: number, + softNmsSigma: number): [Tensor1D, Tensor1D, Scalar] { + assertNotComplex(boxes, 'nonMaxSuppressionWithScore'); + + const boxesVals = this.readSync(boxes.dataId) as TypedArray; + const scoresVals = this.readSync(scores.dataId) as TypedArray; + return nonMaxSuppressionV5( boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma); } diff --git a/tfjs-core/src/backends/non_max_suppression_impl.ts b/tfjs-core/src/backends/non_max_suppression_impl.ts index ac6d5d02476..5fbc2b05bf7 100644 --- a/tfjs-core/src/backends/non_max_suppression_impl.ts +++ b/tfjs-core/src/backends/non_max_suppression_impl.ts @@ -19,8 +19,8 @@ * Implementation of the NonMaxSuppression kernel shared between webgl and cpu. */ -import {tensor1d} from '../ops/tensor_ops'; -import {Tensor1D} from '../tensor'; +import {scalar, tensor1d} from '../ops/tensor_ops'; +import {Scalar, Tensor1D} from '../tensor'; import {TypedArray} from '../types'; import {binaryInsert} from './array_util'; @@ -31,10 +31,34 @@ interface Candidate { suppressBeginIndex: number; } -export function nonMaxSuppressionImpl( +export function nonMaxSuppressionV3( + boxes: TypedArray, scores: TypedArray, maxOutputSize: number, + iouThreshold: number, scoreThreshold: number): Tensor1D { + const dummySoftNmsSigma = 0.0; + + return nonMaxSuppressionImpl_( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, + dummySoftNmsSigma)[0]; +} + +export function nonMaxSuppressionV5( boxes: TypedArray, scores: TypedArray, maxOutputSize: number, iouThreshold: number, scoreThreshold: number, - softNmsSigma: number): Tensor1D { + softNmsSigma: number): [Tensor1D, Tensor1D, Scalar] { + // For NonMaxSuppressionV5Op, we always return a second output holding + // corresponding scores. + const returnScoresTensor = true; + + return nonMaxSuppressionImpl_( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, + returnScoresTensor); +} + +function nonMaxSuppressionImpl_( + boxes: TypedArray, scores: TypedArray, maxOutputSize: number, + iouThreshold: number, scoreThreshold: number, softNmsSigma: number, + returnScoresTensor = false, + padToMaxOutputSize = false): [Tensor1D, Tensor1D, Scalar] { // The list is sorted in ascending order, so that we can always pop the // candidate with the largest score in O(1) time. const candidates = @@ -48,6 +72,7 @@ export function nonMaxSuppressionImpl( const scale = softNmsSigma > 0 ? (-0.5 / softNmsSigma) : 0.0; const selected: number[] = []; + const selectedScore: number[] = []; while (selected.length < maxOutputSize && candidates.length > 0) { const candidate = candidates.pop(); @@ -59,10 +84,10 @@ export function nonMaxSuppressionImpl( // Overlapping boxes are likely to have similar scores, therefore we // iterate through the previously selected boxes backwards in order to - // see if candidate's score should be suppressed. We use suppressBeginIndex - // to track and ensure a candidate can be suppressed by a selected box no - // more than once. Also, if the overlap exceeds iouThreshold, we simply - // ignore the candidate. + // see if candidate's score should be suppressed. We use + // suppressBeginIndex to track and ensure a candidate can be suppressed + // by a selected box no more than once. Also, if the overlap exceeds + // iouThreshold, we simply ignore the candidate. let ignoreCandidate = false; for (let j = selected.length - 1; j >= suppressBeginIndex; --j) { const iou = intersectionOverUnion(boxes, boxIndex, selected[j]); @@ -81,19 +106,20 @@ export function nonMaxSuppressionImpl( } // At this point, if `candidate.score` has not dropped below - // `scoreThreshold`, then we know that we went through all of the previous - // selections and can safely update `suppressBeginIndex` to the end of the - // selected array. Then we can re-insert the candidate with the updated - // score and suppressBeginIndex back in the candidate list. If on the other - // hand, `candidate.score` has dropped below the score threshold, we will - // not add it back to the candidates list. + // `scoreThreshold`, then we know that we went through all of the + // previous selections and can safely update `suppressBeginIndex` to the + // end of the selected array. Then we can re-insert the candidate with + // the updated score and suppressBeginIndex back in the candidate list. + // If on the other hand, `candidate.score` has dropped below the score + // threshold, we will not add it back to the candidates list. candidate.suppressBeginIndex = selected.length; if (!ignoreCandidate) { - // Candidate has passed all the tests, and is not suppressed, so select - // the candidate. + // Candidate has passed all the tests, and is not suppressed, so + // select the candidate. if (candidate.score === originalScore) { selected.push(boxIndex); + selectedScore.push(candidate.score); } else if (candidate.score > scoreThreshold) { // Candidate's score is suppressed but is still high enough to be // considered, so add back to the candidates list. @@ -102,7 +128,17 @@ export function nonMaxSuppressionImpl( } } - return tensor1d(selected, 'int32'); + // NonMaxSuppressionV4 feature: padding output to maxOutputSize. + const numValidOutputs = selected.length; + if (padToMaxOutputSize) { + selected.fill(0, numValidOutputs); + selectedScore.fill(0.0, numValidOutputs); + } + + return [ + tensor1d(selected, 'int32'), tensor1d(selectedScore, 'float32'), + scalar(numValidOutputs, 'int32') + ]; } function intersectionOverUnion(boxes: TypedArray, i: number, j: number) { diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index 99b38ae9dd8..f0fe05c5f46 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -45,7 +45,7 @@ import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} fr import {DataStorage, EPSILON_FLOAT16, EPSILON_FLOAT32, KernelBackend} from '../backend'; import * as backend_util from '../backend_util'; import {mergeRealAndImagArrays} from '../complex_util'; -import {nonMaxSuppressionImpl} from '../non_max_suppression_impl'; +import {nonMaxSuppressionV3, nonMaxSuppressionV5} from '../non_max_suppression_impl'; import {split} from '../split_shared'; import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; @@ -2243,15 +2243,27 @@ export class MathBackendWebGL extends KernelBackend { } nonMaxSuppression( + boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, + iouThreshold: number, scoreThreshold: number): Tensor1D { + warn( + 'tf.nonMaxSuppression() in webgl locks the UI thread. ' + + 'Call tf.nonMaxSuppressionAsync() instead'); + const boxesVals = boxes.dataSync(); + const scoresVals = scores.dataSync(); + return nonMaxSuppressionV3( + boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); + } + + nonMaxSuppressionWithScore( boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, iouThreshold: number, scoreThreshold: number, - softNmsSigma: number): Tensor1D { + softNmsSigma: number): [Tensor1D, Tensor1D, Scalar] { warn( 'tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead'); const boxesVals = boxes.dataSync(); const scoresVals = scores.dataSync(); - return nonMaxSuppressionImpl( + return nonMaxSuppressionV5( boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma); } diff --git a/tfjs-core/src/ops/image_ops.ts b/tfjs-core/src/ops/image_ops.ts index 5cebd9ced2a..eea8d6fdd9d 100644 --- a/tfjs-core/src/ops/image_ops.ts +++ b/tfjs-core/src/ops/image_ops.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {nonMaxSuppressionImpl} from '../backends/non_max_suppression_impl'; +import {nonMaxSuppressionV3, nonMaxSuppressionV5} from '../backends/non_max_suppression_impl'; import {ENGINE, ForwardFunc} from '../engine'; import {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; @@ -143,6 +143,73 @@ function resizeNearestNeighbor_( return res as T; } +/** + * Performs non maximum suppression of bounding boxes based on + * iou (intersection over union). + * + * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is + * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of + * the bounding box. + * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`. + * @param maxOutputSize The maximum number of boxes to be selected. + * @param iouThreshold A float representing the threshold for deciding whether + * boxes overlap too much with respect to IOU. Must be between [0, 1]. + * Defaults to 0.5 (50% box overlap). + * @param scoreThreshold A threshold for deciding when to remove boxes based + * on score. Defaults to -inf, which means any score is accepted. + * @return A 1D tensor with the selected box indices. + */ +/** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */ +function nonMaxSuppression_( + boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike, + maxOutputSize: number, iouThreshold = 0.5, + scoreThreshold = Number.NEGATIVE_INFINITY): Tensor1D { + const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression'); + const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression'); + + const inputs = nonMaxSuppSanityCheck( + $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold); + maxOutputSize = inputs.maxOutputSize; + iouThreshold = inputs.iouThreshold; + scoreThreshold = inputs.scoreThreshold; + + const attrs = {maxOutputSize, iouThreshold, scoreThreshold}; + return ENGINE.runKernelFunc( + b => b.nonMaxSuppression( + $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold), + {boxes: $boxes, scores: $scores}, null /* grad */, 'NonMaxSuppressionV3', + attrs); +} + +/** This is the async version of `nonMaxSuppression` */ +async function nonMaxSuppressionAsync_( + boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike, + maxOutputSize: number, iouThreshold = 0.5, + scoreThreshold = Number.NEGATIVE_INFINITY): Promise { + const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync'); + const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync'); + + const inputs = nonMaxSuppSanityCheck( + $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold); + maxOutputSize = inputs.maxOutputSize; + iouThreshold = inputs.iouThreshold; + scoreThreshold = inputs.scoreThreshold; + + const boxesAndScores = await Promise.all([$boxes.data(), $scores.data()]); + const boxesVals = boxesAndScores[0]; + const scoresVals = boxesAndScores[1]; + + const res = nonMaxSuppressionV3( + boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); + if ($boxes !== boxes) { + $boxes.dispose(); + } + if ($scores !== scores) { + $scores.dispose(); + } + return res; +} + /** * Performs non maximum suppression of bounding boxes based on * iou (intersection over union). @@ -164,13 +231,19 @@ function resizeNearestNeighbor_( * @param scoreThreshold A threshold for deciding when to remove boxes based * on score. Defaults to -inf, which means any score is accepted. * @param softNmsSigma A float representing the sigma parameter for Soft NMS. - * @return A 1D tensor with the selected box indices. + * When sigma is 0, it falls back to nonMaxSuppression. + * @return An `Array` of + * - A 1D tensor with the selected box indices. + * - A 1D tensor with the corresponding scores for each selected box. + * - A number representing the number of valid elements in the selected + * box indices. */ /** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */ -function nonMaxSuppression_( +function nonMaxSuppressionWithScore_( boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike, maxOutputSize: number, iouThreshold = 0.5, - scoreThreshold = Number.NEGATIVE_INFINITY, softNmsSigma = 0): Tensor1D { + scoreThreshold = Number.NEGATIVE_INFINITY, + softNmsSigma = 0.0): [Tensor1D, Tensor1D, Tensor] { const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression'); const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression'); @@ -183,19 +256,19 @@ function nonMaxSuppression_( const attrs = {maxOutputSize, iouThreshold, scoreThreshold}; return ENGINE.runKernelFunc( - b => b.nonMaxSuppression( + b => b.nonMaxSuppressionWithScore( $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma), - {boxes: $boxes, scores: $scores}, null /* grad */, 'NonMaxSuppressionV3', + {boxes: $boxes, scores: $scores}, null /* grad */, 'NonMaxSuppressionV5', attrs); } -/** This is the async version of `nonMaxSuppression` */ -async function nonMaxSuppressionAsync_( +/** This is the async version of `nonMaxSuppressionWithScore` */ +async function nonMaxSuppressionWithScoreAsync_( boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike, maxOutputSize: number, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, - softNmsSigma = 0.0): Promise { + softNmsSigma = 0.0): Promise<[Tensor1D, Tensor1D, Tensor]> { const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync'); const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync'); @@ -211,7 +284,7 @@ async function nonMaxSuppressionAsync_( const boxesVals = boxesAndScores[0]; const scoresVals = boxesAndScores[1]; - const res = nonMaxSuppressionImpl( + const res = nonMaxSuppressionV5( boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma); if ($boxes !== boxes) { @@ -225,7 +298,7 @@ async function nonMaxSuppressionAsync_( function nonMaxSuppSanityCheck( boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, - iouThreshold: number, scoreThreshold: number, softNmsSigma: number): { + iouThreshold: number, scoreThreshold: number, softNmsSigma?: number): { maxOutputSize: number, iouThreshold: number, scoreThreshold: number, @@ -238,7 +311,7 @@ function nonMaxSuppSanityCheck( scoreThreshold = Number.NEGATIVE_INFINITY; } if (softNmsSigma == null) { - softNmsSigma = 0; + softNmsSigma = 0.0; } const numBoxes = boxes.shape[0]; @@ -340,4 +413,6 @@ export const resizeBilinear = op({resizeBilinear_}); export const resizeNearestNeighbor = op({resizeNearestNeighbor_}); export const nonMaxSuppression = op({nonMaxSuppression_}); export const nonMaxSuppressionAsync = nonMaxSuppressionAsync_; +export const nonMaxSuppressionWithScore = op({nonMaxSuppressionWithScore_}); +export const nonMaxSuppressionWithScoreAsync = nonMaxSuppressionWithScoreAsync_; export const cropAndResize = op({cropAndResize_}); diff --git a/tfjs-core/src/ops/image_ops_test.ts b/tfjs-core/src/ops/image_ops_test.ts index ae5d4cc6354..0e1908d865d 100644 --- a/tfjs-core/src/ops/image_ops_test.ts +++ b/tfjs-core/src/ops/image_ops_test.ts @@ -19,227 +19,249 @@ import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose, expectArraysEqual} from '../test_util'; describeWithFlags('nonMaxSuppression', ALL_ENVS, () => { - it('select from three clusters', async () => { - const boxes = tf.tensor2d( - [ - 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, - 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); - const maxOutputSize = 3; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([3]); - expectArraysEqual(await indices.data(), [3, 0, 5]); - }); - - it('select from three clusters flipped coordinates', async () => { - const boxes = tf.tensor2d( - [ - 1, 1, 0, 0, 0, 0.1, 1, 1.1, 0, .9, 1, -0.1, - 0, 10, 1, 11, 1, 10.1, 0, 11.1, 1, 101, 0, 100 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); - const maxOutputSize = 3; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([3]); - expectArraysEqual(await indices.data(), [3, 0, 5]); - }); - - it('select at most two boxes from three clusters', async () => { - const boxes = tf.tensor2d( - [ - 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, - 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); - const maxOutputSize = 2; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([2]); - expectArraysEqual(await indices.data(), [3, 0]); - }); - - it('select at most thirty boxes from three clusters', async () => { - const boxes = tf.tensor2d( - [ - 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, - 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); - const maxOutputSize = 30; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([3]); - expectArraysEqual(await indices.data(), [3, 0, 5]); - }); - - it('select single box', async () => { - const boxes = tf.tensor2d([0, 0, 1, 1], [1, 4]); - const scores = tf.tensor1d([0.9]); - const maxOutputSize = 3; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([1]); - expectArraysEqual(await indices.data(), [0]); - }); - - it('select from ten identical boxes', async () => { - const numBoxes = 10; - const corners = new Array(numBoxes) - .fill(0) - .map(_ => [0, 0, 1, 1]) - .reduce((arr, curr) => arr.concat(curr)); - const boxes = tf.tensor2d(corners, [numBoxes, 4]); - const scores = tf.tensor1d(Array(numBoxes).fill(0.9)); - const maxOutputSize = 3; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([1]); - expectArraysEqual(await indices.data(), [0]); - }); - - it('inconsistent box and score shapes', () => { - const boxes = tf.tensor2d( - [ - 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, - 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5]); - const maxOutputSize = 30; - const iouThreshold = 0.5; - const scoreThreshold = 0; - expect( - () => tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold)) - .toThrowError(/scores has incompatible shape with boxes/); - }); - - it('invalid iou threshold', () => { - const boxes = tf.tensor2d([0, 0, 1, 1], [1, 4]); - const scores = tf.tensor1d([0.9]); - const maxOutputSize = 3; - const iouThreshold = 1.2; - const scoreThreshold = 0; - expect( - () => tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold)) - .toThrowError(/iouThreshold must be in \[0, 1\]/); - }); - - it('empty input', async () => { - const boxes = tf.tensor2d([], [0, 4]); - const scores = tf.tensor1d([]); - const maxOutputSize = 3; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([0]); - expectArraysEqual(await indices.data(), []); - }); - - it('accepts a tensor-like object', async () => { - const boxes = [[0, 0, 1, 1], [0, 1, 1, 2]]; - const scores = [1, 2]; - const indices = tf.image.nonMaxSuppression(boxes, scores, 10); - expect(indices.shape).toEqual([2]); - expect(indices.dtype).toEqual('int32'); - expectArraysEqual(await indices.data(), [1, 0]); - }); - - it('select from three clusters with SoftNMS', async () => { - const boxes = tf.tensor2d( - [ - 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, - 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); - const maxOutputSize = 6; - const iouThreshold = 1.0; - const scoreThreshold = 0; - const softNmsSigma = 0.5; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, - softNmsSigma); - - expect(indices.shape).toEqual([6]); - expectArraysEqual(await indices.data(), [3, 0, 1, 5, 4, 2]); + describe('NonMaxSuppression Basic', () => { + it('select from three clusters', async () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 3; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([3]); + expectArraysEqual(await indices.data(), [3, 0, 5]); + }); + + it('select from three clusters flipped coordinates', async () => { + const boxes = tf.tensor2d( + [ + 1, 1, 0, 0, 0, 0.1, 1, 1.1, 0, .9, 1, -0.1, + 0, 10, 1, 11, 1, 10.1, 0, 11.1, 1, 101, 0, 100 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 3; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([3]); + expectArraysEqual(await indices.data(), [3, 0, 5]); + }); + + it('select at most two boxes from three clusters', async () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 2; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([2]); + expectArraysEqual(await indices.data(), [3, 0]); + }); + + it('select at most thirty boxes from three clusters', async () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 30; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([3]); + expectArraysEqual(await indices.data(), [3, 0, 5]); + }); + + it('select single box', async () => { + const boxes = tf.tensor2d([0, 0, 1, 1], [1, 4]); + const scores = tf.tensor1d([0.9]); + const maxOutputSize = 3; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([1]); + expectArraysEqual(await indices.data(), [0]); + }); + + it('select from ten identical boxes', async () => { + const numBoxes = 10; + const corners = new Array(numBoxes) + .fill(0) + .map(_ => [0, 0, 1, 1]) + .reduce((arr, curr) => arr.concat(curr)); + const boxes = tf.tensor2d(corners, [numBoxes, 4]); + const scores = tf.tensor1d(Array(numBoxes).fill(0.9)); + const maxOutputSize = 3; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([1]); + expectArraysEqual(await indices.data(), [0]); + }); + + it('inconsistent box and score shapes', () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5]); + const maxOutputSize = 30; + const iouThreshold = 0.5; + const scoreThreshold = 0; + expect( + () => tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold)) + .toThrowError(/scores has incompatible shape with boxes/); + }); + + it('invalid iou threshold', () => { + const boxes = tf.tensor2d([0, 0, 1, 1], [1, 4]); + const scores = tf.tensor1d([0.9]); + const maxOutputSize = 3; + const iouThreshold = 1.2; + const scoreThreshold = 0; + expect( + () => tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold)) + .toThrowError(/iouThreshold must be in \[0, 1\]/); + }); + + it('empty input', async () => { + const boxes = tf.tensor2d([], [0, 4]); + const scores = tf.tensor1d([]); + const maxOutputSize = 3; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([0]); + expectArraysEqual(await indices.data(), []); + }); + + it('accepts a tensor-like object', async () => { + const boxes = [[0, 0, 1, 1], [0, 1, 1, 2]]; + const scores = [1, 2]; + const indices = tf.image.nonMaxSuppression(boxes, scores, 10); + expect(indices.shape).toEqual([2]); + expect(indices.dtype).toEqual('int32'); + expectArraysEqual(await indices.data(), [1, 0]); + }); + }); + + describe('NonMaxSuppressionWithScore', () => { + it('select from three clusters with SoftNMS', async () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 6; + const iouThreshold = 1.0; + const scoreThreshold = 0; + const softNmsSigma = 0.5; + const indices = tf.image.nonMaxSuppressionWithScore( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, + softNmsSigma); + + expect(indices[0].shape).toEqual([6]); + expectArraysEqual(await indices[0].data(), [3, 0, 1, 5, 4, 2]); + + expect(indices[1].shape).toEqual([6]); + expectArraysClose( + await indices[1].data(), [0.95, 0.9, 0.384, 0.3, 0.256, 0.197]); + + expect(indices[2].shape).toEqual([]); + expectArraysEqual(await indices[2].data(), [6]); + }); }); }); describeWithFlags('nonMaxSuppressionAsync', ALL_ENVS, () => { - it('select from three clusters', async () => { - const boxes = tf.tensor2d( - [ - 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, - 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); - const maxOutputSize = 3; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = await tf.image.nonMaxSuppressionAsync( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([3]); - expectArraysEqual(await indices.data(), [3, 0, 5]); - }); - - it('accepts a tensor-like object', async () => { - const boxes = [[0, 0, 1, 1], [0, 1, 1, 2]]; - const scores = [1, 2]; - const indices = await tf.image.nonMaxSuppressionAsync(boxes, scores, 10); - expect(indices.shape).toEqual([2]); - expect(indices.dtype).toEqual('int32'); - expectArraysEqual(await indices.data(), [1, 0]); - }); - - it('select from three clusters with SoftNMS', async () => { - const boxes = tf.tensor2d( - [ - 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, - 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 - ], - [6, 4]); - const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); - const maxOutputSize = 6; - const iouThreshold = 1.0; - const scoreThreshold = 0; - const softNmsSigma = 0.5; - const indices = await tf.image.nonMaxSuppressionAsync( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, - softNmsSigma); - - expect(indices.shape).toEqual([6]); - expectArraysEqual(await indices.data(), [3, 0, 1, 5, 4, 2]); + describe('NonMaxSuppressionAsync basic', () => { + it('select from three clusters', async () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 3; + const iouThreshold = 0.5; + const scoreThreshold = 0; + const indices = await tf.image.nonMaxSuppressionAsync( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); + + expect(indices.shape).toEqual([3]); + expectArraysEqual(await indices.data(), [3, 0, 5]); + }); + + it('accepts a tensor-like object', async () => { + const boxes = [[0, 0, 1, 1], [0, 1, 1, 2]]; + const scores = [1, 2]; + const indices = await tf.image.nonMaxSuppressionAsync(boxes, scores, 10); + expect(indices.shape).toEqual([2]); + expect(indices.dtype).toEqual('int32'); + expectArraysEqual(await indices.data(), [1, 0]); + }); + }); + + describe('NonMaxSuppressionWithScoreAsync', () => { + it('select from three clusters with SoftNMS', async () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 6; + const iouThreshold = 1.0; + const scoreThreshold = 0; + const softNmsSigma = 0.5; + const indices = await tf.image.nonMaxSuppressionWithScoreAsync( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, + softNmsSigma); + + expect(indices[0].shape).toEqual([6]); + expectArraysEqual(await indices[0].data(), [3, 0, 1, 5, 4, 2]); + + expect(indices[1].shape).toEqual([6]); + expectArraysClose( + await indices[1].data(), [0.95, 0.9, 0.384, 0.3, 0.256, 0.197]); + + expect(indices[2].shape).toEqual([]); + expectArraysEqual(await indices[2].data(), [6]); + }); }); }); From 817de6bc216cdbdfe27ebef364969bfcf12bc1ff Mon Sep 17 00:00:00 2001 From: Na Li Date: Mon, 16 Dec 2019 14:46:04 -0800 Subject: [PATCH 2/5] Change output to named object instead of array; Use modularized pattern to create new op; Fix output to two instead of three --- tfjs-core/src/backends/backend.ts | 7 -- tfjs-core/src/backends/cpu/all_kernels.ts | 1 + tfjs-core/src/backends/cpu/backend_cpu.ts | 15 +---- .../cpu/non_max_suppression_with_score.ts | 63 ++++++++++++++++++ .../src/backends/non_max_suppression_impl.ts | 53 ++++++++------- tfjs-core/src/backends/webgl/all_kernels.ts | 1 + tfjs-core/src/backends/webgl/backend_webgl.ts | 16 +---- .../webgl/non_max_suppression_with_score.ts | 65 +++++++++++++++++++ tfjs-core/src/ops/image_ops.ts | 31 +++++---- tfjs-core/src/ops/image_ops_test.ts | 33 ++++------ 10 files changed, 193 insertions(+), 92 deletions(-) create mode 100644 tfjs-core/src/backends/cpu/non_max_suppression_with_score.ts create mode 100644 tfjs-core/src/backends/webgl/non_max_suppression_with_score.ts diff --git a/tfjs-core/src/backends/backend.ts b/tfjs-core/src/backends/backend.ts index 4c11114780f..fd3dbbaab3d 100644 --- a/tfjs-core/src/backends/backend.ts +++ b/tfjs-core/src/backends/backend.ts @@ -594,13 +594,6 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer { return notYetImplemented('nonMaxSuppression'); } - nonMaxSuppressionWithScore( - boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, - iouThreshold: number, scoreThreshold?: number, - softNmsSigma?: number): [Tensor1D, Tensor1D, Scalar] { - return notYetImplemented('nonMaxSuppression'); - } - fft(x: Tensor2D): Tensor2D { return notYetImplemented('fft'); } diff --git a/tfjs-core/src/backends/cpu/all_kernels.ts b/tfjs-core/src/backends/cpu/all_kernels.ts index 8d59f9ee9dd..6a7b3b68e81 100644 --- a/tfjs-core/src/backends/cpu/all_kernels.ts +++ b/tfjs-core/src/backends/cpu/all_kernels.ts @@ -19,3 +19,4 @@ // global registry when we compile the library. A modular build would replace // the contents of this file and import only the kernels that are needed. import './square'; +import './non_max_suppression_with_score'; diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index ec9af4eb7f0..ee2a1e6772a 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -42,7 +42,7 @@ import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util'; import {BackendTimingInfo, DataStorage, EPSILON_FLOAT32, KernelBackend} from '../backend'; import * as backend_util from '../backend_util'; import * as complex_util from '../complex_util'; -import {nonMaxSuppressionV3, nonMaxSuppressionV5} from '../non_max_suppression_impl'; +import {nonMaxSuppressionV3} from '../non_max_suppression_impl'; import {split} from '../split_shared'; import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; @@ -3279,19 +3279,6 @@ export class MathBackendCPU extends KernelBackend { boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); } - nonMaxSuppressionWithScore( - boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, - iouThreshold: number, scoreThreshold: number, - softNmsSigma: number): [Tensor1D, Tensor1D, Scalar] { - assertNotComplex(boxes, 'nonMaxSuppressionWithScore'); - - const boxesVals = this.readSync(boxes.dataId) as TypedArray; - const scoresVals = this.readSync(scores.dataId) as TypedArray; - return nonMaxSuppressionV5( - boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, - softNmsSigma); - } - fft(x: Tensor2D): Tensor2D { return this.fftBatch(x, false); } diff --git a/tfjs-core/src/backends/cpu/non_max_suppression_with_score.ts b/tfjs-core/src/backends/cpu/non_max_suppression_with_score.ts new file mode 100644 index 00000000000..1a65c2f6682 --- /dev/null +++ b/tfjs-core/src/backends/cpu/non_max_suppression_with_score.ts @@ -0,0 +1,63 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo} from '../../kernel_registry'; +import {TypedArray} from '../../types'; +import {nonMaxSuppressionV5} from '../non_max_suppression_impl'; + +import {MathBackendCPU} from './backend_cpu'; +import {assertNotComplex} from './cpu_util'; + +interface NonMaxSuppressionWithScoreInputs extends NamedTensorInfoMap { + boxes: TensorInfo; + scores: TensorInfo; +} + +interface NonMaxSuppressionWithScoreAttrs extends NamedAttrMap { + maxOutputSize: number; + iouThreshold: number; + scoreThreshold: number; + softNmsSigma: number; +} + +registerKernel({ + kernelName: 'NonMaxSuppressionWithScore', + backendName: 'cpu', + kernelFunc: ({inputs, backend, attrs}) => { + const {boxes, scores} = inputs as NonMaxSuppressionWithScoreInputs; + const {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma} = + attrs as NonMaxSuppressionWithScoreAttrs; + + const cpuBackend = backend as MathBackendCPU; + + assertNotComplex(boxes, 'NonMaxSuppressionWithScore'); + + const boxesVals = cpuBackend.data.get(boxes.dataId).values as TypedArray; + const scoresVals = cpuBackend.data.get(scores.dataId).values as TypedArray; + + const maxOutputSizeVal = maxOutputSize; + const iouThresholdVal = iouThreshold; + const scoreThresholdVal = scoreThreshold; + const softNmsSigmaVal = softNmsSigma; + + const {selectedIndices, selectedScores} = nonMaxSuppressionV5( + boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, + scoreThresholdVal, softNmsSigmaVal); + + return [selectedIndices, selectedScores]; + } +}); diff --git a/tfjs-core/src/backends/non_max_suppression_impl.ts b/tfjs-core/src/backends/non_max_suppression_impl.ts index 5fbc2b05bf7..3f7bb8d8961 100644 --- a/tfjs-core/src/backends/non_max_suppression_impl.ts +++ b/tfjs-core/src/backends/non_max_suppression_impl.ts @@ -20,7 +20,8 @@ */ import {scalar, tensor1d} from '../ops/tensor_ops'; -import {Scalar, Tensor1D} from '../tensor'; +import {Tensor1D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {TypedArray} from '../types'; import {binaryInsert} from './array_util'; @@ -37,28 +38,35 @@ export function nonMaxSuppressionV3( const dummySoftNmsSigma = 0.0; return nonMaxSuppressionImpl_( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, - dummySoftNmsSigma)[0]; + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, + dummySoftNmsSigma) + .selectedIndices as Tensor1D; } export function nonMaxSuppressionV5( boxes: TypedArray, scores: TypedArray, maxOutputSize: number, iouThreshold: number, scoreThreshold: number, - softNmsSigma: number): [Tensor1D, Tensor1D, Scalar] { + softNmsSigma: number): NamedTensorMap { // For NonMaxSuppressionV5Op, we always return a second output holding // corresponding scores. const returnScoresTensor = true; - return nonMaxSuppressionImpl_( + const result = nonMaxSuppressionImpl_( boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, returnScoresTensor); + + result.numValidOutputs.dispose(); + + return { + selectedIndices: result.selectedIndices, + selectedScores: result.selectedScores + }; } function nonMaxSuppressionImpl_( boxes: TypedArray, scores: TypedArray, maxOutputSize: number, iouThreshold: number, scoreThreshold: number, softNmsSigma: number, - returnScoresTensor = false, - padToMaxOutputSize = false): [Tensor1D, Tensor1D, Scalar] { + returnScoresTensor = false, padToMaxOutputSize = false): NamedTensorMap { // The list is sorted in ascending order, so that we can always pop the // candidate with the largest score in O(1) time. const candidates = @@ -71,10 +79,10 @@ function nonMaxSuppressionImpl_( // before. const scale = softNmsSigma > 0 ? (-0.5 / softNmsSigma) : 0.0; - const selected: number[] = []; - const selectedScore: number[] = []; + const selectedIndices: number[] = []; + const selectedScores: number[] = []; - while (selected.length < maxOutputSize && candidates.length > 0) { + while (selectedIndices.length < maxOutputSize && candidates.length > 0) { const candidate = candidates.pop(); const {score: originalScore, boxIndex, suppressBeginIndex} = candidate; @@ -89,8 +97,8 @@ function nonMaxSuppressionImpl_( // by a selected box no more than once. Also, if the overlap exceeds // iouThreshold, we simply ignore the candidate. let ignoreCandidate = false; - for (let j = selected.length - 1; j >= suppressBeginIndex; --j) { - const iou = intersectionOverUnion(boxes, boxIndex, selected[j]); + for (let j = selectedIndices.length - 1; j >= suppressBeginIndex; --j) { + const iou = intersectionOverUnion(boxes, boxIndex, selectedIndices[j]); if (iou >= iouThreshold) { ignoreCandidate = true; @@ -112,14 +120,14 @@ function nonMaxSuppressionImpl_( // the updated score and suppressBeginIndex back in the candidate list. // If on the other hand, `candidate.score` has dropped below the score // threshold, we will not add it back to the candidates list. - candidate.suppressBeginIndex = selected.length; + candidate.suppressBeginIndex = selectedIndices.length; if (!ignoreCandidate) { // Candidate has passed all the tests, and is not suppressed, so // select the candidate. if (candidate.score === originalScore) { - selected.push(boxIndex); - selectedScore.push(candidate.score); + selectedIndices.push(boxIndex); + selectedScores.push(candidate.score); } else if (candidate.score > scoreThreshold) { // Candidate's score is suppressed but is still high enough to be // considered, so add back to the candidates list. @@ -129,16 +137,17 @@ function nonMaxSuppressionImpl_( } // NonMaxSuppressionV4 feature: padding output to maxOutputSize. - const numValidOutputs = selected.length; + const numValidOutputs = selectedIndices.length; if (padToMaxOutputSize) { - selected.fill(0, numValidOutputs); - selectedScore.fill(0.0, numValidOutputs); + selectedIndices.fill(0, numValidOutputs); + selectedScores.fill(0.0, numValidOutputs); } - return [ - tensor1d(selected, 'int32'), tensor1d(selectedScore, 'float32'), - scalar(numValidOutputs, 'int32') - ]; + return { + selectedIndices: tensor1d(selectedIndices, 'int32'), + selectedScores: tensor1d(selectedScores, 'float32'), + numValidOutputs: scalar(numValidOutputs, 'int32') + }; } function intersectionOverUnion(boxes: TypedArray, i: number, j: number) { diff --git a/tfjs-core/src/backends/webgl/all_kernels.ts b/tfjs-core/src/backends/webgl/all_kernels.ts index 973dacddad2..64339bf260e 100644 --- a/tfjs-core/src/backends/webgl/all_kernels.ts +++ b/tfjs-core/src/backends/webgl/all_kernels.ts @@ -20,3 +20,4 @@ // the contents of this file and import only the kernels that are needed. import './square'; import './fromPixels'; +import './non_max_suppression_with_score'; diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index f0fe05c5f46..2bce838b3aa 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -45,7 +45,7 @@ import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} fr import {DataStorage, EPSILON_FLOAT16, EPSILON_FLOAT32, KernelBackend} from '../backend'; import * as backend_util from '../backend_util'; import {mergeRealAndImagArrays} from '../complex_util'; -import {nonMaxSuppressionV3, nonMaxSuppressionV5} from '../non_max_suppression_impl'; +import {nonMaxSuppressionV3} from '../non_max_suppression_impl'; import {split} from '../split_shared'; import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; @@ -2254,20 +2254,6 @@ export class MathBackendWebGL extends KernelBackend { boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); } - nonMaxSuppressionWithScore( - boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, - iouThreshold: number, scoreThreshold: number, - softNmsSigma: number): [Tensor1D, Tensor1D, Scalar] { - warn( - 'tf.nonMaxSuppression() in webgl locks the UI thread. ' + - 'Call tf.nonMaxSuppressionAsync() instead'); - const boxesVals = boxes.dataSync(); - const scoresVals = scores.dataSync(); - return nonMaxSuppressionV5( - boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, - softNmsSigma); - } - cropAndResize( image: Tensor4D, boxes: Tensor2D, boxIndex: Tensor1D, cropSize: [number, number], method: 'bilinear'|'nearest', diff --git a/tfjs-core/src/backends/webgl/non_max_suppression_with_score.ts b/tfjs-core/src/backends/webgl/non_max_suppression_with_score.ts new file mode 100644 index 00000000000..291c1610433 --- /dev/null +++ b/tfjs-core/src/backends/webgl/non_max_suppression_with_score.ts @@ -0,0 +1,65 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo} from '../../kernel_registry'; +import {warn} from '../../log'; +import {TypedArray} from '../../types'; +import {nonMaxSuppressionV5} from '../non_max_suppression_impl'; + +import {MathBackendWebGL} from './backend_webgl'; + +interface NonMaxSuppressionWithScoreInputs extends NamedTensorInfoMap { + boxes: TensorInfo; + scores: TensorInfo; +} + +interface NonMaxSuppressionWithScoreAttrs extends NamedAttrMap { + maxOutputSize: number; + iouThreshold: number; + scoreThreshold: number; + softNmsSigma: number; +} + +registerKernel({ + kernelName: 'NonMaxSuppressionWithScore', + backendName: 'webgl', + kernelFunc: ({inputs, backend, attrs}) => { + warn( + 'tf.nonMaxSuppression() in webgl locks the UI thread. ' + + 'Call tf.nonMaxSuppressionAsync() instead'); + + const {boxes, scores} = inputs as NonMaxSuppressionWithScoreInputs; + const {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma} = + attrs as NonMaxSuppressionWithScoreAttrs; + + const gpuBackend = backend as MathBackendWebGL; + + const boxesVals = gpuBackend.readSync(boxes.dataId) as TypedArray; + const scoresVals = gpuBackend.readSync(scores.dataId) as TypedArray; + + const maxOutputSizeVal = maxOutputSize; + const iouThresholdVal = iouThreshold; + const scoreThresholdVal = scoreThreshold; + const softNmsSigmaVal = softNmsSigma; + + const {selectedIndices, selectedScores} = nonMaxSuppressionV5( + boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, + scoreThresholdVal, softNmsSigmaVal); + + return [selectedIndices, selectedScores]; + } +}); diff --git a/tfjs-core/src/ops/image_ops.ts b/tfjs-core/src/ops/image_ops.ts index eea8d6fdd9d..f4ab2fb6f1e 100644 --- a/tfjs-core/src/ops/image_ops.ts +++ b/tfjs-core/src/ops/image_ops.ts @@ -18,12 +18,14 @@ import {nonMaxSuppressionV3, nonMaxSuppressionV5} from '../backends/non_max_suppression_impl'; import {ENGINE, ForwardFunc} from '../engine'; import {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; import {op} from './operation'; + /** * Bilinear resize a batch of 3D images to a new shape. * @@ -232,18 +234,17 @@ async function nonMaxSuppressionAsync_( * on score. Defaults to -inf, which means any score is accepted. * @param softNmsSigma A float representing the sigma parameter for Soft NMS. * When sigma is 0, it falls back to nonMaxSuppression. - * @return An `Array` of - * - A 1D tensor with the selected box indices. - * - A 1D tensor with the corresponding scores for each selected box. - * - A number representing the number of valid elements in the selected - * box indices. + * @return A map with the following properties: + * - selectedIndices: A 1D tensor with the selected box indices. + * - selectedScores: A 1D tensor with the corresponding scores for each + * selected box. */ /** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */ function nonMaxSuppressionWithScore_( boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike, maxOutputSize: number, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, - softNmsSigma = 0.0): [Tensor1D, Tensor1D, Tensor] { + softNmsSigma = 0.0): NamedTensorMap { const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression'); const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression'); @@ -253,14 +254,15 @@ function nonMaxSuppressionWithScore_( maxOutputSize = inputs.maxOutputSize; iouThreshold = inputs.iouThreshold; scoreThreshold = inputs.scoreThreshold; + softNmsSigma = inputs.softNmsSigma; - const attrs = {maxOutputSize, iouThreshold, scoreThreshold}; - return ENGINE.runKernelFunc( - b => b.nonMaxSuppressionWithScore( - $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, - softNmsSigma), - {boxes: $boxes, scores: $scores}, null /* grad */, 'NonMaxSuppressionV5', - attrs); + const attrs = {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma}; + + const result = ENGINE.runKernel( + 'NonMaxSuppressionWithScore', + {boxes: $boxes, scores: $scores}, attrs) as Tensor[]; + + return {selectedIndices: result[0], selectedScores: result[1]}; } /** This is the async version of `nonMaxSuppressionWithScore` */ @@ -268,7 +270,7 @@ async function nonMaxSuppressionWithScoreAsync_( boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike, maxOutputSize: number, iouThreshold = 0.5, scoreThreshold = Number.NEGATIVE_INFINITY, - softNmsSigma = 0.0): Promise<[Tensor1D, Tensor1D, Tensor]> { + softNmsSigma = 0.0): Promise { const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync'); const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync'); @@ -287,6 +289,7 @@ async function nonMaxSuppressionWithScoreAsync_( const res = nonMaxSuppressionV5( boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma); + if ($boxes !== boxes) { $boxes.dispose(); } diff --git a/tfjs-core/src/ops/image_ops_test.ts b/tfjs-core/src/ops/image_ops_test.ts index 0e1908d865d..39905d525e8 100644 --- a/tfjs-core/src/ops/image_ops_test.ts +++ b/tfjs-core/src/ops/image_ops_test.ts @@ -188,19 +188,16 @@ describeWithFlags('nonMaxSuppression', ALL_ENVS, () => { const iouThreshold = 1.0; const scoreThreshold = 0; const softNmsSigma = 0.5; - const indices = tf.image.nonMaxSuppressionWithScore( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, - softNmsSigma); - expect(indices[0].shape).toEqual([6]); - expectArraysEqual(await indices[0].data(), [3, 0, 1, 5, 4, 2]); + const {selectedIndices, selectedScores} = + tf.image.nonMaxSuppressionWithScore( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, + softNmsSigma); - expect(indices[1].shape).toEqual([6]); - expectArraysClose( - await indices[1].data(), [0.95, 0.9, 0.384, 0.3, 0.256, 0.197]); + expectArraysEqual(await selectedIndices.data(), [3, 0, 1, 5, 4, 2]); - expect(indices[2].shape).toEqual([]); - expectArraysEqual(await indices[2].data(), [6]); + expectArraysClose( + await selectedScores.data(), [0.95, 0.9, 0.384, 0.3, 0.256, 0.197]); }); }); }); @@ -248,19 +245,15 @@ describeWithFlags('nonMaxSuppressionAsync', ALL_ENVS, () => { const iouThreshold = 1.0; const scoreThreshold = 0; const softNmsSigma = 0.5; - const indices = await tf.image.nonMaxSuppressionWithScoreAsync( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, - softNmsSigma); + const {selectedIndices, selectedScores} = + await tf.image.nonMaxSuppressionWithScoreAsync( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, + softNmsSigma); - expect(indices[0].shape).toEqual([6]); - expectArraysEqual(await indices[0].data(), [3, 0, 1, 5, 4, 2]); + expectArraysEqual(await selectedIndices.data(), [3, 0, 1, 5, 4, 2]); - expect(indices[1].shape).toEqual([6]); expectArraysClose( - await indices[1].data(), [0.95, 0.9, 0.384, 0.3, 0.256, 0.197]); - - expect(indices[2].shape).toEqual([]); - expectArraysEqual(await indices[2].data(), [6]); + await selectedScores.data(), [0.95, 0.9, 0.384, 0.3, 0.256, 0.197]); }); }); }); From 6700aaf7d4e7d905ad207091b6a9048335bec900 Mon Sep 17 00:00:00 2001 From: Na Li Date: Mon, 16 Dec 2019 15:01:02 -0800 Subject: [PATCH 3/5] Add memory leak test for async method. --- tfjs-core/src/ops/image_ops_test.ts | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tfjs-core/src/ops/image_ops_test.ts b/tfjs-core/src/ops/image_ops_test.ts index 39905d525e8..0e40a73124c 100644 --- a/tfjs-core/src/ops/image_ops_test.ts +++ b/tfjs-core/src/ops/image_ops_test.ts @@ -245,15 +245,24 @@ describeWithFlags('nonMaxSuppressionAsync', ALL_ENVS, () => { const iouThreshold = 1.0; const scoreThreshold = 0; const softNmsSigma = 0.5; + + const numTensorsBefore = tf.memory().numTensors; + const {selectedIndices, selectedScores} = await tf.image.nonMaxSuppressionWithScoreAsync( boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma); + const numTensorsAfter = tf.memory().numTensors; + expectArraysEqual(await selectedIndices.data(), [3, 0, 1, 5, 4, 2]); expectArraysClose( await selectedScores.data(), [0.95, 0.9, 0.384, 0.3, 0.256, 0.197]); + + // The number of tensors should increase by the number of tensors + // returned (i.e. selectedIndices and selectedScores). + expect(numTensorsAfter).toEqual(numTensorsBefore + 2); }); }); }); From ac933fa29c2487779ae2d676ebd81d7dffcf3033 Mon Sep 17 00:00:00 2001 From: Na Li Date: Mon, 16 Dec 2019 16:42:52 -0800 Subject: [PATCH 4/5] Fix lint --- tfjs-core/src/ops/image_ops.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/tfjs-core/src/ops/image_ops.ts b/tfjs-core/src/ops/image_ops.ts index f4ab2fb6f1e..5c0effb4bc3 100644 --- a/tfjs-core/src/ops/image_ops.ts +++ b/tfjs-core/src/ops/image_ops.ts @@ -25,7 +25,6 @@ import * as util from '../util'; import {op} from './operation'; - /** * Bilinear resize a batch of 3D images to a new shape. * From 026409b273fd72a4d42265676a6e680a83b647ca Mon Sep 17 00:00:00 2001 From: Na Li Date: Tue, 17 Dec 2019 13:05:49 -0800 Subject: [PATCH 5/5] Use same kernal name as TF C++ version --- tfjs-core/src/backends/cpu/all_kernels.ts | 2 +- ...ax_suppression_with_score.ts => non_max_suppression_v5.ts} | 2 +- tfjs-core/src/backends/webgl/all_kernels.ts | 2 +- ...ax_suppression_with_score.ts => non_max_suppression_v5.ts} | 2 +- tfjs-core/src/ops/image_ops.ts | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) rename tfjs-core/src/backends/cpu/{non_max_suppression_with_score.ts => non_max_suppression_v5.ts} (98%) rename tfjs-core/src/backends/webgl/{non_max_suppression_with_score.ts => non_max_suppression_v5.ts} (98%) diff --git a/tfjs-core/src/backends/cpu/all_kernels.ts b/tfjs-core/src/backends/cpu/all_kernels.ts index 6a7b3b68e81..75920b35b88 100644 --- a/tfjs-core/src/backends/cpu/all_kernels.ts +++ b/tfjs-core/src/backends/cpu/all_kernels.ts @@ -19,4 +19,4 @@ // global registry when we compile the library. A modular build would replace // the contents of this file and import only the kernels that are needed. import './square'; -import './non_max_suppression_with_score'; +import './non_max_suppression_v5'; diff --git a/tfjs-core/src/backends/cpu/non_max_suppression_with_score.ts b/tfjs-core/src/backends/cpu/non_max_suppression_v5.ts similarity index 98% rename from tfjs-core/src/backends/cpu/non_max_suppression_with_score.ts rename to tfjs-core/src/backends/cpu/non_max_suppression_v5.ts index 1a65c2f6682..62fdb3781cf 100644 --- a/tfjs-core/src/backends/cpu/non_max_suppression_with_score.ts +++ b/tfjs-core/src/backends/cpu/non_max_suppression_v5.ts @@ -35,7 +35,7 @@ interface NonMaxSuppressionWithScoreAttrs extends NamedAttrMap { } registerKernel({ - kernelName: 'NonMaxSuppressionWithScore', + kernelName: 'NonMaxSuppressionV5', backendName: 'cpu', kernelFunc: ({inputs, backend, attrs}) => { const {boxes, scores} = inputs as NonMaxSuppressionWithScoreInputs; diff --git a/tfjs-core/src/backends/webgl/all_kernels.ts b/tfjs-core/src/backends/webgl/all_kernels.ts index 64339bf260e..724e83ce034 100644 --- a/tfjs-core/src/backends/webgl/all_kernels.ts +++ b/tfjs-core/src/backends/webgl/all_kernels.ts @@ -20,4 +20,4 @@ // the contents of this file and import only the kernels that are needed. import './square'; import './fromPixels'; -import './non_max_suppression_with_score'; +import './non_max_suppression_v5'; diff --git a/tfjs-core/src/backends/webgl/non_max_suppression_with_score.ts b/tfjs-core/src/backends/webgl/non_max_suppression_v5.ts similarity index 98% rename from tfjs-core/src/backends/webgl/non_max_suppression_with_score.ts rename to tfjs-core/src/backends/webgl/non_max_suppression_v5.ts index 291c1610433..4a8c0ac4156 100644 --- a/tfjs-core/src/backends/webgl/non_max_suppression_with_score.ts +++ b/tfjs-core/src/backends/webgl/non_max_suppression_v5.ts @@ -35,7 +35,7 @@ interface NonMaxSuppressionWithScoreAttrs extends NamedAttrMap { } registerKernel({ - kernelName: 'NonMaxSuppressionWithScore', + kernelName: 'NonMaxSuppressionV5', backendName: 'webgl', kernelFunc: ({inputs, backend, attrs}) => { warn( diff --git a/tfjs-core/src/ops/image_ops.ts b/tfjs-core/src/ops/image_ops.ts index 5c0effb4bc3..195569ee54b 100644 --- a/tfjs-core/src/ops/image_ops.ts +++ b/tfjs-core/src/ops/image_ops.ts @@ -258,8 +258,8 @@ function nonMaxSuppressionWithScore_( const attrs = {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma}; const result = ENGINE.runKernel( - 'NonMaxSuppressionWithScore', - {boxes: $boxes, scores: $scores}, attrs) as Tensor[]; + 'NonMaxSuppressionV5', {boxes: $boxes, scores: $scores}, + attrs) as Tensor[]; return {selectedIndices: result[0], selectedScores: result[1]}; }