Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions tfjs-backend-cpu/src/kernels/Div.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@
* =============================================================================
*/

import {Div} from '@tensorflow/tfjs-core';
import {createBinaryKernelConfig} from '../utils/kernel_utils';
import {divImpl} from './Div_impl';
import {Div, KernelConfig} from '@tensorflow/tfjs-core';

export const divConfig = createBinaryKernelConfig(Div, divImpl);
import {binaryKernelFunc} from '../utils/kernel_utils';

export const div = binaryKernelFunc(Div, (a: number, b: number) => a / b);

export const divConfig: KernelConfig = {
kernelName: Div,
backendName: 'cpu',
kernelFunc: div
};
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
* =============================================================================
*/

import {createBinaryKernelImpl} from '../utils/kernel_utils';
import {KernelConfig, NotEqual} from '@tensorflow/tfjs-core';

export const divImpl = createBinaryKernelImpl((a: number, b: number) => a / b);
import {binaryKernelFunc} from '../utils/kernel_utils';

export const notEqual =
binaryKernelFunc(NotEqual, ((a, b) => (a !== b) ? 1 : 0), 'bool');

export const notEqualConfig: KernelConfig = {
kernelName: NotEqual,
backendName: 'cpu',
kernelFunc: notEqual
};
22 changes: 13 additions & 9 deletions tfjs-backend-cpu/src/kernels/SquaredDifference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
* =============================================================================
*/

import {SquaredDifference} from '@tensorflow/tfjs-core';
import {createBinaryKernelImpl} from '../utils/kernel_utils';
import {createBinaryKernelConfig} from '../utils/kernel_utils';
import {KernelConfig, SquaredDifference} from '@tensorflow/tfjs-core';

const squaredDifferenceImpl = createBinaryKernelImpl((aVal, bVal) => {
const diff = aVal - bVal;
return diff * diff;
});
import {binaryKernelFunc} from '../utils/kernel_utils';

export const squaredDifferenceConfig =
createBinaryKernelConfig(SquaredDifference, squaredDifferenceImpl);
export const squaredDifference =
binaryKernelFunc(SquaredDifference, (aVal, bVal) => {
const diff = aVal - bVal;
return diff * diff;
});

export const squaredDifferenceConfig: KernelConfig = {
kernelName: SquaredDifference,
backendName: 'cpu',
kernelFunc: squaredDifference
};
6 changes: 4 additions & 2 deletions tfjs-backend-cpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {maxConfig} from './kernels/Max';
import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax';
import {nonMaxSuppressionV4Config} from './kernels/NonMaxSuppressionV4';
import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5';
import {notEqualConfig} from './kernels/NotEqual';
import {padV2Config} from './kernels/PadV2';
import {reshapeConfig} from './kernels/Reshape';
import {rotateWithOffsetConfig} from './kernels/RotateWithOffset';
Expand All @@ -43,8 +44,9 @@ const kernelConfigs: KernelConfig[] = [
cosConfig, dilation2dConfig, dilation2dBackpropInputConfig,
dilation2dBackpropFilterConfig, divConfig, flipLeftRightConfig,
identityConfig, maxPoolWithArgmaxConfig, maxConfig, nonMaxSuppressionV4Config,
nonMaxSuppressionV5Config, padV2Config, reshapeConfig, rotateWithOffsetConfig,
spaceToBatchNDConfig, squareConfig, squaredDifferenceConfig, transposeConfig
nonMaxSuppressionV5Config, notEqualConfig, padV2Config, reshapeConfig,
rotateWithOffsetConfig, spaceToBatchNDConfig, squareConfig,
squaredDifferenceConfig, transposeConfig
];

for (const kernelConfig of kernelConfigs) {
Expand Down
58 changes: 33 additions & 25 deletions tfjs-backend-cpu/src/utils/kernel_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,49 @@
* =============================================================================
*/

import {BinaryInputs, KernelConfig} from '@tensorflow/tfjs-core';
import {BinaryInputs, KernelFunc} from '@tensorflow/tfjs-core';
import {DataType, NumericDataType, TypedArray} from '@tensorflow/tfjs-core';
import {backend_util} from '@tensorflow/tfjs-core';

import {util} from '@tensorflow/tfjs-core';
import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';

export function createBinaryKernelConfig(
name: string,
op: (
aShape: number[], bShape: number[], aVals: TypedArray,
bVals: TypedArray,
dtype: DataType) => [TypedArray, number[]]): KernelConfig {
return {
kernelName: name,
backendName: 'cpu',
kernelFunc: ({inputs, backend}) => {
const {a, b} = inputs as BinaryInputs;
const cpuBackend = backend as MathBackendCPU;
assertNotComplex([a, b], name);

const aVals = cpuBackend.data.get(a.dataId).values as TypedArray;
const bVals = cpuBackend.data.get(b.dataId).values as TypedArray;

const [resultData, resultShape] =
op(a.shape, b.shape, aVals, bVals, a.dtype);

const dataId = cpuBackend.write(resultData, resultShape, a.dtype);
return {dataId, shape: resultShape, dtype: a.dtype};
}
export type SimpleBinaryOperation = (a: number, b: number) => number;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: might be useful to use this type to replace the type of the "op" parameter in createBinaryKernelImpl below.

export type SimpleBinaryKernelImpl =
(aShape: number[], bShape: number[], aVals: TypedArray, bVals: TypedArray,
dtype: DataType) => [TypedArray, number[]];

/**
* Template that creates a `KernelFunc` for binary ops.
* @param name Kernel name.
* @param op A `SimpleBinaryKernelImpl` of the kernel.
* @param dtype Optional. If set, the result has this dtype. Otherwise, the
* result has the same dtype as the the first input. This is mainly used
* in comparison kernels, such as Equal, Less, Greater, etc.
*/
export function binaryKernelFunc(
name: string, op: SimpleBinaryOperation, dtype?: DataType): KernelFunc {
return ({inputs, backend}) => {
const {a, b} = inputs as BinaryInputs;
const cpuBackend = backend as MathBackendCPU;
assertNotComplex([a, b], name);

const aVals = cpuBackend.data.get(a.dataId).values as TypedArray;
const bVals = cpuBackend.data.get(b.dataId).values as TypedArray;

const $dtype = dtype || a.dtype;

const [resultData, resultShape] =
createBinaryKernelImpl(op)(a.shape, b.shape, aVals, bVals, $dtype);

const dataId = cpuBackend.write(resultData, resultShape, $dtype);
return {dataId, shape: resultShape, dtype: $dtype};
};
}

export function createBinaryKernelImpl(op: (a: number, b: number) => number) {
export function createBinaryKernelImpl(op: SimpleBinaryOperation):
SimpleBinaryKernelImpl {
return (aShape: number[], bShape: number[], aVals: TypedArray,
bVals: TypedArray, dtype: DataType): [TypedArray, number[]] => {
const newShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);
Expand Down