Skip to content

Commit

Permalink
Add Kernel SparseReshape for CPU and WebGL backend (#4956)
Browse files Browse the repository at this point in the history
FEATURE
* initial checkin for sparse reshape

* added SparseReshape kernels to cpu and webgl

* udpated the filename

* fix lint errors

* fixed failed snippets

* fixed tests

* fix lint

* fix lint;

* fix node tests

* addressed comments

* fix failing tests

* fix lint error
  • Loading branch information
pyu10055 committed Apr 20, 2021
1 parent cd7fa52 commit ffb4f24
Show file tree
Hide file tree
Showing 14 changed files with 500 additions and 13 deletions.
64 changes: 64 additions & 0 deletions tfjs-backend-cpu/src/kernels/SparseReshape.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelConfig, SparseReshape, SparseReshapeInputs, TensorInfo, TypedArray} from '@tensorflow/tfjs-core';

import {MathBackendCPU} from '../backend_cpu';

import {sparseReshapeImpl} from './SparseReshape_impl';

export function sparseReshape(
args: {inputs: SparseReshapeInputs, backend: MathBackendCPU}):
[TensorInfo, TensorInfo] {
const {inputs, backend} = args;
const {inputIndices, inputShape, newShape} = inputs;
if (inputIndices.shape.length !== 2) {
throw new Error(`Input indices should be a matrix but received shape
${inputIndices.shape}`);
}
if (inputShape.shape.length !== 1) {
throw new Error(`Input shape should be a vector but received shape
${inputShape.shape}`);
}

if (newShape.shape.length !== 1) {
throw new Error(
`Target shape should be a vector but received shape ${newShape.shape}`);
}

const $inputShape =
Array.from(backend.data.get(inputShape.dataId).values as TypedArray);
const $inputIndices =
backend.data.get(inputIndices.dataId).values as TypedArray;
const targetShape =
Array.from(backend.data.get(newShape.dataId).values as TypedArray);

const [newIndices, indicesShape, outputShape] = sparseReshapeImpl(
$inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape,
targetShape);
return [
backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices),
backend.makeTensorInfo(
[outputShape.length], newShape.dtype, new Int32Array(outputShape)),
];
}

export const sparseReshapeConfig: KernelConfig = {
kernelName: SparseReshape,
backendName: 'cpu',
kernelFunc: sparseReshape,
};
105 changes: 105 additions & 0 deletions tfjs-backend-cpu/src/kernels/SparseReshape_impl.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {DataType, TypedArray, util} from '@tensorflow/tfjs-core';

export function sparseReshapeImpl(
inputIndices: TypedArray, inputIndicesShape: number[], inputDType: DataType,
inputShape: number[],
targetShape: number[]): [TypedArray, number[], number[]] {
const denseSize = util.sizeFromShape(inputShape);
const nnz = inputIndicesShape[0];
const outputRank = targetShape.length;

// Compute the output shape. Determine product of specified dimensions, and
// find the index of the unspecified one.
const outputShape: number[] = [];
let product = 1;
let unknownIndex = -1;
for (let d = 0; d < outputRank; ++d) {
const size = targetShape[d];
if (size === -1) {
if (unknownIndex !== -1) {
throw new Error(`only one output dimension may be -1, not both ${
unknownIndex} and ${d}`);
}
unknownIndex = d;
outputShape.push(1);
} else {
if (size < 0) {
throw new Error(`size ${d} must be non-negative, not ${size}`);
}
product *= size;
outputShape.push(size);
}
}
if (unknownIndex !== -1) {
if (product <= 0) {
throw new Error(
'reshape cannot infer the missing ' +
'input size for an empty tensor unless all ' +
'specified input sizes are non-zero');
}
const missing = Math.trunc(denseSize / product);
if (product * missing !== denseSize) {
throw new Error(`Input to reshape is a SparseTensor with ${denseSize}
dense values, but the requested shape requires a multiple of ${
product}. inputShape=${inputShape} outputShape= ${outputShape}`);
}

outputShape[unknownIndex] = missing;
}
const outputSize = util.sizeFromShape(outputShape);
if (outputSize !== denseSize) {
throw new Error(`Input to reshape is a tensor with ${
denseSize} dense values, but the requested shape has ${
outputSize}. inputShape=${inputShape} outputShape=${outputShape}`);
}

const inputRank = inputShape.length;
const inputStrides: number[] = [];
if (inputRank > 0) {
inputStrides[inputRank - 1] = 1;
for (let d = inputRank - 2; d >= 0; --d) {
inputStrides[d] = inputStrides[d + 1] * inputShape[d + 1];
}
}

const outputStrides: number[] = [];
if (outputRank > 0) {
outputStrides[outputRank - 1] = 1;
for (let d = outputRank - 2; d >= 0; --d) {
outputStrides[d] = outputStrides[d + 1] * outputShape[d + 1];
}
}

const newIndices =
util.getArrayFromDType(inputDType, nnz * outputRank) as TypedArray;
for (let i = 0; i < nnz; ++i) {
let id = 0;
for (let j = 0; j < inputRank; ++j) {
// inputIndices is a 2d tensor with shape of [nnz, inputRank]
id += inputIndices[i * inputRank + j] * inputStrides[j];
}
for (let j = 0; j < outputRank; ++j) {
// newIndices is a 2d tensor with shape of [nnz, outputRank]
newIndices[i * outputRank + j] = Math.trunc(id / outputStrides[j]);
id %= outputStrides[j];
}
}
return [newIndices, [nnz, outputRank], outputShape];
}
2 changes: 2 additions & 0 deletions tfjs-backend-cpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ import {sliceConfig} from './kernels/Slice';
import {softmaxConfig} from './kernels/Softmax';
import {softplusConfig} from './kernels/Softplus';
import {spaceToBatchNDConfig} from './kernels/SpaceToBatchND';
import {sparseReshapeConfig} from './kernels/SparseReshape';
import {sparseToDenseConfig} from './kernels/SparseToDense';
import {splitVConfig} from './kernels/SplitV';
import {sqrtConfig} from './kernels/Sqrt';
Expand Down Expand Up @@ -313,6 +314,7 @@ const kernelConfigs: KernelConfig[] = [
softmaxConfig,
softplusConfig,
spaceToBatchNDConfig,
sparseReshapeConfig,
sparseToDenseConfig,
splitVConfig,
sqrtConfig,
Expand Down
1 change: 1 addition & 0 deletions tfjs-backend-cpu/src/shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ export {prodImpl} from './kernels/Prod';
export {rangeImpl} from './kernels/Range_impl';
export {rsqrtImpl} from './kernels/Rsqrt';
export {sliceImpl} from './kernels/Slice';
export {sparseReshapeImpl} from './kernels/SparseReshape_impl';
export {squaredDifferenceImpl} from './kernels/SquaredDifference';
export {stridedSliceImpl} from './kernels/StridedSlice_impl';
export {subImpl} from './kernels/Sub';
Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-webgl/src/kernel_utils/shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const {
rsqrtImpl: rsqrtImplCPU,
simpleAbsImpl: simpleAbsImplCPU,
sliceImpl: sliceImplCPU,
sparseReshapeImpl: sparseReshapeImplCPU,
stridedSliceImpl: stridedSliceImplCPU,
subImpl: subImplCPU,
tileImpl: tileImplCPU,
Expand Down Expand Up @@ -81,6 +82,7 @@ export {
prodImplCPU,
simpleAbsImplCPU,
sliceImplCPU,
sparseReshapeImplCPU,
stridedSliceImplCPU,
subImplCPU,
rangeImplCPU,
Expand Down
62 changes: 62 additions & 0 deletions tfjs-backend-webgl/src/kernels/SparseReshape.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelConfig, SparseReshape, SparseReshapeInputs, TensorInfo, TypedArray} from '@tensorflow/tfjs-core';

import {MathBackendWebGL} from '../backend_webgl';
import {sparseReshapeImplCPU} from '../kernel_utils/shared';

export function sparseReshape(
args: {inputs: SparseReshapeInputs, backend: MathBackendWebGL}):
[TensorInfo, TensorInfo] {
const {inputs, backend} = args;
const {inputIndices, inputShape, newShape} = inputs;
if (inputIndices.shape.length !== 2) {
throw new Error(`Input indices should be a matrix but received shape ${
inputIndices.shape}`);
}
if (inputShape.shape.length !== 1) {
throw new Error(`Input shape should be a vector but received shape ${
inputShape.shape}`);
}

if (newShape.shape.length !== 1) {
throw new Error(
`Target shape should be a vector but received shape ${newShape.shape}`);
}

const $inputShape =
Array.from(backend.readSync(inputShape.dataId) as TypedArray);
const $inputIndices = backend.readSync(inputIndices.dataId) as TypedArray;
const targetShape =
Array.from(backend.readSync(newShape.dataId) as TypedArray);

const [newIndices, indicesShape, outputShape] = sparseReshapeImplCPU(
$inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape,
targetShape);
return [
backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices),
backend.makeTensorInfo(
[outputShape.length], newShape.dtype, new Int32Array(outputShape)),
];
}

export const sparseReshapeConfig: KernelConfig = {
kernelName: SparseReshape,
backendName: 'webgl',
kernelFunc: sparseReshape,
};
2 changes: 2 additions & 0 deletions tfjs-backend-webgl/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ import {sliceConfig} from './kernels/Slice';
import {softmaxConfig} from './kernels/Softmax';
import {softplusConfig} from './kernels/Softplus';
import {spaceToBatchNDConfig} from './kernels/SpaceToBatchND';
import {sparseReshapeConfig} from './kernels/SparseReshape';
import {sparseToDenseConfig} from './kernels/SparseToDense';
import {splitVConfig} from './kernels/SplitV';
import {sqrtConfig} from './kernels/Sqrt';
Expand Down Expand Up @@ -308,6 +309,7 @@ const kernelConfigs: KernelConfig[] = [
softmaxConfig,
softplusConfig,
spaceToBatchNDConfig,
sparseReshapeConfig,
sparseToDenseConfig,
splitVConfig,
sqrtConfig,
Expand Down
2 changes: 1 addition & 1 deletion tfjs-converter/metadata/kernel2op.json
Original file line number Diff line number Diff line change
Expand Up @@ -555,4 +555,4 @@
"_FusedMatMul": [
"fused.matMul"
]
}
}
20 changes: 12 additions & 8 deletions tfjs-core/src/kernel_names.ts
Original file line number Diff line number Diff line change
Expand Up @@ -766,14 +766,9 @@ export interface SoftmaxAttrs {
dim: number;
}

export const SquaredDifference = 'SquaredDifference';
export type SquaredDifferenceInputs = BinaryInputs;

export const Square = 'Square';
export type SquareInputs = Pick<NamedTensorInfoMap, 'x'>;

export const Sub = 'Sub';
export type SubInputs = BinaryInputs;
export const SparseReshape = 'SparseReshape';
export type SparseReshapeInputs =
Pick<NamedTensorInfoMap, 'inputIndices'|'inputShape'|'newShape'>;

export const SparseToDense = 'SparseToDense';
export type SparseToDenseInputs =
Expand All @@ -782,6 +777,12 @@ export interface SparseToDenseAttrs {
outputShape: number[];
}

export const SquaredDifference = 'SquaredDifference';
export type SquaredDifferenceInputs = BinaryInputs;

export const Square = 'Square';
export type SquareInputs = Pick<NamedTensorInfoMap, 'x'>;

export const StridedSlice = 'StridedSlice';
export type StridedSliceInputs = Pick<NamedTensorInfoMap, 'x'>;
export interface StridedSliceAttrs {
Expand All @@ -795,6 +796,9 @@ export interface StridedSliceAttrs {
shrinkAxisMask: number;
}

export const Sub = 'Sub';
export type SubInputs = BinaryInputs;

export const Tan = 'Tan';
export type TanInputs = UnaryInputs;

Expand Down
5 changes: 4 additions & 1 deletion tfjs-core/src/ops/ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -296,5 +296,8 @@ const losses = {
softmaxCrossEntropy
};

import {sparseReshape} from './sparse/sparse_reshape';
const sparse = {sparseReshape};

// Second level exports.
export {image, linalg, losses, spectral, fused, signal};
export {image, linalg, losses, spectral, fused, signal, sparse};
Loading

0 comments on commit ffb4f24

Please sign in to comment.