Skip to content

Commit

Permalink
added support for sparseReshape op for converter (#4963)
Browse files Browse the repository at this point in the history
FEATURE
* added support for sparseReshape op for converter

* fix lint

* update kernel2op
  • Loading branch information
pyu10055 committed Apr 21, 2021
1 parent ffb4f24 commit e913154
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 6 deletions.
3 changes: 3 additions & 0 deletions tfjs-converter/metadata/kernel2op.json
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@
"SpaceToBatchND": [
"spaceToBatchND"
],
"SparseReshape": [
"sparse.sparseReshape"
],
"SparseToDense": [
"sparseToDense",
"cast"
Expand Down
31 changes: 31 additions & 0 deletions tfjs-converter/python/tensorflowjs/op_list/sparse.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
[
{
"tfOpName": "SparseReshape",
"category": "sparse",
"inputs": [
{
"start": 0,
"name": "inputIndices",
"type": "tensor"
},
{
"start": 1,
"name": "inputShape",
"type": "tensor"
},
{
"start": 2,
"name": "newShape",
"type": "tensor"
}
],
"attrs": [
{
"tfName": "T",
"name": "dtype",
"type": "dtype",
"notSupported": true
}
]
}
]
45 changes: 45 additions & 0 deletions tfjs-converter/src/operations/executors/sparse_executor.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Tensor, Tensor1D, Tensor2D} from '@tensorflow/tfjs-core';
// tslint:disable-next-line: no-imports-from-dist
import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';

import {NamedTensorsMap} from '../../data/types';
import {ExecutionContext} from '../../executor/execution_context';
import {InternalOpExecutor, Node} from '../types';

import {getParamValue} from './utils';

export const executeOp: InternalOpExecutor =
(node: Node, tensorMap: NamedTensorsMap,
context: ExecutionContext): Tensor[] => {
switch (node.op) {
case 'SparseReshape': {
const {outputIndices, outputShape} = tfOps.sparse.sparseReshape(
getParamValue('inputIndices', node, tensorMap, context) as
Tensor2D,
getParamValue('inputShape', node, tensorMap, context) as Tensor1D,
getParamValue('newShape', node, tensorMap, context) as Tensor1D);
return [outputIndices, outputShape];
}
default:
throw TypeError(`Node type ${node.op} is not implemented`);
}
};

export const CATEGORY = 'convolution';
76 changes: 76 additions & 0 deletions tfjs-converter/src/operations/executors/sparse_executor_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {Tensor, test_util} from '@tensorflow/tfjs-core';
// tslint:disable-next-line: no-imports-from-dist
import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';

import {ExecutionContext} from '../../executor/execution_context';
import * as sparse from '../op_list/sparse';
import {Node} from '../types';

import {executeOp} from './sparse_executor';
import {createTensorAttr, validateParam} from './test_helper';

describe('sparse', () => {
let node: Node;
const inputIndices = [tfOps.tensor2d(
[0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 2, 3], [5, 3], 'int32')];
const inputShape = [tfOps.tensor1d([2, 3, 6], 'int32')];
const newShape = [tfOps.tensor1d([9, -1], 'int32')];
const context = new ExecutionContext({}, {}, {});

beforeEach(() => {
node = {
name: 'test',
op: '',
category: 'sparse',
inputNames: ['inputIndices', 'inputShape', 'newShape'],
inputs: [],
inputParams: {
inputIndices: createTensorAttr(0),
inputShape: createTensorAttr(1),
newShape: createTensorAttr(2)
},
attrParams: {},
children: []
};
});

describe('executeOp', () => {
describe('SparseReshape', () => {
it('should call tfOps.sparse.sparseReshape', async () => {
spyOn(tfOps.sparse, 'sparseReshape').and.callThrough();
node.op = 'SparseReshape';
const result =
executeOp(node, {inputIndices, inputShape, newShape}, context) as
Tensor[];

expect(tfOps.sparse.sparseReshape)
.toHaveBeenCalledWith(inputIndices[0], inputShape[0], newShape[0]);
test_util.expectArraysClose(
await result[0].data(), [0, 0, 0, 1, 1, 2, 4, 2, 8, 1]);
test_util.expectArraysClose(await result[1].data(), [9, 4]);
});

it('should match json def', () => {
node.op = 'SparseReshape';

expect(validateParam(node, sparse.json)).toBeTruthy();
});
});
});
});
30 changes: 30 additions & 0 deletions tfjs-converter/src/operations/op_list/sparse.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {OpMapper} from '../types';

export const json: OpMapper[] = [{
'tfOpName': 'SparseReshape',
'category': 'sparse',
'inputs': [
{'start': 0, 'name': 'inputIndices', 'type': 'tensor'},
{'start': 1, 'name': 'inputShape', 'type': 'tensor'},
{'start': 2, 'name': 'newShape', 'type': 'tensor'},
],
'attrs':
[{'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true}]
}];
8 changes: 3 additions & 5 deletions tfjs-converter/src/operations/op_mapper_schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export const json = {
'arithmetic', 'basic_math', 'control', 'convolution', 'custom',
'dynamic', 'evaluation', 'image', 'creation', 'graph', 'logical',
'matrices', 'normalization', 'reduction', 'slice_join', 'spectral',
'transformation'
'transformation', 'sparse'
]
},
'InputParamMapper': {
Expand All @@ -54,8 +54,7 @@ export const json = {
'anyOf': [
{'type': 'string'}, {'type': 'array', 'items': {'type': 'string'}},
{'type': 'number'}, {'type': 'array', 'items': {'type': 'number'}},
{'type': 'boolean'},
{'type': 'array', 'items': {'type': 'boolean'}}
{'type': 'boolean'}, {'type': 'array', 'items': {'type': 'boolean'}}
]
},
'notSupported': {'type': 'boolean'},
Expand All @@ -81,8 +80,7 @@ export const json = {
'anyOf': [
{'type': 'string'}, {'type': 'array', 'items': {'type': 'string'}},
{'type': 'number'}, {'type': 'array', 'items': {'type': 'number'}},
{'type': 'boolean'},
{'type': 'array', 'items': {'type': 'boolean'}}
{'type': 'boolean'}, {'type': 'array', 'items': {'type': 'boolean'}}
]
},
'notSupported': {'type': 'boolean'},
Expand Down
3 changes: 3 additions & 0 deletions tfjs-converter/src/operations/operation_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import * as matrices from './executors/matrices_executor';
import * as normalization from './executors/normalization_executor';
import * as reduction from './executors/reduction_executor';
import * as sliceJoin from './executors/slice_join_executor';
import * as sparse from './executors/sparse_executor';
import * as spectral from './executors/spectral_executor';
import * as transformation from './executors/transformation_executor';
import {Node} from './types';
Expand Down Expand Up @@ -90,6 +91,8 @@ export function executeOp(
case 'slice_join':
return tfc.tidy(
() => sliceJoin.executeOp(node, tensorMap, context));
case 'sparse':
return tfc.tidy(() => sparse.executeOp(node, tensorMap, context));
case 'spectral':
return tfc.tidy(() => spectral.executeOp(node, tensorMap, context));
case 'transformation':
Expand Down
2 changes: 1 addition & 1 deletion tfjs-converter/src/operations/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export type ParamType = 'number'|'string'|'string[]'|'number[]'|'bool'|'bool[]'|
export type Category = 'arithmetic'|'basic_math'|'control'|'convolution'|
'custom'|'dynamic'|'evaluation'|'image'|'creation'|'graph'|'logical'|
'matrices'|'normalization'|'reduction'|'slice_join'|'spectral'|
'transformation'|'hash_table';
'transformation'|'hash_table'|'sparse';

// For mapping input or attributes of NodeDef into TensorFlow.js op param.
export declare interface ParamMapper {
Expand Down

0 comments on commit e913154

Please sign in to comment.