From e9131549c7ca6dcdca76590a1eb1c68c08e312ce Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Tue, 20 Apr 2021 21:18:56 -0600 Subject: [PATCH] added support for sparseReshape op for converter (#4963) FEATURE * added support for sparseReshape op for converter * fix lint * update kernel2op --- tfjs-converter/metadata/kernel2op.json | 3 + .../python/tensorflowjs/op_list/sparse.json | 31 ++++++++ .../operations/executors/sparse_executor.ts | 45 +++++++++++ .../executors/sparse_executor_test.ts | 76 +++++++++++++++++++ .../src/operations/op_list/sparse.ts | 30 ++++++++ .../src/operations/op_mapper_schema.ts | 8 +- .../src/operations/operation_executor.ts | 3 + tfjs-converter/src/operations/types.ts | 2 +- 8 files changed, 192 insertions(+), 6 deletions(-) create mode 100644 tfjs-converter/python/tensorflowjs/op_list/sparse.json create mode 100644 tfjs-converter/src/operations/executors/sparse_executor.ts create mode 100644 tfjs-converter/src/operations/executors/sparse_executor_test.ts create mode 100644 tfjs-converter/src/operations/op_list/sparse.ts diff --git a/tfjs-converter/metadata/kernel2op.json b/tfjs-converter/metadata/kernel2op.json index 9f21aabddd..8fe5af57f5 100644 --- a/tfjs-converter/metadata/kernel2op.json +++ b/tfjs-converter/metadata/kernel2op.json @@ -455,6 +455,9 @@ "SpaceToBatchND": [ "spaceToBatchND" ], + "SparseReshape": [ + "sparse.sparseReshape" + ], "SparseToDense": [ "sparseToDense", "cast" diff --git a/tfjs-converter/python/tensorflowjs/op_list/sparse.json b/tfjs-converter/python/tensorflowjs/op_list/sparse.json new file mode 100644 index 0000000000..43c1d02041 --- /dev/null +++ b/tfjs-converter/python/tensorflowjs/op_list/sparse.json @@ -0,0 +1,31 @@ +[ + { + "tfOpName": "SparseReshape", + "category": "sparse", + "inputs": [ + { + "start": 0, + "name": "inputIndices", + "type": "tensor" + }, + { + "start": 1, + "name": "inputShape", + "type": "tensor" + }, + { + "start": 2, + "name": "newShape", + "type": "tensor" + } + ], + "attrs": [ + { + "tfName": "T", + "name": "dtype", + "type": "dtype", + "notSupported": true + } + ] + } +] \ No newline at end of file diff --git a/tfjs-converter/src/operations/executors/sparse_executor.ts b/tfjs-converter/src/operations/executors/sparse_executor.ts new file mode 100644 index 0000000000..2590361b60 --- /dev/null +++ b/tfjs-converter/src/operations/executors/sparse_executor.ts @@ -0,0 +1,45 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor, Tensor1D, Tensor2D} from '@tensorflow/tfjs-core'; +// tslint:disable-next-line: no-imports-from-dist +import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter'; + +import {NamedTensorsMap} from '../../data/types'; +import {ExecutionContext} from '../../executor/execution_context'; +import {InternalOpExecutor, Node} from '../types'; + +import {getParamValue} from './utils'; + +export const executeOp: InternalOpExecutor = + (node: Node, tensorMap: NamedTensorsMap, + context: ExecutionContext): Tensor[] => { + switch (node.op) { + case 'SparseReshape': { + const {outputIndices, outputShape} = tfOps.sparse.sparseReshape( + getParamValue('inputIndices', node, tensorMap, context) as + Tensor2D, + getParamValue('inputShape', node, tensorMap, context) as Tensor1D, + getParamValue('newShape', node, tensorMap, context) as Tensor1D); + return [outputIndices, outputShape]; + } + default: + throw TypeError(`Node type ${node.op} is not implemented`); + } + }; + +export const CATEGORY = 'convolution'; diff --git a/tfjs-converter/src/operations/executors/sparse_executor_test.ts b/tfjs-converter/src/operations/executors/sparse_executor_test.ts new file mode 100644 index 0000000000..324a4eae94 --- /dev/null +++ b/tfjs-converter/src/operations/executors/sparse_executor_test.ts @@ -0,0 +1,76 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Tensor, test_util} from '@tensorflow/tfjs-core'; +// tslint:disable-next-line: no-imports-from-dist +import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter'; + +import {ExecutionContext} from '../../executor/execution_context'; +import * as sparse from '../op_list/sparse'; +import {Node} from '../types'; + +import {executeOp} from './sparse_executor'; +import {createTensorAttr, validateParam} from './test_helper'; + +describe('sparse', () => { + let node: Node; + const inputIndices = [tfOps.tensor2d( + [0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 2, 3], [5, 3], 'int32')]; + const inputShape = [tfOps.tensor1d([2, 3, 6], 'int32')]; + const newShape = [tfOps.tensor1d([9, -1], 'int32')]; + const context = new ExecutionContext({}, {}, {}); + + beforeEach(() => { + node = { + name: 'test', + op: '', + category: 'sparse', + inputNames: ['inputIndices', 'inputShape', 'newShape'], + inputs: [], + inputParams: { + inputIndices: createTensorAttr(0), + inputShape: createTensorAttr(1), + newShape: createTensorAttr(2) + }, + attrParams: {}, + children: [] + }; + }); + + describe('executeOp', () => { + describe('SparseReshape', () => { + it('should call tfOps.sparse.sparseReshape', async () => { + spyOn(tfOps.sparse, 'sparseReshape').and.callThrough(); + node.op = 'SparseReshape'; + const result = + executeOp(node, {inputIndices, inputShape, newShape}, context) as + Tensor[]; + + expect(tfOps.sparse.sparseReshape) + .toHaveBeenCalledWith(inputIndices[0], inputShape[0], newShape[0]); + test_util.expectArraysClose( + await result[0].data(), [0, 0, 0, 1, 1, 2, 4, 2, 8, 1]); + test_util.expectArraysClose(await result[1].data(), [9, 4]); + }); + + it('should match json def', () => { + node.op = 'SparseReshape'; + + expect(validateParam(node, sparse.json)).toBeTruthy(); + }); + }); + }); +}); diff --git a/tfjs-converter/src/operations/op_list/sparse.ts b/tfjs-converter/src/operations/op_list/sparse.ts new file mode 100644 index 0000000000..2cc1927ac5 --- /dev/null +++ b/tfjs-converter/src/operations/op_list/sparse.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {OpMapper} from '../types'; + +export const json: OpMapper[] = [{ + 'tfOpName': 'SparseReshape', + 'category': 'sparse', + 'inputs': [ + {'start': 0, 'name': 'inputIndices', 'type': 'tensor'}, + {'start': 1, 'name': 'inputShape', 'type': 'tensor'}, + {'start': 2, 'name': 'newShape', 'type': 'tensor'}, + ], + 'attrs': + [{'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true}] +}]; diff --git a/tfjs-converter/src/operations/op_mapper_schema.ts b/tfjs-converter/src/operations/op_mapper_schema.ts index 96e45d8e85..5daabce201 100644 --- a/tfjs-converter/src/operations/op_mapper_schema.ts +++ b/tfjs-converter/src/operations/op_mapper_schema.ts @@ -42,7 +42,7 @@ export const json = { 'arithmetic', 'basic_math', 'control', 'convolution', 'custom', 'dynamic', 'evaluation', 'image', 'creation', 'graph', 'logical', 'matrices', 'normalization', 'reduction', 'slice_join', 'spectral', - 'transformation' + 'transformation', 'sparse' ] }, 'InputParamMapper': { @@ -54,8 +54,7 @@ export const json = { 'anyOf': [ {'type': 'string'}, {'type': 'array', 'items': {'type': 'string'}}, {'type': 'number'}, {'type': 'array', 'items': {'type': 'number'}}, - {'type': 'boolean'}, - {'type': 'array', 'items': {'type': 'boolean'}} + {'type': 'boolean'}, {'type': 'array', 'items': {'type': 'boolean'}} ] }, 'notSupported': {'type': 'boolean'}, @@ -81,8 +80,7 @@ export const json = { 'anyOf': [ {'type': 'string'}, {'type': 'array', 'items': {'type': 'string'}}, {'type': 'number'}, {'type': 'array', 'items': {'type': 'number'}}, - {'type': 'boolean'}, - {'type': 'array', 'items': {'type': 'boolean'}} + {'type': 'boolean'}, {'type': 'array', 'items': {'type': 'boolean'}} ] }, 'notSupported': {'type': 'boolean'}, diff --git a/tfjs-converter/src/operations/operation_executor.ts b/tfjs-converter/src/operations/operation_executor.ts index 0241e85172..28df52933c 100644 --- a/tfjs-converter/src/operations/operation_executor.ts +++ b/tfjs-converter/src/operations/operation_executor.ts @@ -38,6 +38,7 @@ import * as matrices from './executors/matrices_executor'; import * as normalization from './executors/normalization_executor'; import * as reduction from './executors/reduction_executor'; import * as sliceJoin from './executors/slice_join_executor'; +import * as sparse from './executors/sparse_executor'; import * as spectral from './executors/spectral_executor'; import * as transformation from './executors/transformation_executor'; import {Node} from './types'; @@ -90,6 +91,8 @@ export function executeOp( case 'slice_join': return tfc.tidy( () => sliceJoin.executeOp(node, tensorMap, context)); + case 'sparse': + return tfc.tidy(() => sparse.executeOp(node, tensorMap, context)); case 'spectral': return tfc.tidy(() => spectral.executeOp(node, tensorMap, context)); case 'transformation': diff --git a/tfjs-converter/src/operations/types.ts b/tfjs-converter/src/operations/types.ts index 054f6364ec..b686bfab09 100644 --- a/tfjs-converter/src/operations/types.ts +++ b/tfjs-converter/src/operations/types.ts @@ -26,7 +26,7 @@ export type ParamType = 'number'|'string'|'string[]'|'number[]'|'bool'|'bool[]'| export type Category = 'arithmetic'|'basic_math'|'control'|'convolution'| 'custom'|'dynamic'|'evaluation'|'image'|'creation'|'graph'|'logical'| 'matrices'|'normalization'|'reduction'|'slice_join'|'spectral'| - 'transformation'|'hash_table'; + 'transformation'|'hash_table'|'sparse'; // For mapping input or attributes of NodeDef into TensorFlow.js op param. export declare interface ParamMapper {