Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions tfjs-core/src/backends/array_util.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

/**
* Inserts a value into a sorted array. This method allows duplicate, meaning it
* allows inserting duplicate value, in which case, the element will be inserted
* at the lowest index of the value.
* @param arr The array to modify.
* @param element The element to insert.
* @param comparator Optional. If no comparator is specified, elements are
* compared using array_util.defaultComparator, which is suitable for Strings
* and Numbers in ascending arrays. If the array contains multiple instances of
* the target value, the left-most instance will be returned. To provide a
* comparator, it should take 2 arguments to compare and return a negative,
* zero, or a positive number.
*/
export function binaryInsert<T>(
arr: T[], element: T, comparator?: (a: T, b: T) => number) {
const index = binarySearch(arr, element, comparator);
const insertionPoint = index < 0 ? -(index + 1) : index;
arr.splice(insertionPoint, 0, element);
}

/**
* Searches the array for the target using binary search, returns the index
* of the found element, or position to insert if element not found. If no
* comparator is specified, elements are compared using array_
* util.defaultComparator, which is suitable for Strings and Numbers in
* ascending arrays. If the array contains multiple instances of the target
* value, the left-most instance will be returned.
* @param arr The array to be searched in.
* @param target The target to be searched for.
* @param comparator Should take 2 arguments to compare and return a negative,
* zero, or a positive number.
* @return Lowest index of the target value if found, otherwise the insertion
* point where the target should be inserted, in the form of
* (-insertionPoint - 1).
*/
export function binarySearch<T>(
arr: T[], target: T, comparator?: (a: T, b: T) => number) {
return binarySearch_(arr, target, comparator || defaultComparator);
}

/**
* Compares its two arguments for order.
* @param a The first element to be compared.
* @param b The second element to be compared.
* @return A negative number, zero, or a positive number as the first
* argument is less than, equal to, or greater than the second.
*/
function defaultComparator<T>(a: T, b: T): number {
return a > b ? 1 : a < b ? -1 : 0;
}

function binarySearch_<T>(
arr: T[], target: T, comparator: (a: T, b: T) => number) {
let left = 0;
let right = arr.length;
let middle = 0;
let found = false;
while (left < right) {
middle = left + ((right - left) >>> 1);
const compareResult = comparator(target, arr[middle]);
if (compareResult > 0) {
left = middle + 1;
} else {
right = middle;
// If compareResult is 0, the value is found. We record it is found,
// and then keep looking because there may be duplicate.
found = !compareResult;
}
}

return found ? left : -left - 1;
}
176 changes: 176 additions & 0 deletions tfjs-core/src/backends/array_util_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as array_util from './array_util';

describe('array_util', () => {
const insertionPoint = (i: number) => -(i + 1);

describe('binarySearch', () => {
const d = [
-897123.9, -321434.58758, -1321.3124, -324, -9, -3, 0, 0, 0, 0.31255, 5,
142.88888708, 334, 342, 453, 54254
];

it('-897123.9 should be found at index 0', () => {
const result = array_util.binarySearch(d, -897123.9);
expect(result).toBe(0);
});

it(`54254 should be found at index ${d.length - 1}`, () => {
const result = array_util.binarySearch(d, 54254);
expect(result).toBe(d.length - 1);
});

it('-3 should be found at index 5', () => {
const result = array_util.binarySearch(d, -3);
expect(result).toBe(5);
});

it('0 should be found at index 6', () => {
const result = array_util.binarySearch(d, 0);
expect(result).toBe(6);
});

it('-900000 should have an insertion point of 0', () => {
const result = array_util.binarySearch(d, -900000);
expect(result).toBeLessThan(0);
expect(insertionPoint(result)).toBe(0);
});

it(`54255 should have an insertion point of ${d.length}`, () => {
const result = array_util.binarySearch(d, 54255);
expect(result).toBeLessThan(0);
expect(insertionPoint(result)).toEqual(d.length);
});

it('1.1 should have an insertion point of 10', () => {
const result = array_util.binarySearch(d, 1.1);
expect(result).toBeLessThan(0);
expect(insertionPoint(result)).toEqual(10);
});
});

describe('binarySearch with custom comparator', () => {
const e = [
54254,
453,
342,
334,
142.88888708,
5,
0.31255,
0,
0,
0,
-3,
-9,
-324,
-1321.3124,
-321434.58758,
-897123.9,
];

const revComparator = (a: number, b: number) => (b - a);

it('54254 should be found at index 0', () => {
const result = array_util.binarySearch(e, 54254, revComparator);
expect(result).toBe(0);
});

it(`-897123.9 should be found at index ${e.length - 1}`, () => {
const result = array_util.binarySearch(e, -897123.9, revComparator);
expect(result).toBe(e.length - 1);
});

it('-3 should be found at index 10', () => {
const result = array_util.binarySearch(e, -3, revComparator);
expect(result).toBe(10);
});

it('0 should be found at index 7', () => {
const result = array_util.binarySearch(e, 0, revComparator);
expect(result).toBe(7);
});

it('54254.1 should have an insertion point of 0', () => {
const result = array_util.binarySearch(e, 54254.1, revComparator);
expect(result).toBeLessThan(0);
expect(insertionPoint(result)).toBe(0);
});

it(`-897124 should have an insertion point of ${e.length}`, () => {
const result = array_util.binarySearch(e, -897124, revComparator);
expect(result).toBeLessThan(0);
expect(insertionPoint(result)).toBe(e.length);
});
});

describe(
'binarySearch with custom comparator with single element array', () => {
const g = [1];

const revComparator = (a: number, b: number) => (b - a);

it('1 should be found at index 0', () => {
const result = array_util.binarySearch(g, 1, revComparator);
expect(result).toBe(0);
});

it('2 should have an insertion point of 0', () => {
const result = array_util.binarySearch(g, 2, revComparator);
expect(result).toBeLessThan(0);
expect(insertionPoint(result)).toBe(0);
});

it('0 should have an insertion point of 1', () => {
const result = array_util.binarySearch(g, 0, revComparator);
expect(result).toBeLessThan(0);
expect(insertionPoint(result)).toBe(1);
});
});

describe('binarySearch test left-most duplicated element', () => {
it('should find the index of the first 0', () => {
const result = array_util.binarySearch([0, 0, 1], 0);
expect(result).toBe(0);
});

it('should find the index of the first 1', () => {
const result = array_util.binarySearch([0, 1, 1], 1);
expect(result).toBe(1);
});
});

describe('binaryInsert', () => {
it('inserts correctly', () => {
const a: number[] = [];

array_util.binaryInsert(a, 3);
expect(a).toEqual([3]);

array_util.binaryInsert(a, 3);
expect(a).toEqual([3, 3]);

array_util.binaryInsert(a, 1);
expect(a).toEqual([1, 3, 3]);

array_util.binaryInsert(a, 5);
expect(a).toEqual([1, 3, 3, 5]);
});
});
});
3 changes: 2 additions & 1 deletion tfjs-core/src/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,8 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {

nonMaxSuppression(
boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number,
iouThreshold: number, scoreThreshold?: number): Tensor1D {
iouThreshold: number, scoreThreshold?: number,
softNmsSigma?: number): Tensor1D {
return notYetImplemented('nonMaxSuppression');
}

Expand Down
6 changes: 4 additions & 2 deletions tfjs-core/src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3270,13 +3270,15 @@ export class MathBackendCPU extends KernelBackend {

nonMaxSuppression(
boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number,
iouThreshold: number, scoreThreshold: number): Tensor1D {
iouThreshold: number, scoreThreshold: number,
softNmsSigma: number): Tensor1D {
assertNotComplex(boxes, 'nonMaxSuppression');

const boxesVals = this.readSync(boxes.dataId) as TypedArray;
const scoresVals = this.readSync(scores.dataId) as TypedArray;
return nonMaxSuppressionImpl(
boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold,
softNmsSigma);
}

fft(x: Tensor2D): Tensor2D {
Expand Down
Loading