diff --git a/tfjs-backend-cpu/src/kernels/Unique.ts b/tfjs-backend-cpu/src/kernels/Unique.ts new file mode 100644 index 00000000000..f4b4d23b31a --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Unique.ts @@ -0,0 +1,46 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, TensorInfo, Unique, UniqueAttrs, UniqueInputs} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; +import {assertNotComplex} from '../cpu_util'; + +import {uniqueImpl} from './Unique_impl'; + +export function unique( + args: {inputs: UniqueInputs, attrs: UniqueAttrs, backend: MathBackendCPU}): + TensorInfo[] { + const {inputs, attrs, backend} = args; + const {axis} = attrs; + const {x} = inputs; + assertNotComplex(x, 'unique'); + + const values = backend.data.get(x.dataId).values; + const {outputValues, outputShape, indices} = + uniqueImpl(values, axis, x.shape, x.dtype); + return [ + backend.makeTensorInfo(outputShape, x.dtype, outputValues), + backend.makeTensorInfo([indices.length], 'int32', indices), + ]; +} + +export const uniqueConfig: KernelConfig = { + kernelName: Unique, + backendName: 'cpu', + kernelFunc: unique as {} as KernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/Unique_impl.ts b/tfjs-backend-cpu/src/kernels/Unique_impl.ts new file mode 100644 index 00000000000..e1e104ca07d --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Unique_impl.ts @@ -0,0 +1,156 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {BackendValues, DataType, TensorBuffer, TypedArray, util} from '@tensorflow/tfjs-core'; + +export function uniqueImpl( + values: BackendValues, axis: number, shape: number[], dtype: DataType): { + outputValues: BackendValues, + outputShape: number[], + indices: BackendValues +} { + // Normalize and validate axis. + const $axis = util.parseAxisParam(axis, shape)[0]; + + // Calculate the new shape that is suitable for extracting data along the + // given axis. + // + // The rank is 3. + // The size of the 1st dimension is the size of all the axes < the given axis. + // The size of the 2nd dimension is the same as the size of the given axis. + // The size of the 3rd dimension is the size of all the axes > the given axis. + // + // For example, for a 4D tensor with shape=[2, 3, 5, 4] and axis=2, the + // newShape would be: [2*3, 5, 4]. + // + // Note that this is not the final output shape. This will be the shape for an + // intermediate TensorBuffer (see inputBuffer below) to allow us to extract + // values along the given axis. To demonstrate how it works, consider the + // following example: + // + // Input: a 3D tensor, with shape [1, 2, 3] + // [ + // [ + // [1,2,3], + // [4,5,6] + // ] + // ] + // Axis: 2 (the last axis). + // Along axis 2, we expect to extract 3 tensors: [1,4], [2,5], [3,6]. + // + // For this example, newShape would be: [2, 3, 1], where 2 is calculated from + // 1*2. The re-shaped data would look like: + // + // [ + // [ + // [1], [2], [3] + // ], + // [ + // [4], [5], [6] + // ] + // ] + // + // Then, we can construct a 3-level nested loop by the following dimension + // order to extract the values along the axis (dimension1): + // i: dimension1 // 0,1,2 (newShape[1]) + // m: dimension0 // 0,1 (newShape[0]) + // n: dimension2 // 0 (newShape[2]) + // + // m, i, n + // --------- + // Iteration 0: data at [0, 0, 0] => "1" + // Iteration 1: data at [1, 0, 0] => "4" + // We got [1,4]. + // Iteration 2: data at [0, 1, 0] => "2" + // Iteration 3: data at [1, 1, 0] => "5" + // We got [2,5]. + // Iteration 4: data at [0, 2, 0] => "3" + // Iteration 5: data at [1, 2, 0] => "6" + // We got [3,6]. + const newShape = [1, shape[0], 1]; + for (let i = 0; i < $axis; i++) { + newShape[0] *= shape[i]; + } + newShape[1] = shape[$axis]; + for (let i = $axis + 1; i < shape.length; i++) { + newShape[2] *= shape[i]; + } + + // A map from unique elements (their string representations) to their values + // in "indices" (below). + const uniqueElements: {[key: string]: number} = {}; + // The indices of each unique element in the original tensor along the given + // axis. It is 1D and has the same size as the given axis. + const indices = new Int32Array(shape[$axis]); + // Create a buffer so we can easily extract value at a given location. + const inputBuffer = new TensorBuffer(newShape, dtype, values as TypedArray); + // The indices along the given axis that have unique elements. This is a + // de-duped version of "indices" above. + const uniqueIndices: number[] = []; + const is1DTensor = newShape[0] === 1 && newShape[2] === 1; + for (let i = 0; i < shape[$axis]; i++) { + // Extract values along the axis. + let element: string; + if (is1DTensor) { + // Fast path for 1D tensor input. + element = values[i].toString(); + } else { + const axisValues = []; + for (let m = 0; m < newShape[0]; m++) { + for (let n = 0; n < newShape[2]; n++) { + axisValues.push(inputBuffer.get(m, i, n)); + } + } + element = axisValues.join(','); + } + + // Dedup and update various indices. + if (uniqueElements[element] !== undefined) { + indices[i] = uniqueElements[element]; + } else { + const uniqueIndex = Object.keys(uniqueElements).length; + uniqueElements[element] = uniqueIndex; + indices[i] = uniqueIndex; + uniqueIndices.push(i); + } + } + + // Now we know where each of the unique elements are located along the axis + // (uniqueIndices). Extract them from input buffer and store them in the + // output buffer. + const outputTmpShape = newShape.slice(); + outputTmpShape[1] = Object.keys(uniqueElements).length; + const outputBuffer = new TensorBuffer(outputTmpShape, dtype); + uniqueIndices.forEach((uniqueElementIndex, i) => { + for (let m = 0; m < newShape[0]; m++) { + for (let n = 0; n < newShape[2]; n++) { + outputBuffer.set(inputBuffer.get(m, uniqueElementIndex, n), m, i, n); + } + } + }); + + // The output shape can be calculated from the input shape with the size of + // the given axis replaced by the number of unique elements along that axis. + const outputShape = shape.slice(); + outputShape[$axis] = outputTmpShape[1]; + + return { + outputValues: outputBuffer.values as BackendValues, + outputShape, + indices, + }; +} diff --git a/tfjs-backend-cpu/src/register_all_kernels.ts b/tfjs-backend-cpu/src/register_all_kernels.ts index bec5c0cd270..f4ffff7fec4 100644 --- a/tfjs-backend-cpu/src/register_all_kernels.ts +++ b/tfjs-backend-cpu/src/register_all_kernels.ts @@ -88,6 +88,7 @@ import {subConfig} from './kernels/Sub'; import {tanConfig} from './kernels/Tan'; import {tanhConfig} from './kernels/Tanh'; import {transposeConfig} from './kernels/Transpose'; +import {uniqueConfig} from './kernels/Unique'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ @@ -159,7 +160,8 @@ const kernelConfigs: KernelConfig[] = [ subConfig, tanConfig, tanhConfig, - transposeConfig + transposeConfig, + uniqueConfig, ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-backend-cpu/src/shared.ts b/tfjs-backend-cpu/src/shared.ts index c54f9b4b7e3..bad8613065b 100644 --- a/tfjs-backend-cpu/src/shared.ts +++ b/tfjs-backend-cpu/src/shared.ts @@ -18,3 +18,4 @@ // Shared kernel impls for use in other backends. export {maxImpl} from './kernels/Max_impl'; export {transposeImpl} from './kernels/Transpose_impl'; +export {uniqueImpl} from './kernels/Unique_impl'; diff --git a/tfjs-backend-webgl/src/backend_webgl.ts b/tfjs-backend-webgl/src/backend_webgl.ts index d6a41def911..d07bfcb6930 100644 --- a/tfjs-backend-webgl/src/backend_webgl.ts +++ b/tfjs-backend-webgl/src/backend_webgl.ts @@ -2360,8 +2360,9 @@ export class MathBackendWebGL extends KernelBackend { return backend_util.linspaceImpl(start, stop, num); } - makeTensorInfo(shape: number[], dtype: DataType): TensorInfo { - const dataId = this.write(null /* values */, shape, dtype); + makeTensorInfo(shape: number[], dtype: DataType, values?: BackendValues): + TensorInfo { + const dataId = this.write(values, shape, dtype); this.texData.get(dataId).usage = null; return {dataId, shape, dtype}; } diff --git a/tfjs-backend-webgl/src/kernel_utils/shared.ts b/tfjs-backend-webgl/src/kernel_utils/shared.ts index 0ac71e981c0..6c2941d0e83 100644 --- a/tfjs-backend-webgl/src/kernel_utils/shared.ts +++ b/tfjs-backend-webgl/src/kernel_utils/shared.ts @@ -20,6 +20,10 @@ // tslint:disable-next-line: no-imports-from-dist import * as shared from '@tensorflow/tfjs-backend-cpu/dist/shared'; -const {maxImpl: maxImplCPU, transposeImpl: transposeImplCPU} = shared; +const { + maxImpl: maxImplCPU, + transposeImpl: transposeImplCPU, + uniqueImpl: uniqueImplCPU, +} = shared; -export {maxImplCPU, transposeImplCPU}; +export {maxImplCPU, transposeImplCPU, uniqueImplCPU}; diff --git a/tfjs-backend-webgl/src/kernels/Unique.ts b/tfjs-backend-webgl/src/kernels/Unique.ts new file mode 100644 index 00000000000..4fc78a809ec --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/Unique.ts @@ -0,0 +1,50 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the License); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, TensorInfo, Unique, UniqueAttrs, UniqueInputs} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {uniqueImplCPU} from '../kernel_utils/shared'; +import {assertNotComplex} from '../webgl_util'; + +export function unique( + args: + {inputs: UniqueInputs, attrs: UniqueAttrs, backend: MathBackendWebGL}): + TensorInfo[] { + const {inputs, attrs, backend} = args; + const {axis} = attrs; + const {x} = inputs; + assertNotComplex(x, 'unique'); + + // For now, always forward calculation to the CPU backend. + console.warn( + 'WARNING: ', + 'UI might be locked temporarily as data is being downloaded'); + const values = backend.readSync(x.dataId); + const {outputValues, outputShape, indices} = + uniqueImplCPU(values, axis, x.shape, x.dtype); + return [ + backend.makeTensorInfo(outputShape, x.dtype, outputValues), + backend.makeTensorInfo([indices.length], 'int32', indices), + ]; +} + +export const uniqueConfig: KernelConfig = { + kernelName: Unique, + backendName: 'webgl', + kernelFunc: unique as {} as KernelFunc, +}; diff --git a/tfjs-backend-webgl/src/register_all_kernels.ts b/tfjs-backend-webgl/src/register_all_kernels.ts index 48161cc1f2b..a94a871145e 100644 --- a/tfjs-backend-webgl/src/register_all_kernels.ts +++ b/tfjs-backend-webgl/src/register_all_kernels.ts @@ -39,6 +39,7 @@ import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; import {tanConfig} from './kernels/Tan'; import {transposeConfig} from './kernels/Transpose'; +import {uniqueConfig} from './kernels/Unique'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ @@ -64,7 +65,8 @@ const kernelConfigs: KernelConfig[] = [ squareConfig, squaredDifferenceConfig, tanConfig, - transposeConfig + transposeConfig, + uniqueConfig, ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-converter/docs/supported_ops.md b/tfjs-converter/docs/supported_ops.md index 9db33accd58..0047a19d4ad 100644 --- a/tfjs-converter/docs/supported_ops.md +++ b/tfjs-converter/docs/supported_ops.md @@ -166,6 +166,8 @@ |Tensorflow Op Name|Tensorflow.js Op Name| |---|---| |TopKV2|topK| +|Unique|unique| +|UniqueV2|unique| |Not mapped|confusionMatrix| |Not mapped|topk| diff --git a/tfjs-converter/metadata/kernel2op.json b/tfjs-converter/metadata/kernel2op.json index 7d7e5383749..737a37b53f4 100644 --- a/tfjs-converter/metadata/kernel2op.json +++ b/tfjs-converter/metadata/kernel2op.json @@ -505,6 +505,12 @@ "TruncatedNormal": [ "truncatedNormal" ], + "Unique": [ + "unique" + ], + "UniqueV2": [ + "unique" + ], "Unpack": [ "unstack" ], diff --git a/tfjs-converter/python/tensorflowjs/op_list/evaluation.json b/tfjs-converter/python/tensorflowjs/op_list/evaluation.json index 8a14458a4ed..5de11223d6d 100644 --- a/tfjs-converter/python/tensorflowjs/op_list/evaluation.json +++ b/tfjs-converter/python/tensorflowjs/op_list/evaluation.json @@ -21,5 +21,32 @@ "type": "bool" } ] + }, + { + "tfOpName": "Unique", + "category": "evaluation", + "inputs": [ + { + "start": 0, + "name": "x", + "type": "tensor" + } + ] + }, + { + "tfOpName": "UniqueV2", + "category": "evaluation", + "inputs": [ + { + "start": 0, + "name": "x", + "type": "tensor" + }, + { + "start": 1, + "name": "axis", + "type": "number" + } + ] } -] \ No newline at end of file +] diff --git a/tfjs-converter/src/operations/executors/evaluation_executor.ts b/tfjs-converter/src/operations/executors/evaluation_executor.ts index 84ad0cf5601..5a46646155e 100644 --- a/tfjs-converter/src/operations/executors/evaluation_executor.ts +++ b/tfjs-converter/src/operations/executors/evaluation_executor.ts @@ -37,6 +37,18 @@ export const executeOp: InternalOpExecutor = const result = tfOps.topk(x, k, sorted); return [result.values, result.indices]; } + case 'Unique': { + const x = getParamValue('x', node, tensorMap, context) as Tensor; + const result = tfOps.unique(x); + return [result.values, result.indices]; + } + case 'UniqueV2': { + const x = getParamValue('x', node, tensorMap, context) as Tensor; + const axis = + getParamValue('axis', node, tensorMap, context) as number; + const result = tfOps.unique(x, axis); + return [result.values, result.indices]; + } default: throw TypeError(`Node type ${node.op} is not implemented`); } diff --git a/tfjs-converter/src/operations/executors/evaluation_executor_test.ts b/tfjs-converter/src/operations/executors/evaluation_executor_test.ts index 13e7cfec569..35edbbb01d3 100644 --- a/tfjs-converter/src/operations/executors/evaluation_executor_test.ts +++ b/tfjs-converter/src/operations/executors/evaluation_executor_test.ts @@ -54,5 +54,28 @@ describe('evaluation', () => { expect(tfOps.topk).toHaveBeenCalledWith(input1[0], 1, true); }); }); + + describe('Unique', () => { + it('should get called correctly', () => { + node.op = 'Unique'; + node.inputParams['x'] = createTensorAttr(0); + spyOn(tfOps, 'unique').and.callThrough(); + executeOp(node, {input1}, context); + expect(tfOps.unique).toHaveBeenCalledWith(input1[0]); + }); + }); + + describe('UniqueV2', () => { + it('should get called correctly', () => { + node.op = 'UniqueV2'; + node.inputParams['x'] = createTensorAttr(0); + node.inputParams['axis'] = createNumberAttrFromIndex(1); + spyOn(tfOps, 'unique').and.callThrough(); + const xInput = [tfOps.tensor2d([[1], [2]])]; + const axisInput = [tfOps.scalar(1)]; + executeOp(node, {'input1': xInput, 'input2': axisInput}, context); + expect(tfOps.unique).toHaveBeenCalledWith(xInput[0], 1); + }); + }); }); }); diff --git a/tfjs-converter/src/operations/op_list/evaluation.ts b/tfjs-converter/src/operations/op_list/evaluation.ts index 5e80b493969..2d09bdb5b0f 100644 --- a/tfjs-converter/src/operations/op_list/evaluation.ts +++ b/tfjs-converter/src/operations/op_list/evaluation.ts @@ -17,12 +17,29 @@ import {OpMapper} from '../types'; * ============================================================================= */ -export const json: OpMapper[] = [{ - 'tfOpName': 'TopKV2', - 'category': 'evaluation', - 'inputs': [ - {'start': 0, 'name': 'x', 'type': 'tensor'}, - {'start': 1, 'name': 'k', 'type': 'number'}, - ], - 'attrs': [{'tfName': 'sorted', 'name': 'sorted', 'type': 'bool'}] -}]; +export const json: OpMapper[] = [ + { + 'tfOpName': 'TopKV2', + 'category': 'evaluation', + 'inputs': [ + {'start': 0, 'name': 'x', 'type': 'tensor'}, + {'start': 1, 'name': 'k', 'type': 'number'}, + ], + 'attrs': [{'tfName': 'sorted', 'name': 'sorted', 'type': 'bool'}] + }, + { + 'tfOpName': 'Unique', + 'category': 'evaluation', + 'inputs': [ + {'start': 0, 'name': 'x', 'type': 'tensor'}, + ], + }, + { + 'tfOpName': 'UniqueV2', + 'category': 'evaluation', + 'inputs': [ + {'start': 0, 'name': 'x', 'type': 'tensor'}, + {'start': 1, 'name': 'axis', 'type': 'number'}, + ], + }, +]; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 25628ffb978..7584a924576 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -746,6 +746,12 @@ export interface TransposeAttrs { perm: number[]; } +export const Unique = 'Unique'; +export type UniqueInputs = Pick; +export interface UniqueAttrs { + axis: number; +} + export type UnaryInputs = Pick; export const Unpack = 'Unpack'; diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 86c907532d9..9aab3f68276 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -185,6 +185,7 @@ export {tensor6d} from './tensor6d'; export {tile} from './tile'; export {topk} from './topk'; export {truncatedNormal} from './truncated_normal'; +export {unique} from './unique'; export {unsortedSegmentSum} from './unsorted_segment_sum'; export {unstack} from './unstack'; export {variable} from './variable'; diff --git a/tfjs-core/src/ops/unique.ts b/tfjs-core/src/ops/unique.ts new file mode 100644 index 00000000000..e7d35876a4c --- /dev/null +++ b/tfjs-core/src/ops/unique.ts @@ -0,0 +1,92 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Unique, UniqueAttrs, UniqueInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor, Tensor1D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import {assert} from '../util'; + +import {op} from './operation'; + +/** + * Finds unique elements along an axis of a tensor. + * + * It returns a tensor `values` containing all of the unique elements along the + * `axis` of the given tensor `x` in the same order that they occur along the + * `axis` in `x`; `x` does not need to be sorted. It also returns a tensor + * `indices` the same size as the number of the elements in `x` along the `axis` + * dimension. It contains the index in the unique output `values`. + * + * ```js + * // A 1-D tensor + * const a = tf.tensor1d([1, 1, 2, 4, 4, 4, 7, 8, 8]); + * const {values, indices} = tf.unique(a); + * values.print(); // [1, 2, 4, 7, 8,] + * indices.print(); // [0, 0, 1, 2, 2, 2, 3, 4, 4] + * ``` + * + * ```js + * // A 2-D tensor with axis=0 + * // + * // 'a' is: [[1, 0, 0], + * // [1, 0, 0], + * // [2, 0, 0]] + * const a = tf.tensor2d([[1, 0, 0], [1, 0, 0], [2, 0, 0]]); + * const {values, indices} = tf.unique(a, 0) + * values.print(); // [[1, 0, 0], + * // [2, 0, 0]] + * indices.print(); // [0, 0, 1] + * ``` + * + * ```js + * // A 2-D tensor with axis=1 + * // + * // 'a' is: [[1, 0, 0], + * // [1, 0, 0], + * // [2, 0, 0]] + * const a = tf.tensor2d([[1, 0, 0], [1, 0, 0], [2, 0, 0]]); + * const {values, indices} = tf.unique(a, 1) + * values.print(); // [[1, 0], + * // [1, 0], + * // [2, 0]] + * indices.print(); // [0, 1, 1] + * ``` + * @param x A tensor (int32, string, bool). + * @param axis The axis of the tensor to find the unique elements. + * @returns [uniqueElements, indices] (see above for details) + * + * @doc {heading: 'Operations', subheading: 'Evaluation'} + */ +function unique_( + x: T|TensorLike, axis = 0): {values: T, indices: Tensor1D} { + // x can be of any dtype, thus null as the last argument. + const $x = convertToTensor(x, 'x', 'unique', null); + assert($x.rank > 0, () => 'The input tensor must be at least 1D'); + + const inputs: UniqueInputs = {x: $x}; + const attrs: UniqueAttrs = {axis}; + const [values, indices] = ENGINE.runKernel( + Unique, inputs as {} as NamedTensorMap, + attrs as {} as NamedAttrMap) as [T, Tensor1D]; + return {values, indices}; +} + +export const unique = op({unique_}); diff --git a/tfjs-core/src/ops/unique_test.ts b/tfjs-core/src/ops/unique_test.ts new file mode 100644 index 00000000000..127413b295b --- /dev/null +++ b/tfjs-core/src/ops/unique_test.ts @@ -0,0 +1,167 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysEqual} from '../test_util'; + +import {tensor1d} from './tensor1d'; + +describeWithFlags('unique', ALL_ENVS, () => { + it('1d tensor with int32', async () => { + const x = tensor1d([1, 1, 2, 4, 4, 4, 7, 8, 8]); + const {values, indices} = tf.unique(x); + + expect(indices.dtype).toBe('int32'); + expect(indices.shape).toEqual(x.shape); + expect(values.shape).toEqual([5]); + expectArraysEqual(await values.data(), [1, 2, 4, 7, 8]); + expectArraysEqual(await indices.data(), [0, 0, 1, 2, 2, 2, 3, 4, 4]); + }); + + it('1d tensor with string', async () => { + const x = tensor1d(['a', 'b', 'b', 'c', 'c']); + const {values, indices} = tf.unique(x); + + expect(indices.dtype).toBe('int32'); + expect(indices.shape).toEqual(x.shape); + expect(values.dtype).toEqual('string'); + expect(values.shape).toEqual([3]); + expectArraysEqual(await values.data(), ['a', 'b', 'c']); + expectArraysEqual(await indices.data(), [0, 1, 1, 2, 2]); + }); + + it('1d tensor with bool', async () => { + const x = tensor1d([true, true, false]); + const {values, indices} = tf.unique(x); + + expect(indices.dtype).toBe('int32'); + expect(indices.shape).toEqual(x.shape); + expect(values.dtype).toEqual('bool'); + expect(values.shape).toEqual([2]); + expectArraysEqual(await values.data(), [true, false]); + expectArraysEqual(await indices.data(), [0, 0, 1]); + }); + + it('1d tensor with NaN and Infinity', async () => { + const x = tensor1d([NaN, Infinity, NaN, Infinity]); + const {values, indices} = tf.unique(x); + + expect(indices.dtype).toBe('int32'); + expect(indices.shape).toEqual(x.shape); + expect(values.shape).toEqual([2]); + expectArraysEqual(await values.data(), [NaN, Infinity]); + expectArraysEqual(await indices.data(), [0, 1, 0, 1]); + }); + + it('2d tensor with axis=0', async () => { + const x = tf.tensor2d([[1, 0, 0], [1, 0, 0], [2, 0, 0]]); + const {values, indices} = tf.unique(x, 0); + + expect(indices.dtype).toBe('int32'); + expect(indices.shape).toEqual([x.shape[0]]); + expect(values.shape).toEqual([2, 3]); + expectArraysEqual(await values.data(), [1, 0, 0, 2, 0, 0]); + expectArraysEqual(await indices.data(), [0, 0, 1]); + }); + + it('2d tensor with axis=1', async () => { + const x = tf.tensor2d([[1, 0, 0, 1], [1, 0, 0, 1], [2, 0, 0, 2]]); + const {values, indices} = tf.unique(x, 1); + + expect(indices.dtype).toBe('int32'); + expect(indices.shape).toEqual([x.shape[1]]); + expect(values.shape).toEqual([3, 2]); + expectArraysEqual(await values.data(), [[1, 0], [1, 0], [2, 0]]); + expectArraysEqual(await indices.data(), [0, 1, 1, 0]); + }); + + it('2d tensor with string', async () => { + const x = tf.tensor2d([['a', 'b', 'b'], ['a', 'b', 'b'], ['c', 'b', 'b']]); + const {values, indices} = tf.unique(x, 0); + + expect(indices.dtype).toBe('int32'); + expect(indices.shape).toEqual([x.shape[0]]); + expect(values.dtype).toEqual('string'); + expect(values.shape).toEqual([2, 3]); + expectArraysEqual(await values.data(), ['a', 'b', 'b', 'c', 'b', 'b']); + expectArraysEqual(await indices.data(), [0, 0, 1]); + }); + + it('2d tensor with strings that have comma', async () => { + const x = tf.tensor2d([['a', 'b,c', 'd'], ['a', 'b', 'c,d']]); + const {values, indices} = tf.unique(x, 0); + + expect(indices.dtype).toBe('int32'); + expect(indices.shape).toEqual([x.shape[0]]); + expect(values.dtype).toEqual('string'); + expect(values.shape).toEqual([2, 3]); + expectArraysEqual(await values.data(), ['a', 'b,c', 'd', 'a', 'b', 'c,d']); + expectArraysEqual(await indices.data(), [0, 1]); + }); + + it('3d tensor with axis=0', async () => { + const x = + tf.tensor3d([[[1, 0], [1, 0]], [[1, 0], [1, 0]], [[1, 1], [1, 1]]]); + const {values, indices} = tf.unique(x, 0); + + expect(indices.dtype).toBe('int32'); + expect(indices.shape).toEqual([x.shape[0]]); + expect(values.shape).toEqual([2, 2, 2]); + expectArraysEqual(await values.data(), [1, 0, 1, 0, 1, 1, 1, 1]); + expectArraysEqual(await indices.data(), [0, 0, 1]); + }); + + it('3d tensor with axis=1', async () => { + const x = + tf.tensor3d([[[1, 0], [1, 0]], [[1, 0], [1, 0]], [[1, 1], [1, 1]]]); + const {values, indices} = tf.unique(x, 1); + + expect(indices.dtype).toBe('int32'); + expect(indices.shape).toEqual([x.shape[1]]); + expect(values.shape).toEqual([3, 1, 2]); + expectArraysEqual(await values.data(), [[[1, 0]], [[1, 0]], [[1, 1]]]); + expectArraysEqual(await indices.data(), [0, 0]); + }); + + it('3d tensor with axis=2', async () => { + const x = tf.tensor3d([[[1, 0, 1]], [[1, 0, 1]]]); + const {values, indices} = tf.unique(x, 2); + + expect(indices.dtype).toBe('int32'); + expect(indices.shape).toEqual([x.shape[2]]); + expect(values.shape).toEqual([2, 1, 2]); + expectArraysEqual(await values.data(), [1, 0, 1, 0]); + expectArraysEqual(await indices.data(), [0, 1, 0]); + }); + + it('3d tensor with string', async () => { + const x = tf.tensor3d([ + [['a', 'b'], ['a', 'b']], [['a', 'b'], ['a', 'b']], + [['a', 'a'], ['a', 'a']] + ]); + const {values, indices} = tf.unique(x, 0); + + expect(indices.dtype).toBe('int32'); + expect(indices.shape).toEqual([x.shape[0]]); + expect(values.dtype).toEqual('string'); + expect(values.shape).toEqual([2, 2, 2]); + expectArraysEqual( + await values.data(), ['a', 'b', 'a', 'b', 'a', 'a', 'a', 'a']); + expectArraysEqual(await indices.data(), [0, 0, 1]); + }); +}); diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index 8aa5702ac7b..bb8d6aaa7ba 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -160,6 +160,7 @@ import './to_float'; import './to_int'; import './topk'; import './transpose'; +import './unique'; import './unsorted_segment_sum'; import './unstack'; import './where'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 5549cfb6cab..9b6e559661d 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -169,6 +169,7 @@ const CHAINED_OPS = [ 'toInt', 'topk', 'transpose', + 'unique', 'unsortedSegmentSum', 'unstack', 'where', diff --git a/tfjs-core/src/public/chained_ops/unique.ts b/tfjs-core/src/public/chained_ops/unique.ts new file mode 100644 index 00000000000..7f31819a43e --- /dev/null +++ b/tfjs-core/src/public/chained_ops/unique.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {unique} from '../../ops/unique'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + unique(this: T, axis?: number): {values: T, indices: T}; + } +} + +Tensor.prototype.unique = function( + this: T, axis?: number): {values: T, indices: T} { + this.throwIfDisposed(); + return unique(this, axis) as {values: T, indices: T}; +}; diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index afa7677655b..c27997d84d0 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -216,6 +216,7 @@ import './ops/to_pixels_test'; import './ops/topk_test'; import './ops/transpose_test'; import './ops/truncated_normal_test'; +import './ops/unique_test'; import './ops/unsorted_segment_sum_test'; import './ops/unstack_test'; import './ops/where_async_test'; diff --git a/tfjs-node/src/run_tests.ts b/tfjs-node/src/run_tests.ts index 6d41a161930..132b84a2851 100644 --- a/tfjs-node/src/run_tests.ts +++ b/tfjs-node/src/run_tests.ts @@ -69,7 +69,8 @@ const IGNORE_LIST: string[] = [ // tslint:disable-next-line:max-line-length 'maxPool3d test-tensorflow {} x=[1,2,2,2,1] f=[2,2,2] s=1 p=1 roundingMode=floor', // libtensorflow doesn't support 6D ArgMax yet. - 'argmax test-tensorflow {} 6D, axis=0', 'diag test-tensorflow {} complex', + 'argmax test-tensorflow {} 6D, axis=0', + 'diag test-tensorflow {} complex', 'diag test-tensorflow {} bool', // See https://github.com/tensorflow/tfjs/issues/1891 'conv2d test-tensorflow {} x=[2,1,2,2] f=[1,1,1,1] s=1 d=1 p=0 NCHW', @@ -79,7 +80,10 @@ const IGNORE_LIST: string[] = [ 'conv2d test-tensorflow {} x=[2,1,2,2] f=[2,2,1,1] s=1 d=1 p=same NCHW', 'conv2d test-tensorflow {} gradient x=[1,1,3,3] f=[2,2,1,1] s=1 p=0 NCHW', 'conv2d test-tensorflow {} gradient x=[2,1,3,3] f=[2,2,1,1] s=1 p=0 NCHW', - 'maxPoolWithArgmax', 'rotate', 'flipLeftRight' + 'maxPoolWithArgmax', + 'rotate', + 'flipLeftRight', + 'unique', ]; if (process.platform === 'win32') {