Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions tfjs-core/src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ import * as seedrandom from 'seedrandom';

import {ENGINE} from '../../engine';
import {env} from '../../environment';

import {warn} from '../../log';
import * as array_ops_util from '../../ops/array_ops_util';
import * as axis_util from '../../ops/axis_util';
import * as broadcast_util from '../../ops/broadcast_util';
import {complex, imag, real} from '../../ops/complex_ops';
import * as concat_util from '../../ops/concat_util';
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
import {div} from '../../ops/div';
import * as erf_util from '../../ops/erf_util';
import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util';
import * as gather_nd_util from '../../ops/gather_nd_util';
Expand All @@ -47,6 +47,7 @@ import {split} from '../split_shared';
import {tile} from '../tile_impl';
import {topkImpl} from '../topk_impl';
import {whereImpl} from '../where_impl';

import {assertNotComplex} from './cpu_util';

function mapActivation(
Expand Down Expand Up @@ -385,7 +386,9 @@ export class MathBackendCPU extends KernelBackend {
const b = this.exp(a);
const sumExp = this.sum(b, axes).reshape(expandedShape);

return this.realDivide(b, sumExp) as T;
// TODO(annxingyuan): Call divImpl rather than op as part of softmax kernel
// modularization.
return div(b, sumExp);
}

subtract(a: Tensor, b: Tensor): Tensor {
Expand Down Expand Up @@ -493,14 +496,6 @@ export class MathBackendCPU extends KernelBackend {
(aValue, bValue) => aValue * bValue);
}

realDivide(a: Tensor, b: Tensor): Tensor {
assertNotComplex([a, b], 'realDivide');

const op = (a: number, b: number) => a / b;
const outputDtype = 'float32';
return this.broadcastedBinaryOp(a, b, outputDtype, op);
}

floorDiv(a: Tensor, b: Tensor): Tensor {
assertNotComplex([a, b], 'floorDiv');

Expand Down
22 changes: 22 additions & 0 deletions tfjs-core/src/backends/cpu/kernels/Div.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Div} from '../../../kernel_names';
import {createBinaryKernelConfig} from '../utils/kernel_utils';
import {divImpl} from './Div_impl';

export const divConfig = createBinaryKernelConfig(Div, divImpl);
20 changes: 20 additions & 0 deletions tfjs-core/src/backends/cpu/kernels/Div_impl.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {createBinaryKernelImpl} from '../utils/kernel_utils';

export const divImpl = createBinaryKernelImpl((a: number, b: number) => a / b);
35 changes: 9 additions & 26 deletions tfjs-core/src/backends/cpu/kernels/SquaredDifference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,14 @@
* =============================================================================
*/

import {SquaredDifference, SquaredDifferenceInputs} from '../../../kernel_names';
import {KernelConfig} from '../../../kernel_registry';
import {TypedArray} from '../../../types';
import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';
import {broadcastedBinaryOp} from '../utils/kernel_utils';
import {SquaredDifference} from '../../../kernel_names';
import {createBinaryKernelImpl} from '../utils/kernel_utils';
import {createBinaryKernelConfig} from '../utils/kernel_utils';

export const squaredDifferenceConfig: KernelConfig = {
kernelName: SquaredDifference,
backendName: 'cpu',
kernelFunc: ({inputs, backend}) => {
const {a, b} = inputs as SquaredDifferenceInputs;
const cpuBackend = backend as MathBackendCPU;
assertNotComplex([a, b], SquaredDifference);
const squaredDifferenceImpl = createBinaryKernelImpl((aVal, bVal) => {
const diff = aVal - bVal;
return diff * diff;
});

const aVals = cpuBackend.data.get(a.dataId).values as TypedArray;
const bVals = cpuBackend.data.get(b.dataId).values as TypedArray;

const [resultData, resultShape] = broadcastedBinaryOp(
a.shape, b.shape, aVals, bVals, a.dtype, (aVal, bVal) => {
const diff = aVal - bVal;
return diff * diff;
});

const dataId = cpuBackend.write(resultData, resultShape, a.dtype);
return {dataId, shape: resultShape, dtype: a.dtype};
}
};
export const squaredDifferenceConfig =
createBinaryKernelConfig(SquaredDifference, squaredDifferenceImpl);
5 changes: 2 additions & 3 deletions tfjs-core/src/backends/cpu/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@
// the contents of this file and import only the kernels that are needed.
import {KernelConfig, registerKernel} from '../../kernel_registry';

import {divConfig} from './kernels/Div';
import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5';
import {squareConfig} from './kernels/Square';
import {squaredDifferenceConfig} from './kernels/SquaredDifference';

// List all kernel configs here
const kernelConfigs: KernelConfig[] = [
nonMaxSuppressionV5Config,
squareConfig,
squaredDifferenceConfig,
nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, divConfig
];

for (const kernelConfig of kernelConfigs) {
Expand Down
94 changes: 62 additions & 32 deletions tfjs-core/src/backends/cpu/utils/kernel_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,50 +16,80 @@
*/

import * as backend_util from '../../../backends/backend_util';
import {BinaryInputs} from '../../../kernel_names';
import {KernelConfig} from '../../../kernel_registry';
import {DataType, NumericDataType, TypedArray} from '../../../types';
import * as util from '../../../util';
import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';

export function broadcastedBinaryOp(
aShape: number[], bShape: number[], aVals: TypedArray, bVals: TypedArray,
dtype: DataType,
op: (a: number, b: number) => number): [TypedArray, number[]] {
const newShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);
export function createBinaryKernelConfig(
name: string,
op: (
aShape: number[], bShape: number[], aVals: TypedArray,
bVals: TypedArray,
dtype: DataType) => [TypedArray, number[]]): KernelConfig {
return {
kernelName: name,
backendName: 'cpu',
kernelFunc: ({inputs, backend}) => {
const {a, b} = inputs as BinaryInputs;
const cpuBackend = backend as MathBackendCPU;
assertNotComplex([a, b], name);

const resultRank = newShape.length;
const resultStrides = util.computeStrides(newShape);
const resultSize = util.sizeFromShape(newShape);
const aVals = cpuBackend.data.get(a.dataId).values as TypedArray;
const bVals = cpuBackend.data.get(b.dataId).values as TypedArray;

const result =
util.getTypedArrayFromDType(dtype as NumericDataType, resultSize);
const [resultData, resultShape] =
op(a.shape, b.shape, aVals, bVals, a.dtype);

const aRank = aShape.length;
const bRank = bShape.length;
const dataId = cpuBackend.write(resultData, resultShape, a.dtype);
return {dataId, shape: resultShape, dtype: a.dtype};
}
};
}

const aStrides = util.computeStrides(aShape);
const bStrides = util.computeStrides(bShape);
export function createBinaryKernelImpl(op: (a: number, b: number) => number) {
return (aShape: number[], bShape: number[], aVals: TypedArray,
bVals: TypedArray, dtype: DataType): [TypedArray, number[]] => {
const newShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);

const aBroadcastDims = backend_util.getBroadcastDims(aShape, newShape);
const bBroadcastDims = backend_util.getBroadcastDims(bShape, newShape);
const resultRank = newShape.length;
const resultStrides = util.computeStrides(newShape);
const resultSize = util.sizeFromShape(newShape);

if (aBroadcastDims.length + bBroadcastDims.length === 0) {
for (let i = 0; i < result.length; ++i) {
result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
}
} else {
for (let i = 0; i < result.length; ++i) {
const loc = util.indexToLoc(i, resultRank, resultStrides);
const result =
util.getTypedArrayFromDType(dtype as NumericDataType, resultSize);

const aRank = aShape.length;
const bRank = bShape.length;

const aStrides = util.computeStrides(aShape);
const bStrides = util.computeStrides(bShape);

const aBroadcastDims = backend_util.getBroadcastDims(aShape, newShape);
const bBroadcastDims = backend_util.getBroadcastDims(bShape, newShape);

if (aBroadcastDims.length + bBroadcastDims.length === 0) {
for (let i = 0; i < result.length; ++i) {
result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
}
} else {
for (let i = 0; i < result.length; ++i) {
const loc = util.indexToLoc(i, resultRank, resultStrides);

const aLoc = loc.slice(-aRank);
aBroadcastDims.forEach(d => aLoc[d] = 0);
const aIndex = util.locToIndex(aLoc, aRank, aStrides);
const aLoc = loc.slice(-aRank);
aBroadcastDims.forEach(d => aLoc[d] = 0);
const aIndex = util.locToIndex(aLoc, aRank, aStrides);

const bLoc = loc.slice(-bRank);
bBroadcastDims.forEach(d => bLoc[d] = 0);
const bIndex = util.locToIndex(bLoc, bRank, bStrides);
const bLoc = loc.slice(-bRank);
bBroadcastDims.forEach(d => bLoc[d] = 0);
const bIndex = util.locToIndex(bLoc, bRank, bStrides);

result[i] = op(aVals[aIndex], bVals[bIndex]);
result[i] = op(aVals[aIndex], bVals[bIndex]);
}
}
}

return [result, newShape];
return [result, newShape];
};
}
17 changes: 4 additions & 13 deletions tfjs-core/src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import * as axis_util from '../../ops/axis_util';
import {complex, imag, real} from '../../ops/complex_ops';
import {computeOutShape} from '../../ops/concat_util';
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
import {div} from '../../ops/div';
import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util';
import * as gather_nd_util from '../../ops/gather_nd_util';
import * as reduce_util from '../../ops/reduce_util';
Expand Down Expand Up @@ -1372,18 +1373,6 @@ export class MathBackendWebGL extends KernelBackend {
return this.reduce(a2D, 'any', a2D.dtype).reshape(outShape);
}

realDivide(a: Tensor, b: Tensor): Tensor {
const op = binaryop_gpu.DIV;
const outputDtype = 'float32';
if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
const checkOutOfBounds = true;
return this.packedBinaryOp(
a, b, binaryop_packed_gpu.DIV, outputDtype, checkOutOfBounds);
}
const program = new BinaryOpProgram(op, a.shape, b.shape);
return this.compileAndRun<Tensor>(program, [a, b], outputDtype);
}

floorDiv(a: Tensor, b: Tensor): Tensor {
const op = binaryop_gpu.INT_DIV;
const outputDtype = 'int32';
Expand Down Expand Up @@ -1597,7 +1586,9 @@ export class MathBackendWebGL extends KernelBackend {
const b = this.exp(a);
const sumExp = this.sum(b, axes).reshape(expandedShape);

return this.realDivide(b, sumExp) as T;
// TODO(annxingyuan): Call divImpl rather than op as part of softmax kernel
// modularization.
return div(b, sumExp);
}

log<T extends Tensor>(x: T): T {
Expand Down
33 changes: 33 additions & 0 deletions tfjs-core/src/backends/webgl/kernels/Div.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Div, DivInputs} from '../../../kernel_names';
import {KernelConfig} from '../../../kernel_registry';
import {MathBackendWebGL} from '../backend_webgl';
import {divImpl} from './Div_impl';

export const divConfig: KernelConfig = {
kernelName: Div,
backendName: 'webgl',
kernelFunc: ({inputs, backend}) => {
const {a, b} = inputs as DivInputs;

const webglBackend = backend as MathBackendWebGL;

return divImpl(a, b, webglBackend);
}
};
35 changes: 35 additions & 0 deletions tfjs-core/src/backends/webgl/kernels/Div_impl.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {env} from '../../../environment';
import {TensorInfo} from '../../../kernel_registry';
import {MathBackendWebGL} from '../backend_webgl';
import * as binaryop_gpu from '../binaryop_gpu';
import {BinaryOpProgram} from '../binaryop_gpu';
import * as binaryop_packed_gpu from '../binaryop_packed_gpu';
import {BinaryOpPackedProgram} from '../binaryop_packed_gpu';

export function divImpl(
a: TensorInfo, b: TensorInfo, backend: MathBackendWebGL): TensorInfo {
let program = new BinaryOpProgram(binaryop_gpu.DIV, a.shape, b.shape);
if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
program = new BinaryOpPackedProgram(
binaryop_packed_gpu.DIV, a.shape, b.shape, true);
}
const output = backend.runWebGLProgram(program, [a, b], 'float32');
return output;
}
2 changes: 2 additions & 0 deletions tfjs-core/src/backends/webgl/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
import {KernelConfig, registerKernel} from '../../kernel_registry';

import {divConfig} from './kernels/Div';
import {fromPixelsConfig} from './kernels/FromPixels';
import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5';
import {squareConfig} from './kernels/Square';
Expand All @@ -24,6 +25,7 @@ import {squaredDifferenceConfig} from './kernels/SquaredDifference';
// List all kernel configs here
const kernelConfigs: KernelConfig[] = [
fromPixelsConfig,
divConfig,
nonMaxSuppressionV5Config,
squareConfig,
squaredDifferenceConfig,
Expand Down
Loading