From 4b27e7b8c6c789be3ad9674e2c0a91b7c1b59cd3 Mon Sep 17 00:00:00 2001 From: Na Li Date: Mon, 14 Sep 2020 13:25:01 -0700 Subject: [PATCH 1/3] Modularize NotEqual. --- tfjs-backend-cpu/src/kernels/Div.ts | 14 ++++- tfjs-backend-cpu/src/kernels/NotEqual.ts | 30 ++++++++++ .../src/kernels/SquaredDifference.ts | 16 ++++-- tfjs-backend-cpu/src/register_all_kernels.ts | 6 +- tfjs-backend-cpu/src/utils/kernel_utils.ts | 55 +++++++++++-------- 5 files changed, 87 insertions(+), 34 deletions(-) create mode 100644 tfjs-backend-cpu/src/kernels/NotEqual.ts diff --git a/tfjs-backend-cpu/src/kernels/Div.ts b/tfjs-backend-cpu/src/kernels/Div.ts index 57907fb6e28..33cbada577d 100644 --- a/tfjs-backend-cpu/src/kernels/Div.ts +++ b/tfjs-backend-cpu/src/kernels/Div.ts @@ -15,8 +15,16 @@ * ============================================================================= */ -import {Div} from '@tensorflow/tfjs-core'; -import {createBinaryKernelConfig} from '../utils/kernel_utils'; +import {Div, KernelConfig} from '@tensorflow/tfjs-core'; + +import {binaryKernelFunc} from '../utils/kernel_utils'; + import {divImpl} from './Div_impl'; -export const divConfig = createBinaryKernelConfig(Div, divImpl); +export const div = binaryKernelFunc(Div, divImpl); + +export const divConfig: KernelConfig = { + kernelName: Div, + backendName: 'cpu', + kernelFunc: div +}; diff --git a/tfjs-backend-cpu/src/kernels/NotEqual.ts b/tfjs-backend-cpu/src/kernels/NotEqual.ts new file mode 100644 index 00000000000..5a69a988eb1 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/NotEqual.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, NotEqual} from '@tensorflow/tfjs-core'; + +import {binaryKernelFunc, createBinaryKernelImpl} from '../utils/kernel_utils'; + +const notEqualImpl = createBinaryKernelImpl((a, b) => (a !== b) ? 1 : 0); + +export const notEqual = binaryKernelFunc(NotEqual, notEqualImpl, 'bool'); + +export const notEqualConfig: KernelConfig = { + kernelName: NotEqual, + backendName: 'cpu', + kernelFunc: notEqual +}; diff --git a/tfjs-backend-cpu/src/kernels/SquaredDifference.ts b/tfjs-backend-cpu/src/kernels/SquaredDifference.ts index 8f1ba615942..edb512ecf37 100644 --- a/tfjs-backend-cpu/src/kernels/SquaredDifference.ts +++ b/tfjs-backend-cpu/src/kernels/SquaredDifference.ts @@ -15,14 +15,20 @@ * ============================================================================= */ -import {SquaredDifference} from '@tensorflow/tfjs-core'; -import {createBinaryKernelImpl} from '../utils/kernel_utils'; -import {createBinaryKernelConfig} from '../utils/kernel_utils'; +import {KernelConfig, SquaredDifference} from '@tensorflow/tfjs-core'; + +import {binaryKernelFunc, createBinaryKernelImpl} from '../utils/kernel_utils'; const squaredDifferenceImpl = createBinaryKernelImpl((aVal, bVal) => { const diff = aVal - bVal; return diff * diff; }); -export const squaredDifferenceConfig = - createBinaryKernelConfig(SquaredDifference, squaredDifferenceImpl); +export const squaredDifference = + binaryKernelFunc(SquaredDifference, squaredDifferenceImpl); + +export const squaredDifferenceConfig: KernelConfig = { + kernelName: SquaredDifference, + backendName: 'cpu', + kernelFunc: squaredDifference +}; diff --git a/tfjs-backend-cpu/src/register_all_kernels.ts b/tfjs-backend-cpu/src/register_all_kernels.ts index bdf9e54a3b3..c761093e4ad 100644 --- a/tfjs-backend-cpu/src/register_all_kernels.ts +++ b/tfjs-backend-cpu/src/register_all_kernels.ts @@ -30,6 +30,7 @@ import {maxConfig} from './kernels/Max'; import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax'; import {nonMaxSuppressionV4Config} from './kernels/NonMaxSuppressionV4'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; +import {notEqualConfig} from './kernels/NotEqual'; import {padV2Config} from './kernels/PadV2'; import {reshapeConfig} from './kernels/Reshape'; import {rotateWithOffsetConfig} from './kernels/RotateWithOffset'; @@ -43,8 +44,9 @@ const kernelConfigs: KernelConfig[] = [ cosConfig, dilation2dConfig, dilation2dBackpropInputConfig, dilation2dBackpropFilterConfig, divConfig, flipLeftRightConfig, identityConfig, maxPoolWithArgmaxConfig, maxConfig, nonMaxSuppressionV4Config, - nonMaxSuppressionV5Config, padV2Config, reshapeConfig, rotateWithOffsetConfig, - spaceToBatchNDConfig, squareConfig, squaredDifferenceConfig, transposeConfig + nonMaxSuppressionV5Config, notEqualConfig, padV2Config, reshapeConfig, + rotateWithOffsetConfig, spaceToBatchNDConfig, squareConfig, + squaredDifferenceConfig, transposeConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-backend-cpu/src/utils/kernel_utils.ts b/tfjs-backend-cpu/src/utils/kernel_utils.ts index 8c263201c15..0448cc3f069 100644 --- a/tfjs-backend-cpu/src/utils/kernel_utils.ts +++ b/tfjs-backend-cpu/src/utils/kernel_utils.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {BinaryInputs, KernelConfig} from '@tensorflow/tfjs-core'; +import {BinaryInputs, KernelFunc} from '@tensorflow/tfjs-core'; import {DataType, NumericDataType, TypedArray} from '@tensorflow/tfjs-core'; import {backend_util} from '@tensorflow/tfjs-core'; @@ -23,29 +23,36 @@ import {util} from '@tensorflow/tfjs-core'; import {MathBackendCPU} from '../backend_cpu'; import {assertNotComplex} from '../cpu_util'; -export function createBinaryKernelConfig( - name: string, - op: ( - aShape: number[], bShape: number[], aVals: TypedArray, - bVals: TypedArray, - dtype: DataType) => [TypedArray, number[]]): KernelConfig { - return { - kernelName: name, - backendName: 'cpu', - kernelFunc: ({inputs, backend}) => { - const {a, b} = inputs as BinaryInputs; - const cpuBackend = backend as MathBackendCPU; - assertNotComplex([a, b], name); - - const aVals = cpuBackend.data.get(a.dataId).values as TypedArray; - const bVals = cpuBackend.data.get(b.dataId).values as TypedArray; - - const [resultData, resultShape] = - op(a.shape, b.shape, aVals, bVals, a.dtype); - - const dataId = cpuBackend.write(resultData, resultShape, a.dtype); - return {dataId, shape: resultShape, dtype: a.dtype}; - } +export type SimpleBinaryOperation = (a: number, b: number) => number; +export type SimpleBinaryKernelImpl = + (aShape: number[], bShape: number[], aVals: TypedArray, bVals: TypedArray, + dtype: DataType) => [TypedArray, number[]]; + +/** + * Template that creates a `KernelFunc` for binary ops. + * @param name Kernel name. + * @param op A `SimpleBinaryKernelImpl` of the kernel. + * @param dtype Optional. If set, the result has this dtype. Otherwise, the + * result has the same dtype as the the first input. This is mainly used + * in comparison kernels, such as Equal, Less, Greater, etc. + */ +export function binaryKernelFunc( + name: string, op: SimpleBinaryKernelImpl, dtype?: DataType): KernelFunc { + return ({inputs, backend}) => { + const {a, b} = inputs as BinaryInputs; + const cpuBackend = backend as MathBackendCPU; + assertNotComplex([a, b], name); + + const aVals = cpuBackend.data.get(a.dataId).values as TypedArray; + const bVals = cpuBackend.data.get(b.dataId).values as TypedArray; + + const $dtype = dtype || a.dtype; + + const [resultData, resultShape] = + op(a.shape, b.shape, aVals, bVals, $dtype); + + const dataId = cpuBackend.write(resultData, resultShape, $dtype); + return {dataId, shape: resultShape, dtype: $dtype}; }; } From af790fe47db50413fc61c5ebde9e3bf36e4e1239 Mon Sep 17 00:00:00 2001 From: Na Li Date: Mon, 14 Sep 2020 14:15:34 -0700 Subject: [PATCH 2/3] Address comments. --- tfjs-backend-cpu/src/utils/kernel_utils.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tfjs-backend-cpu/src/utils/kernel_utils.ts b/tfjs-backend-cpu/src/utils/kernel_utils.ts index 0448cc3f069..c407b716e07 100644 --- a/tfjs-backend-cpu/src/utils/kernel_utils.ts +++ b/tfjs-backend-cpu/src/utils/kernel_utils.ts @@ -56,7 +56,8 @@ export function binaryKernelFunc( }; } -export function createBinaryKernelImpl(op: (a: number, b: number) => number) { +export function createBinaryKernelImpl(op: SimpleBinaryOperation): + SimpleBinaryKernelImpl { return (aShape: number[], bShape: number[], aVals: TypedArray, bVals: TypedArray, dtype: DataType): [TypedArray, number[]] => { const newShape = backend_util.assertAndGetBroadcastShape(aShape, bShape); From 43a33c3e125eae805019c0cf042108f6cfb2aa07 Mon Sep 17 00:00:00 2001 From: Na Li Date: Tue, 15 Sep 2020 23:01:20 -0700 Subject: [PATCH 3/3] Addressed comments. --- tfjs-backend-cpu/src/kernels/Div.ts | 4 +--- tfjs-backend-cpu/src/kernels/Div_impl.ts | 20 ------------------- tfjs-backend-cpu/src/kernels/NotEqual.ts | 7 +++---- .../src/kernels/SquaredDifference.ts | 12 +++++------ tfjs-backend-cpu/src/utils/kernel_utils.ts | 4 ++-- 5 files changed, 11 insertions(+), 36 deletions(-) delete mode 100644 tfjs-backend-cpu/src/kernels/Div_impl.ts diff --git a/tfjs-backend-cpu/src/kernels/Div.ts b/tfjs-backend-cpu/src/kernels/Div.ts index 33cbada577d..32f5c6e7f8d 100644 --- a/tfjs-backend-cpu/src/kernels/Div.ts +++ b/tfjs-backend-cpu/src/kernels/Div.ts @@ -19,9 +19,7 @@ import {Div, KernelConfig} from '@tensorflow/tfjs-core'; import {binaryKernelFunc} from '../utils/kernel_utils'; -import {divImpl} from './Div_impl'; - -export const div = binaryKernelFunc(Div, divImpl); +export const div = binaryKernelFunc(Div, (a: number, b: number) => a / b); export const divConfig: KernelConfig = { kernelName: Div, diff --git a/tfjs-backend-cpu/src/kernels/Div_impl.ts b/tfjs-backend-cpu/src/kernels/Div_impl.ts deleted file mode 100644 index 2b4c060d8ff..00000000000 --- a/tfjs-backend-cpu/src/kernels/Div_impl.ts +++ /dev/null @@ -1,20 +0,0 @@ -/** - * @license - * Copyright 2020 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {createBinaryKernelImpl} from '../utils/kernel_utils'; - -export const divImpl = createBinaryKernelImpl((a: number, b: number) => a / b); diff --git a/tfjs-backend-cpu/src/kernels/NotEqual.ts b/tfjs-backend-cpu/src/kernels/NotEqual.ts index 5a69a988eb1..bd8ad16d284 100644 --- a/tfjs-backend-cpu/src/kernels/NotEqual.ts +++ b/tfjs-backend-cpu/src/kernels/NotEqual.ts @@ -17,11 +17,10 @@ import {KernelConfig, NotEqual} from '@tensorflow/tfjs-core'; -import {binaryKernelFunc, createBinaryKernelImpl} from '../utils/kernel_utils'; +import {binaryKernelFunc} from '../utils/kernel_utils'; -const notEqualImpl = createBinaryKernelImpl((a, b) => (a !== b) ? 1 : 0); - -export const notEqual = binaryKernelFunc(NotEqual, notEqualImpl, 'bool'); +export const notEqual = + binaryKernelFunc(NotEqual, ((a, b) => (a !== b) ? 1 : 0), 'bool'); export const notEqualConfig: KernelConfig = { kernelName: NotEqual, diff --git a/tfjs-backend-cpu/src/kernels/SquaredDifference.ts b/tfjs-backend-cpu/src/kernels/SquaredDifference.ts index edb512ecf37..12f8424dc06 100644 --- a/tfjs-backend-cpu/src/kernels/SquaredDifference.ts +++ b/tfjs-backend-cpu/src/kernels/SquaredDifference.ts @@ -17,15 +17,13 @@ import {KernelConfig, SquaredDifference} from '@tensorflow/tfjs-core'; -import {binaryKernelFunc, createBinaryKernelImpl} from '../utils/kernel_utils'; - -const squaredDifferenceImpl = createBinaryKernelImpl((aVal, bVal) => { - const diff = aVal - bVal; - return diff * diff; -}); +import {binaryKernelFunc} from '../utils/kernel_utils'; export const squaredDifference = - binaryKernelFunc(SquaredDifference, squaredDifferenceImpl); + binaryKernelFunc(SquaredDifference, (aVal, bVal) => { + const diff = aVal - bVal; + return diff * diff; + }); export const squaredDifferenceConfig: KernelConfig = { kernelName: SquaredDifference, diff --git a/tfjs-backend-cpu/src/utils/kernel_utils.ts b/tfjs-backend-cpu/src/utils/kernel_utils.ts index c407b716e07..2a88dc3722f 100644 --- a/tfjs-backend-cpu/src/utils/kernel_utils.ts +++ b/tfjs-backend-cpu/src/utils/kernel_utils.ts @@ -37,7 +37,7 @@ export type SimpleBinaryKernelImpl = * in comparison kernels, such as Equal, Less, Greater, etc. */ export function binaryKernelFunc( - name: string, op: SimpleBinaryKernelImpl, dtype?: DataType): KernelFunc { + name: string, op: SimpleBinaryOperation, dtype?: DataType): KernelFunc { return ({inputs, backend}) => { const {a, b} = inputs as BinaryInputs; const cpuBackend = backend as MathBackendCPU; @@ -49,7 +49,7 @@ export function binaryKernelFunc( const $dtype = dtype || a.dtype; const [resultData, resultShape] = - op(a.shape, b.shape, aVals, bVals, $dtype); + createBinaryKernelImpl(op)(a.shape, b.shape, aVals, bVals, $dtype); const dataId = cpuBackend.write(resultData, resultShape, $dtype); return {dataId, shape: resultShape, dtype: $dtype};