From 02738e2bd50366fe7a6b659da46638000863fa54 Mon Sep 17 00:00:00 2001 From: Na Li Date: Mon, 9 Dec 2019 16:06:25 -0800 Subject: [PATCH 1/7] Add support to NonMaxSuppressionV5 --- tfjs-core/src/backends/array_util.ts | 72 +++++++ tfjs-core/src/backends/array_util_test.ts | 176 ++++++++++++++++++ tfjs-core/src/backends/backend.ts | 3 +- tfjs-core/src/backends/cpu/backend_cpu.ts | 6 +- .../src/backends/non_max_suppression_impl.ts | 83 +++++++-- tfjs-core/src/backends/webgl/backend_webgl.ts | 6 +- tfjs-core/src/ops/image_ops.ts | 38 +++- tfjs-core/src/ops/image_ops_test.ts | 55 ++++-- 8 files changed, 398 insertions(+), 41 deletions(-) create mode 100644 tfjs-core/src/backends/array_util.ts create mode 100644 tfjs-core/src/backends/array_util_test.ts diff --git a/tfjs-core/src/backends/array_util.ts b/tfjs-core/src/backends/array_util.ts new file mode 100644 index 00000000000..b615b735330 --- /dev/null +++ b/tfjs-core/src/backends/array_util.ts @@ -0,0 +1,72 @@ +/** + * Inserts a value into a sorted array. This method allows duplicate, meaning it + * allows inserting duplicate value, in which case, the element will be inserted + * at the lowest index of the value. + * @param arr The array to modify. + * @param element The element to insert. + * @param comparator Optional. If no comparator is specified, elements are + * compared using array_util.defaultComparator, which is suitable for Strings + * and Numbers in ascending arrays. If the array contains multiple instances of + * the target value, the left-most instance will be returned. To provide a + * comparator, it should take 2 arguments to compare and return a negative, + * zero, or a positive number. + */ +export function binaryInsert( + arr: T[], element: T, comparator?: (a: T, b: T) => number) { + const index = binarySearch(arr, element, comparator); + const insertionPoint = index < 0 ? -(index + 1) : index; + arr.splice(insertionPoint, 0, element); +} + +/** + * Searches the array for the target using binary search, returns the index + * of the found element, or position to insert if element not found. If no + * comparator is specified, elements are compared using array_ + * util.defaultComparator, which is suitable for Strings and Numbers in + * ascending arrays. If the array contains multiple instances of the target + * value, the left-most instance will be returned. + * @param arr The array to be searched in. + * @param target The target to be searched for. + * @param comparator Should take 2 arguments to compare and return a negative, + * zero, or a positive number. + * @return Lowest index of the target value if found, otherwise the insertion + * point where the target should be inserted, in the form of + * (-insertionPoint - 1). + */ +export function binarySearch( + arr: T[], target: T, comparator?: (a: T, b: T) => number) { + return binarySearch_(arr, target, comparator || defaultComparator); +} + +/** + * Compares its two arguments for order. + * @param a The first element to be compared. + * @param b The second element to be compared. + * @return A negative number, zero, or a positive number as the first + * argument is less than, equal to, or greater than the second. + */ +function defaultComparator(a: T, b: T): number { + return a > b ? 1 : a < b ? -1 : 0; +} + +function binarySearch_( + arr: T[], target: T, comparator: (a: T, b: T) => number) { + let left = 0; + let right = arr.length; + let middle = 0; + let found = false; + while (left < right) { + middle = left + ((right - left) >>> 1); + const compareResult = comparator(target, arr[middle]); + if (compareResult > 0) { + left = middle + 1; + } else { + right = middle; + // If compareResult is 0, the value is found. We record it is found, + // and then keep looking because there may be duplicate. + found = !compareResult; + } + } + + return found ? left : -left - 1; +} diff --git a/tfjs-core/src/backends/array_util_test.ts b/tfjs-core/src/backends/array_util_test.ts new file mode 100644 index 00000000000..8d765a0a06d --- /dev/null +++ b/tfjs-core/src/backends/array_util_test.ts @@ -0,0 +1,176 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as array_util from './array_util'; + +describe('array_util', () => { + const insertionPoint = (i: number) => -(i + 1); + + describe('binarySearch', () => { + const d = [ + -897123.9, -321434.58758, -1321.3124, -324, -9, -3, 0, 0, 0, 0.31255, 5, + 142.88888708, 334, 342, 453, 54254 + ]; + + it('-897123.9 should be found at index 0', () => { + const result = array_util.binarySearch(d, -897123.9); + expect(result).toBe(0); + }); + + it(`54254 should be found at index ${d.length - 1}`, () => { + const result = array_util.binarySearch(d, 54254); + expect(result).toBe(d.length - 1); + }); + + it('-3 should be found at index 5', () => { + const result = array_util.binarySearch(d, -3); + expect(result).toBe(5); + }); + + it('0 should be found at index 6', () => { + const result = array_util.binarySearch(d, 0); + expect(result).toBe(6); + }); + + it('-900000 should have an insertion point of 0', () => { + const result = array_util.binarySearch(d, -900000); + expect(result).toBeLessThan(0); + expect(insertionPoint(result)).toBe(0); + }); + + it(`54255 should have an insertion point of ${d.length}`, () => { + const result = array_util.binarySearch(d, 54255); + expect(result).toBeLessThan(0); + expect(insertionPoint(result)).toEqual(d.length); + }); + + it('1.1 should have an insertion point of 10', () => { + const result = array_util.binarySearch(d, 1.1); + expect(result).toBeLessThan(0); + expect(insertionPoint(result)).toEqual(10); + }); + }); + + describe('binarySearch with custom comparator', () => { + const e = [ + 54254, + 453, + 342, + 334, + 142.88888708, + 5, + 0.31255, + 0, + 0, + 0, + -3, + -9, + -324, + -1321.3124, + -321434.58758, + -897123.9, + ]; + + const revComparator = (a: number, b: number) => (b - a); + + it('54254 should be found at index 0', () => { + const result = array_util.binarySearch(e, 54254, revComparator); + expect(result).toBe(0); + }); + + it(`-897123.9 should be found at index ${e.length - 1}`, () => { + const result = array_util.binarySearch(e, -897123.9, revComparator); + expect(result).toBe(e.length - 1); + }); + + it('-3 should be found at index 10', () => { + const result = array_util.binarySearch(e, -3, revComparator); + expect(result).toBe(10); + }); + + it('0 should be found at index 7', () => { + const result = array_util.binarySearch(e, 0, revComparator); + expect(result).toBe(7); + }); + + it('54254.1 should have an insertion point of 0', () => { + const result = array_util.binarySearch(e, 54254.1, revComparator); + expect(result).toBeLessThan(0); + expect(insertionPoint(result)).toBe(0); + }); + + it(`-897124 should have an insertion point of ${e.length}`, () => { + const result = array_util.binarySearch(e, -897124, revComparator); + expect(result).toBeLessThan(0); + expect(insertionPoint(result)).toBe(e.length); + }); + }); + + describe( + 'binarySearch with custom comparator with single element array', () => { + const g = [1]; + + const revComparator = (a: number, b: number) => (b - a); + + it('1 should be found at index 0', () => { + const result = array_util.binarySearch(g, 1, revComparator); + expect(result).toBe(0); + }); + + it('2 should have an insertion point of 0', () => { + const result = array_util.binarySearch(g, 2, revComparator); + expect(result).toBeLessThan(0); + expect(insertionPoint(result)).toBe(0); + }); + + it('0 should have an insertion point of 1', () => { + const result = array_util.binarySearch(g, 0, revComparator); + expect(result).toBeLessThan(0); + expect(insertionPoint(result)).toBe(1); + }); + }); + + describe('binarySearch test left-most duplicated element', () => { + it('should find the index of the first 0', () => { + const result = array_util.binarySearch([0, 0, 1], 0); + expect(result).toBe(0); + }); + + it('should find the index of the first 1', () => { + const result = array_util.binarySearch([0, 1, 1], 1); + expect(result).toBe(1); + }); + }); + + describe('binaryInsert', () => { + it('inserts correctly', () => { + const a: number[] = []; + + array_util.binaryInsert(a, 3); + expect(a).toEqual([3]); + + array_util.binaryInsert(a, 3); + expect(a).toEqual([3, 3]); + + array_util.binaryInsert(a, 1); + expect(a).toEqual([1, 3, 3]); + + array_util.binaryInsert(a, 5); + expect(a).toEqual([1, 3, 3, 5]); + }); + }); +}); diff --git a/tfjs-core/src/backends/backend.ts b/tfjs-core/src/backends/backend.ts index fd3dbbaab3d..c3c212626f1 100644 --- a/tfjs-core/src/backends/backend.ts +++ b/tfjs-core/src/backends/backend.ts @@ -590,7 +590,8 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer { nonMaxSuppression( boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, - iouThreshold: number, scoreThreshold?: number): Tensor1D { + iouThreshold: number, scoreThreshold?: number, + softNmsSigma?: number): Tensor1D { return notYetImplemented('nonMaxSuppression'); } diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index 81c502aca01..5b04218c3ac 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -3270,13 +3270,15 @@ export class MathBackendCPU extends KernelBackend { nonMaxSuppression( boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, - iouThreshold: number, scoreThreshold: number): Tensor1D { + iouThreshold: number, scoreThreshold: number, + softNmsSigma: number): Tensor1D { assertNotComplex(boxes, 'nonMaxSuppression'); const boxesVals = this.readSync(boxes.dataId) as TypedArray; const scoresVals = this.readSync(scores.dataId) as TypedArray; return nonMaxSuppressionImpl( - boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); + boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, + softNmsSigma); } fft(x: Tensor2D): Tensor2D { diff --git a/tfjs-core/src/backends/non_max_suppression_impl.ts b/tfjs-core/src/backends/non_max_suppression_impl.ts index 2a2a7982b15..ddc2134c330 100644 --- a/tfjs-core/src/backends/non_max_suppression_impl.ts +++ b/tfjs-core/src/backends/non_max_suppression_impl.ts @@ -23,35 +23,81 @@ import {tensor1d} from '../ops/tensor_ops'; import {Tensor1D} from '../tensor'; import {TypedArray} from '../types'; +import {binaryInsert} from './array_util'; + +interface Candidate { + score: number; + boxIndex: number; + suppressBeginIndex: number; +} + export function nonMaxSuppressionImpl( boxes: TypedArray, scores: TypedArray, maxOutputSize: number, - iouThreshold: number, scoreThreshold: number): Tensor1D { - const candidates = Array.from(scores) - .map((score, boxIndex) => ({score, boxIndex})) - .filter(c => c.score > scoreThreshold) - .sort((c1, c2) => c2.score - c1.score); + iouThreshold: number, scoreThreshold: number, + softNmsSigma: number): Tensor1D { + // The list is sorted in ascending order, so that we can always pop the + // candidate with the largest score in O(1) time. + const candidates = + Array.from(scores) + .map((score, boxIndex) => ({score, boxIndex, suppressBeginIndex: 0})) + .filter(c => c.score > scoreThreshold) + .sort(ascendingComparator); + + // If softNmsSigma is 0, the outcome of this algorithm is exactly same as + // before. + const scale = softNmsSigma > 0 ? (-0.5 / softNmsSigma) : 0.0; const selected: number[] = []; - for (let i = 0; i < candidates.length; i++) { - const {score, boxIndex} = candidates[i]; - if (score < scoreThreshold) { + while (selected.length < maxOutputSize && candidates.length > 0) { + const candidate = candidates.pop(); + const {score: originalScore, boxIndex, suppressBeginIndex} = candidate; + + if (originalScore < scoreThreshold) { break; } + // Overlapping boxes are likely to have similar scores, therefore we + // iterate through the previously selected boxes backwards in order to + // see if candidate's score should be suppressed. We use suppressBeginIndex + // to track and ensure a candidate can be suppressed by a selected box no + // more than once. Also, if the overlap exceeds iouThreshold, we simply + // ignore the candidate. let ignoreCandidate = false; - for (let j = selected.length - 1; j >= 0; --j) { + for (let j = selected.length - 1; j >= suppressBeginIndex; --j) { const iou = intersectionOverUnion(boxes, boxIndex, selected[j]); + if (iou >= iouThreshold) { ignoreCandidate = true; break; } + + candidate.score = + candidate.score * suppressWeight(iouThreshold, scale, iou); + + if (candidate.score <= scoreThreshold) { + break; + } } + // At this point, if `candidate.score` has not dropped below + // `scoreThreshold`, then we know that we went through all of the previous + // selections and can safely update `suppressBeginIndex` to the end of the + // selected array. Then we can re-insert the candidate with the updated + // score and suppressBeginIndex back in the candidate list. If on the other + // hand, `candidate.score` has dropped below the score threshold, we will + // not add it back to the candidates list. + candidate.suppressBeginIndex = selected.length; + if (!ignoreCandidate) { - selected.push(boxIndex); - if (selected.length >= maxOutputSize) { - break; + // Candidate has passed all the tests, and is not suppressed, so select + // the candidate. + if (candidate.score === originalScore) { + selected.push(boxIndex); + } else if (candidate.score > scoreThreshold) { + // Candidate's score is suppressed but is still high enough to be + // considered, so add back to the candidates list. + binaryInsert(candidates, candidate, ascendingComparator); } } } @@ -83,3 +129,16 @@ function intersectionOverUnion(boxes: TypedArray, i: number, j: number) { Math.max(intersectionXmax - intersectionXmin, 0.0); return intersectionArea / (areaI + areaJ - intersectionArea); } + +// A Gaussian penalty function, this method always returns values in [0, 1]. +// The weight is a function of similarity, the more overlap two boxes are, the +// smaller the weight is, meaning highly overlapping boxe will be significantly +// penalized. On the other hand, a non-overlapping box will not be penalized. +function suppressWeight(iouThreshold: number, scale: number, iou: number) { + const weight = Math.exp(scale * iou * iou); + return iou <= iouThreshold ? weight : 0.0; +} + +function ascendingComparator(c1: Candidate, c2: Candidate) { + return c1.score - c2.score; +} diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index 5e9e6c1db9a..3bec48180ce 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -2244,14 +2244,16 @@ export class MathBackendWebGL extends KernelBackend { nonMaxSuppression( boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, - iouThreshold: number, scoreThreshold: number): Tensor1D { + iouThreshold: number, scoreThreshold: number, + softNmsSigma: number): Tensor1D { warn( 'tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead'); const boxesVals = boxes.dataSync(); const scoresVals = scores.dataSync(); return nonMaxSuppressionImpl( - boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); + boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, + softNmsSigma); } cropAndResize( diff --git a/tfjs-core/src/ops/image_ops.ts b/tfjs-core/src/ops/image_ops.ts index 336c98e7c43..2f0871129f4 100644 --- a/tfjs-core/src/ops/image_ops.ts +++ b/tfjs-core/src/ops/image_ops.ts @@ -145,7 +145,11 @@ function resizeNearestNeighbor_( /** * Performs non maximum suppression of bounding boxes based on - * iou (intersection over union) + * iou (intersection over union). This op also supports a Soft-NMS mode (c.f. + * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score + * of other overlapping boxes, therefore favoring different regions of the image + * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma` + * parameter to be larger than 0. * * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of @@ -163,12 +167,13 @@ function resizeNearestNeighbor_( function nonMaxSuppression_( boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike, maxOutputSize: number, iouThreshold = 0.5, - scoreThreshold = Number.NEGATIVE_INFINITY): Tensor1D { + scoreThreshold = Number.NEGATIVE_INFINITY, softNmsSigma = 0): Tensor1D { const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression'); const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression'); const inputs = nonMaxSuppSanityCheck( - $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold); + $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, + softNmsSigma); maxOutputSize = inputs.maxOutputSize; iouThreshold = inputs.iouThreshold; scoreThreshold = inputs.scoreThreshold; @@ -176,7 +181,8 @@ function nonMaxSuppression_( const attrs = {maxOutputSize, iouThreshold, scoreThreshold}; return ENGINE.runKernelFunc( b => b.nonMaxSuppression( - $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold), + $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, + softNmsSigma), {boxes: $boxes, scores: $scores}, null /* grad */, 'NonMaxSuppressionV3', attrs); } @@ -185,22 +191,26 @@ function nonMaxSuppression_( async function nonMaxSuppressionAsync_( boxes: Tensor2D|TensorLike, scores: Tensor1D|TensorLike, maxOutputSize: number, iouThreshold = 0.5, - scoreThreshold = Number.NEGATIVE_INFINITY): Promise { + scoreThreshold = Number.NEGATIVE_INFINITY, + softNmsSigma = 0.0): Promise { const $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync'); const $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync'); const inputs = nonMaxSuppSanityCheck( - $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold); + $boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, + softNmsSigma); maxOutputSize = inputs.maxOutputSize; iouThreshold = inputs.iouThreshold; scoreThreshold = inputs.scoreThreshold; + softNmsSigma = inputs.softNmsSigma; const boxesAndScores = await Promise.all([$boxes.data(), $scores.data()]); const boxesVals = boxesAndScores[0]; const scoresVals = boxesAndScores[1]; const res = nonMaxSuppressionImpl( - boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); + boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, + softNmsSigma); if ($boxes !== boxes) { $boxes.dispose(); } @@ -212,14 +222,22 @@ async function nonMaxSuppressionAsync_( function nonMaxSuppSanityCheck( boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number, - iouThreshold: number, scoreThreshold: number): - {maxOutputSize: number, iouThreshold: number, scoreThreshold: number} { + iouThreshold: number, scoreThreshold: number, softNmsSigma: number): { + maxOutputSize: number, + iouThreshold: number, + scoreThreshold: number, + softNmsSigma: number +} { if (iouThreshold == null) { iouThreshold = 0.5; } if (scoreThreshold == null) { scoreThreshold = Number.NEGATIVE_INFINITY; } + if (softNmsSigma == null) { + softNmsSigma = 0; + } + const numBoxes = boxes.shape[0]; maxOutputSize = Math.min(maxOutputSize, numBoxes); @@ -238,7 +256,7 @@ function nonMaxSuppSanityCheck( scores.shape[0] === numBoxes, () => `scores has incompatible shape with boxes. Expected ${numBoxes}, ` + `but was ${scores.shape[0]}`); - return {maxOutputSize, iouThreshold, scoreThreshold}; + return {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma}; } /** diff --git a/tfjs-core/src/ops/image_ops_test.ts b/tfjs-core/src/ops/image_ops_test.ts index 3adb4e2a1ec..56edbdac3a9 100644 --- a/tfjs-core/src/ops/image_ops_test.ts +++ b/tfjs-core/src/ops/image_ops_test.ts @@ -104,19 +104,6 @@ describeWithFlags('nonMaxSuppression', ALL_ENVS, () => { expectArraysEqual(await indices.data(), [0]); }); - it('select from ten identical boxes', async () => { - const boxes = tf.tensor2d([0, 0, 1, 1], [1, 4]); - const scores = tf.tensor1d([0.9]); - const maxOutputSize = 3; - const iouThreshold = 0.5; - const scoreThreshold = 0; - const indices = tf.image.nonMaxSuppression( - boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); - - expect(indices.shape).toEqual([1]); - expectArraysEqual(await indices.data(), [0]); - }); - it('select from ten identical boxes', async () => { const numBoxes = 10; const corners = new Array(numBoxes) @@ -132,7 +119,7 @@ describeWithFlags('nonMaxSuppression', ALL_ENVS, () => { boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); expect(indices.shape).toEqual([1]); - expectArraysEqual(await indices.data(), [0]); + expectArraysEqual(await indices.data(), [9]); }); it('inconsistent box and score shapes', () => { @@ -185,6 +172,26 @@ describeWithFlags('nonMaxSuppression', ALL_ENVS, () => { expect(indices.dtype).toEqual('int32'); expectArraysEqual(await indices.data(), [1, 0]); }); + + it('select from three clusters with SoftNMS', async () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 6; + const iouThreshold = 1.0; + const scoreThreshold = 0; + const softNmsSigma = 0.5; + const indices = tf.image.nonMaxSuppression( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, + softNmsSigma); + + expect(indices.shape).toEqual([6]); + expectArraysEqual(await indices.data(), [3, 0, 1, 5, 4, 2]); + }); }); describeWithFlags('nonMaxSuppressionAsync', ALL_ENVS, () => { @@ -214,6 +221,26 @@ describeWithFlags('nonMaxSuppressionAsync', ALL_ENVS, () => { expect(indices.dtype).toEqual('int32'); expectArraysEqual(await indices.data(), [1, 0]); }); + + it('select from three clusters with SoftNMS', async () => { + const boxes = tf.tensor2d( + [ + 0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, + 0, 10, 1, 11, 0, 10.1, 1, 11.1, 0, 100, 1, 101 + ], + [6, 4]); + const scores = tf.tensor1d([0.9, 0.75, 0.6, 0.95, 0.5, 0.3]); + const maxOutputSize = 6; + const iouThreshold = 1.0; + const scoreThreshold = 0; + const softNmsSigma = 0.5; + const indices = await tf.image.nonMaxSuppressionAsync( + boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, + softNmsSigma); + + expect(indices.shape).toEqual([6]); + expectArraysEqual(await indices.data(), [3, 0, 1, 5, 4, 2]); + }); }); describeWithFlags('cropAndResize', ALL_ENVS, () => { From e35e897454ee21d592a2ed722fbc1ede61078ea7 Mon Sep 17 00:00:00 2001 From: Na Li Date: Mon, 9 Dec 2019 16:24:56 -0800 Subject: [PATCH 2/7] Add license info to array_util --- tfjs-core/src/backends/array_util.ts | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tfjs-core/src/backends/array_util.ts b/tfjs-core/src/backends/array_util.ts index b615b735330..a59112a491c 100644 --- a/tfjs-core/src/backends/array_util.ts +++ b/tfjs-core/src/backends/array_util.ts @@ -1,3 +1,20 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** * Inserts a value into a sorted array. This method allows duplicate, meaning it * allows inserting duplicate value, in which case, the element will be inserted From 337896d1b210840cdc7be951f49fb4491408b5f4 Mon Sep 17 00:00:00 2001 From: Na Li Date: Tue, 10 Dec 2019 10:44:06 -0800 Subject: [PATCH 3/7] Make sorting stable for duplicates --- tfjs-core/src/backends/non_max_suppression_impl.ts | 7 ++++++- tfjs-core/src/ops/image_ops_test.ts | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/backends/non_max_suppression_impl.ts b/tfjs-core/src/backends/non_max_suppression_impl.ts index ddc2134c330..ac6d5d02476 100644 --- a/tfjs-core/src/backends/non_max_suppression_impl.ts +++ b/tfjs-core/src/backends/non_max_suppression_impl.ts @@ -140,5 +140,10 @@ function suppressWeight(iouThreshold: number, scale: number, iou: number) { } function ascendingComparator(c1: Candidate, c2: Candidate) { - return c1.score - c2.score; + // For objects with same scores, we make the object with the larger index go + // first. In an array that pops from the end, this means that the object with + // the smaller index will be popped first. This ensures the same output as + // the TensorFlow python version. + return (c1.score - c2.score) || + ((c1.score === c2.score) && (c2.boxIndex - c1.boxIndex)); } diff --git a/tfjs-core/src/ops/image_ops_test.ts b/tfjs-core/src/ops/image_ops_test.ts index 56edbdac3a9..ae5d4cc6354 100644 --- a/tfjs-core/src/ops/image_ops_test.ts +++ b/tfjs-core/src/ops/image_ops_test.ts @@ -119,7 +119,7 @@ describeWithFlags('nonMaxSuppression', ALL_ENVS, () => { boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); expect(indices.shape).toEqual([1]); - expectArraysEqual(await indices.data(), [9]); + expectArraysEqual(await indices.data(), [0]); }); it('inconsistent box and score shapes', () => { From 717f4a2afc8608d3059f1274254de961a92922b1 Mon Sep 17 00:00:00 2001 From: Na Li Date: Mon, 9 Dec 2019 16:06:25 -0800 Subject: [PATCH 4/7] Add support to NonMaxSuppressionV5 --- tfjs-core/src/ops/image_ops_test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-core/src/ops/image_ops_test.ts b/tfjs-core/src/ops/image_ops_test.ts index ae5d4cc6354..56edbdac3a9 100644 --- a/tfjs-core/src/ops/image_ops_test.ts +++ b/tfjs-core/src/ops/image_ops_test.ts @@ -119,7 +119,7 @@ describeWithFlags('nonMaxSuppression', ALL_ENVS, () => { boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); expect(indices.shape).toEqual([1]); - expectArraysEqual(await indices.data(), [0]); + expectArraysEqual(await indices.data(), [9]); }); it('inconsistent box and score shapes', () => { From b225a85391c506875d54c035768b4e42a9ecdb81 Mon Sep 17 00:00:00 2001 From: Na Li Date: Tue, 10 Dec 2019 10:53:07 -0800 Subject: [PATCH 5/7] Update test --- tfjs-core/src/ops/image_ops_test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-core/src/ops/image_ops_test.ts b/tfjs-core/src/ops/image_ops_test.ts index 56edbdac3a9..ae5d4cc6354 100644 --- a/tfjs-core/src/ops/image_ops_test.ts +++ b/tfjs-core/src/ops/image_ops_test.ts @@ -119,7 +119,7 @@ describeWithFlags('nonMaxSuppression', ALL_ENVS, () => { boxes, scores, maxOutputSize, iouThreshold, scoreThreshold); expect(indices.shape).toEqual([1]); - expectArraysEqual(await indices.data(), [9]); + expectArraysEqual(await indices.data(), [0]); }); it('inconsistent box and score shapes', () => { From c23800ed3fa0c532b445b8d59c0d07118150b89a Mon Sep 17 00:00:00 2001 From: Na Li Date: Tue, 10 Dec 2019 11:02:29 -0800 Subject: [PATCH 6/7] Add range check for softNmsSigma --- tfjs-core/src/ops/image_ops.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tfjs-core/src/ops/image_ops.ts b/tfjs-core/src/ops/image_ops.ts index 2f0871129f4..71e8988e10d 100644 --- a/tfjs-core/src/ops/image_ops.ts +++ b/tfjs-core/src/ops/image_ops.ts @@ -256,6 +256,9 @@ function nonMaxSuppSanityCheck( scores.shape[0] === numBoxes, () => `scores has incompatible shape with boxes. Expected ${numBoxes}, ` + `but was ${scores.shape[0]}`); + util.assert( + 0 <= softNmsSigma && softNmsSigma <= 1, + () => `softNmsSigma must be in [0, 1], but was '${softNmsSigma}'`); return {maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma}; } From 4ec917710334d6d4dc197b5810ad3c91654c480d Mon Sep 17 00:00:00 2001 From: Na Li Date: Tue, 10 Dec 2019 11:13:41 -0800 Subject: [PATCH 7/7] Doccomment change --- tfjs-core/src/backends/array_util.ts | 2 +- tfjs-core/src/backends/array_util_test.ts | 2 +- tfjs-core/src/ops/image_ops.ts | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tfjs-core/src/backends/array_util.ts b/tfjs-core/src/backends/array_util.ts index a59112a491c..0a93f435562 100644 --- a/tfjs-core/src/backends/array_util.ts +++ b/tfjs-core/src/backends/array_util.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google LLC. All Rights Reserved. + * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at diff --git a/tfjs-core/src/backends/array_util_test.ts b/tfjs-core/src/backends/array_util_test.ts index 8d765a0a06d..d0f83382be9 100644 --- a/tfjs-core/src/backends/array_util_test.ts +++ b/tfjs-core/src/backends/array_util_test.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google LLC. All Rights Reserved. + * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at diff --git a/tfjs-core/src/ops/image_ops.ts b/tfjs-core/src/ops/image_ops.ts index 71e8988e10d..98f5a04af86 100644 --- a/tfjs-core/src/ops/image_ops.ts +++ b/tfjs-core/src/ops/image_ops.ts @@ -145,7 +145,9 @@ function resizeNearestNeighbor_( /** * Performs non maximum suppression of bounding boxes based on - * iou (intersection over union). This op also supports a Soft-NMS mode (c.f. + * iou (intersection over union). + * + * This op also supports a Soft-NMS mode (c.f. * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score * of other overlapping boxes, therefore favoring different regions of the image * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma`