diff --git a/tfjs-backend-cpu/src/kernels/Div.ts b/tfjs-backend-cpu/src/kernels/Div.ts index 57907fb6e28..32f5c6e7f8d 100644 --- a/tfjs-backend-cpu/src/kernels/Div.ts +++ b/tfjs-backend-cpu/src/kernels/Div.ts @@ -15,8 +15,14 @@ * ============================================================================= */ -import {Div} from '@tensorflow/tfjs-core'; -import {createBinaryKernelConfig} from '../utils/kernel_utils'; -import {divImpl} from './Div_impl'; +import {Div, KernelConfig} from '@tensorflow/tfjs-core'; -export const divConfig = createBinaryKernelConfig(Div, divImpl); +import {binaryKernelFunc} from '../utils/kernel_utils'; + +export const div = binaryKernelFunc(Div, (a: number, b: number) => a / b); + +export const divConfig: KernelConfig = { + kernelName: Div, + backendName: 'cpu', + kernelFunc: div +}; diff --git a/tfjs-backend-cpu/src/kernels/Div_impl.ts b/tfjs-backend-cpu/src/kernels/NotEqual.ts similarity index 67% rename from tfjs-backend-cpu/src/kernels/Div_impl.ts rename to tfjs-backend-cpu/src/kernels/NotEqual.ts index 2b4c060d8ff..bd8ad16d284 100644 --- a/tfjs-backend-cpu/src/kernels/Div_impl.ts +++ b/tfjs-backend-cpu/src/kernels/NotEqual.ts @@ -15,6 +15,15 @@ * ============================================================================= */ -import {createBinaryKernelImpl} from '../utils/kernel_utils'; +import {KernelConfig, NotEqual} from '@tensorflow/tfjs-core'; -export const divImpl = createBinaryKernelImpl((a: number, b: number) => a / b); +import {binaryKernelFunc} from '../utils/kernel_utils'; + +export const notEqual = + binaryKernelFunc(NotEqual, ((a, b) => (a !== b) ? 1 : 0), 'bool'); + +export const notEqualConfig: KernelConfig = { + kernelName: NotEqual, + backendName: 'cpu', + kernelFunc: notEqual +}; diff --git a/tfjs-backend-cpu/src/kernels/SquaredDifference.ts b/tfjs-backend-cpu/src/kernels/SquaredDifference.ts index 8f1ba615942..12f8424dc06 100644 --- a/tfjs-backend-cpu/src/kernels/SquaredDifference.ts +++ b/tfjs-backend-cpu/src/kernels/SquaredDifference.ts @@ -15,14 +15,18 @@ * ============================================================================= */ -import {SquaredDifference} from '@tensorflow/tfjs-core'; -import {createBinaryKernelImpl} from '../utils/kernel_utils'; -import {createBinaryKernelConfig} from '../utils/kernel_utils'; +import {KernelConfig, SquaredDifference} from '@tensorflow/tfjs-core'; -const squaredDifferenceImpl = createBinaryKernelImpl((aVal, bVal) => { - const diff = aVal - bVal; - return diff * diff; -}); +import {binaryKernelFunc} from '../utils/kernel_utils'; -export const squaredDifferenceConfig = - createBinaryKernelConfig(SquaredDifference, squaredDifferenceImpl); +export const squaredDifference = + binaryKernelFunc(SquaredDifference, (aVal, bVal) => { + const diff = aVal - bVal; + return diff * diff; + }); + +export const squaredDifferenceConfig: KernelConfig = { + kernelName: SquaredDifference, + backendName: 'cpu', + kernelFunc: squaredDifference +}; diff --git a/tfjs-backend-cpu/src/register_all_kernels.ts b/tfjs-backend-cpu/src/register_all_kernels.ts index bdf9e54a3b3..c761093e4ad 100644 --- a/tfjs-backend-cpu/src/register_all_kernels.ts +++ b/tfjs-backend-cpu/src/register_all_kernels.ts @@ -30,6 +30,7 @@ import {maxConfig} from './kernels/Max'; import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax'; import {nonMaxSuppressionV4Config} from './kernels/NonMaxSuppressionV4'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; +import {notEqualConfig} from './kernels/NotEqual'; import {padV2Config} from './kernels/PadV2'; import {reshapeConfig} from './kernels/Reshape'; import {rotateWithOffsetConfig} from './kernels/RotateWithOffset'; @@ -43,8 +44,9 @@ const kernelConfigs: KernelConfig[] = [ cosConfig, dilation2dConfig, dilation2dBackpropInputConfig, dilation2dBackpropFilterConfig, divConfig, flipLeftRightConfig, identityConfig, maxPoolWithArgmaxConfig, maxConfig, nonMaxSuppressionV4Config, - nonMaxSuppressionV5Config, padV2Config, reshapeConfig, rotateWithOffsetConfig, - spaceToBatchNDConfig, squareConfig, squaredDifferenceConfig, transposeConfig + nonMaxSuppressionV5Config, notEqualConfig, padV2Config, reshapeConfig, + rotateWithOffsetConfig, spaceToBatchNDConfig, squareConfig, + squaredDifferenceConfig, transposeConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-backend-cpu/src/utils/kernel_utils.ts b/tfjs-backend-cpu/src/utils/kernel_utils.ts index 8c263201c15..2a88dc3722f 100644 --- a/tfjs-backend-cpu/src/utils/kernel_utils.ts +++ b/tfjs-backend-cpu/src/utils/kernel_utils.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {BinaryInputs, KernelConfig} from '@tensorflow/tfjs-core'; +import {BinaryInputs, KernelFunc} from '@tensorflow/tfjs-core'; import {DataType, NumericDataType, TypedArray} from '@tensorflow/tfjs-core'; import {backend_util} from '@tensorflow/tfjs-core'; @@ -23,33 +23,41 @@ import {util} from '@tensorflow/tfjs-core'; import {MathBackendCPU} from '../backend_cpu'; import {assertNotComplex} from '../cpu_util'; -export function createBinaryKernelConfig( - name: string, - op: ( - aShape: number[], bShape: number[], aVals: TypedArray, - bVals: TypedArray, - dtype: DataType) => [TypedArray, number[]]): KernelConfig { - return { - kernelName: name, - backendName: 'cpu', - kernelFunc: ({inputs, backend}) => { - const {a, b} = inputs as BinaryInputs; - const cpuBackend = backend as MathBackendCPU; - assertNotComplex([a, b], name); - - const aVals = cpuBackend.data.get(a.dataId).values as TypedArray; - const bVals = cpuBackend.data.get(b.dataId).values as TypedArray; - - const [resultData, resultShape] = - op(a.shape, b.shape, aVals, bVals, a.dtype); - - const dataId = cpuBackend.write(resultData, resultShape, a.dtype); - return {dataId, shape: resultShape, dtype: a.dtype}; - } +export type SimpleBinaryOperation = (a: number, b: number) => number; +export type SimpleBinaryKernelImpl = + (aShape: number[], bShape: number[], aVals: TypedArray, bVals: TypedArray, + dtype: DataType) => [TypedArray, number[]]; + +/** + * Template that creates a `KernelFunc` for binary ops. + * @param name Kernel name. + * @param op A `SimpleBinaryKernelImpl` of the kernel. + * @param dtype Optional. If set, the result has this dtype. Otherwise, the + * result has the same dtype as the the first input. This is mainly used + * in comparison kernels, such as Equal, Less, Greater, etc. + */ +export function binaryKernelFunc( + name: string, op: SimpleBinaryOperation, dtype?: DataType): KernelFunc { + return ({inputs, backend}) => { + const {a, b} = inputs as BinaryInputs; + const cpuBackend = backend as MathBackendCPU; + assertNotComplex([a, b], name); + + const aVals = cpuBackend.data.get(a.dataId).values as TypedArray; + const bVals = cpuBackend.data.get(b.dataId).values as TypedArray; + + const $dtype = dtype || a.dtype; + + const [resultData, resultShape] = + createBinaryKernelImpl(op)(a.shape, b.shape, aVals, bVals, $dtype); + + const dataId = cpuBackend.write(resultData, resultShape, $dtype); + return {dataId, shape: resultShape, dtype: $dtype}; }; } -export function createBinaryKernelImpl(op: (a: number, b: number) => number) { +export function createBinaryKernelImpl(op: SimpleBinaryOperation): + SimpleBinaryKernelImpl { return (aShape: number[], bShape: number[], aVals: TypedArray, bVals: TypedArray, dtype: DataType): [TypedArray, number[]] => { const newShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);