diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index f45855309a8..6880be3269b 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -19,7 +19,6 @@ import * as seedrandom from 'seedrandom'; import {ENGINE} from '../../engine'; import {env} from '../../environment'; - import {warn} from '../../log'; import * as array_ops_util from '../../ops/array_ops_util'; import * as axis_util from '../../ops/axis_util'; @@ -27,6 +26,7 @@ import * as broadcast_util from '../../ops/broadcast_util'; import {complex, imag, real} from '../../ops/complex_ops'; import * as concat_util from '../../ops/concat_util'; import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util'; +import {div} from '../../ops/div'; import * as erf_util from '../../ops/erf_util'; import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util'; import * as gather_nd_util from '../../ops/gather_nd_util'; @@ -47,6 +47,7 @@ import {split} from '../split_shared'; import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; import {whereImpl} from '../where_impl'; + import {assertNotComplex} from './cpu_util'; function mapActivation( @@ -385,7 +386,9 @@ export class MathBackendCPU extends KernelBackend { const b = this.exp(a); const sumExp = this.sum(b, axes).reshape(expandedShape); - return this.realDivide(b, sumExp) as T; + // TODO(annxingyuan): Call divImpl rather than op as part of softmax kernel + // modularization. + return div(b, sumExp); } subtract(a: Tensor, b: Tensor): Tensor { @@ -493,14 +496,6 @@ export class MathBackendCPU extends KernelBackend { (aValue, bValue) => aValue * bValue); } - realDivide(a: Tensor, b: Tensor): Tensor { - assertNotComplex([a, b], 'realDivide'); - - const op = (a: number, b: number) => a / b; - const outputDtype = 'float32'; - return this.broadcastedBinaryOp(a, b, outputDtype, op); - } - floorDiv(a: Tensor, b: Tensor): Tensor { assertNotComplex([a, b], 'floorDiv'); diff --git a/tfjs-core/src/backends/cpu/kernels/Div.ts b/tfjs-core/src/backends/cpu/kernels/Div.ts new file mode 100644 index 00000000000..d248a487131 --- /dev/null +++ b/tfjs-core/src/backends/cpu/kernels/Div.ts @@ -0,0 +1,22 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Div} from '../../../kernel_names'; +import {createBinaryKernelConfig} from '../utils/kernel_utils'; +import {divImpl} from './Div_impl'; + +export const divConfig = createBinaryKernelConfig(Div, divImpl); diff --git a/tfjs-core/src/backends/cpu/kernels/Div_impl.ts b/tfjs-core/src/backends/cpu/kernels/Div_impl.ts new file mode 100644 index 00000000000..2b4c060d8ff --- /dev/null +++ b/tfjs-core/src/backends/cpu/kernels/Div_impl.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {createBinaryKernelImpl} from '../utils/kernel_utils'; + +export const divImpl = createBinaryKernelImpl((a: number, b: number) => a / b); diff --git a/tfjs-core/src/backends/cpu/kernels/SquaredDifference.ts b/tfjs-core/src/backends/cpu/kernels/SquaredDifference.ts index a57bf1a156a..f7ad0205821 100644 --- a/tfjs-core/src/backends/cpu/kernels/SquaredDifference.ts +++ b/tfjs-core/src/backends/cpu/kernels/SquaredDifference.ts @@ -15,31 +15,14 @@ * ============================================================================= */ -import {SquaredDifference, SquaredDifferenceInputs} from '../../../kernel_names'; -import {KernelConfig} from '../../../kernel_registry'; -import {TypedArray} from '../../../types'; -import {MathBackendCPU} from '../backend_cpu'; -import {assertNotComplex} from '../cpu_util'; -import {broadcastedBinaryOp} from '../utils/kernel_utils'; +import {SquaredDifference} from '../../../kernel_names'; +import {createBinaryKernelImpl} from '../utils/kernel_utils'; +import {createBinaryKernelConfig} from '../utils/kernel_utils'; -export const squaredDifferenceConfig: KernelConfig = { - kernelName: SquaredDifference, - backendName: 'cpu', - kernelFunc: ({inputs, backend}) => { - const {a, b} = inputs as SquaredDifferenceInputs; - const cpuBackend = backend as MathBackendCPU; - assertNotComplex([a, b], SquaredDifference); +const squaredDifferenceImpl = createBinaryKernelImpl((aVal, bVal) => { + const diff = aVal - bVal; + return diff * diff; +}); - const aVals = cpuBackend.data.get(a.dataId).values as TypedArray; - const bVals = cpuBackend.data.get(b.dataId).values as TypedArray; - - const [resultData, resultShape] = broadcastedBinaryOp( - a.shape, b.shape, aVals, bVals, a.dtype, (aVal, bVal) => { - const diff = aVal - bVal; - return diff * diff; - }); - - const dataId = cpuBackend.write(resultData, resultShape, a.dtype); - return {dataId, shape: resultShape, dtype: a.dtype}; - } -}; +export const squaredDifferenceConfig = + createBinaryKernelConfig(SquaredDifference, squaredDifferenceImpl); diff --git a/tfjs-core/src/backends/cpu/register_all_kernels.ts b/tfjs-core/src/backends/cpu/register_all_kernels.ts index 9516f9af311..4d7f2e983cf 100644 --- a/tfjs-core/src/backends/cpu/register_all_kernels.ts +++ b/tfjs-core/src/backends/cpu/register_all_kernels.ts @@ -19,15 +19,14 @@ // the contents of this file and import only the kernels that are needed. import {KernelConfig, registerKernel} from '../../kernel_registry'; +import {divConfig} from './kernels/Div'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ - nonMaxSuppressionV5Config, - squareConfig, - squaredDifferenceConfig, + nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, divConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-core/src/backends/cpu/utils/kernel_utils.ts b/tfjs-core/src/backends/cpu/utils/kernel_utils.ts index ce21517dba4..68279392173 100644 --- a/tfjs-core/src/backends/cpu/utils/kernel_utils.ts +++ b/tfjs-core/src/backends/cpu/utils/kernel_utils.ts @@ -16,50 +16,80 @@ */ import * as backend_util from '../../../backends/backend_util'; +import {BinaryInputs} from '../../../kernel_names'; +import {KernelConfig} from '../../../kernel_registry'; import {DataType, NumericDataType, TypedArray} from '../../../types'; import * as util from '../../../util'; +import {MathBackendCPU} from '../backend_cpu'; +import {assertNotComplex} from '../cpu_util'; -export function broadcastedBinaryOp( - aShape: number[], bShape: number[], aVals: TypedArray, bVals: TypedArray, - dtype: DataType, - op: (a: number, b: number) => number): [TypedArray, number[]] { - const newShape = backend_util.assertAndGetBroadcastShape(aShape, bShape); +export function createBinaryKernelConfig( + name: string, + op: ( + aShape: number[], bShape: number[], aVals: TypedArray, + bVals: TypedArray, + dtype: DataType) => [TypedArray, number[]]): KernelConfig { + return { + kernelName: name, + backendName: 'cpu', + kernelFunc: ({inputs, backend}) => { + const {a, b} = inputs as BinaryInputs; + const cpuBackend = backend as MathBackendCPU; + assertNotComplex([a, b], name); - const resultRank = newShape.length; - const resultStrides = util.computeStrides(newShape); - const resultSize = util.sizeFromShape(newShape); + const aVals = cpuBackend.data.get(a.dataId).values as TypedArray; + const bVals = cpuBackend.data.get(b.dataId).values as TypedArray; - const result = - util.getTypedArrayFromDType(dtype as NumericDataType, resultSize); + const [resultData, resultShape] = + op(a.shape, b.shape, aVals, bVals, a.dtype); - const aRank = aShape.length; - const bRank = bShape.length; + const dataId = cpuBackend.write(resultData, resultShape, a.dtype); + return {dataId, shape: resultShape, dtype: a.dtype}; + } + }; +} - const aStrides = util.computeStrides(aShape); - const bStrides = util.computeStrides(bShape); +export function createBinaryKernelImpl(op: (a: number, b: number) => number) { + return (aShape: number[], bShape: number[], aVals: TypedArray, + bVals: TypedArray, dtype: DataType): [TypedArray, number[]] => { + const newShape = backend_util.assertAndGetBroadcastShape(aShape, bShape); - const aBroadcastDims = backend_util.getBroadcastDims(aShape, newShape); - const bBroadcastDims = backend_util.getBroadcastDims(bShape, newShape); + const resultRank = newShape.length; + const resultStrides = util.computeStrides(newShape); + const resultSize = util.sizeFromShape(newShape); - if (aBroadcastDims.length + bBroadcastDims.length === 0) { - for (let i = 0; i < result.length; ++i) { - result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]); - } - } else { - for (let i = 0; i < result.length; ++i) { - const loc = util.indexToLoc(i, resultRank, resultStrides); + const result = + util.getTypedArrayFromDType(dtype as NumericDataType, resultSize); + + const aRank = aShape.length; + const bRank = bShape.length; + + const aStrides = util.computeStrides(aShape); + const bStrides = util.computeStrides(bShape); + + const aBroadcastDims = backend_util.getBroadcastDims(aShape, newShape); + const bBroadcastDims = backend_util.getBroadcastDims(bShape, newShape); + + if (aBroadcastDims.length + bBroadcastDims.length === 0) { + for (let i = 0; i < result.length; ++i) { + result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]); + } + } else { + for (let i = 0; i < result.length; ++i) { + const loc = util.indexToLoc(i, resultRank, resultStrides); - const aLoc = loc.slice(-aRank); - aBroadcastDims.forEach(d => aLoc[d] = 0); - const aIndex = util.locToIndex(aLoc, aRank, aStrides); + const aLoc = loc.slice(-aRank); + aBroadcastDims.forEach(d => aLoc[d] = 0); + const aIndex = util.locToIndex(aLoc, aRank, aStrides); - const bLoc = loc.slice(-bRank); - bBroadcastDims.forEach(d => bLoc[d] = 0); - const bIndex = util.locToIndex(bLoc, bRank, bStrides); + const bLoc = loc.slice(-bRank); + bBroadcastDims.forEach(d => bLoc[d] = 0); + const bIndex = util.locToIndex(bLoc, bRank, bStrides); - result[i] = op(aVals[aIndex], bVals[bIndex]); + result[i] = op(aVals[aIndex], bVals[bIndex]); + } } - } - return [result, newShape]; + return [result, newShape]; + }; } diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index f513265c7aa..64d29147b21 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -30,6 +30,7 @@ import * as axis_util from '../../ops/axis_util'; import {complex, imag, real} from '../../ops/complex_ops'; import {computeOutShape} from '../../ops/concat_util'; import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util'; +import {div} from '../../ops/div'; import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util'; import * as gather_nd_util from '../../ops/gather_nd_util'; import * as reduce_util from '../../ops/reduce_util'; @@ -1372,18 +1373,6 @@ export class MathBackendWebGL extends KernelBackend { return this.reduce(a2D, 'any', a2D.dtype).reshape(outShape); } - realDivide(a: Tensor, b: Tensor): Tensor { - const op = binaryop_gpu.DIV; - const outputDtype = 'float32'; - if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { - const checkOutOfBounds = true; - return this.packedBinaryOp( - a, b, binaryop_packed_gpu.DIV, outputDtype, checkOutOfBounds); - } - const program = new BinaryOpProgram(op, a.shape, b.shape); - return this.compileAndRun(program, [a, b], outputDtype); - } - floorDiv(a: Tensor, b: Tensor): Tensor { const op = binaryop_gpu.INT_DIV; const outputDtype = 'int32'; @@ -1597,7 +1586,9 @@ export class MathBackendWebGL extends KernelBackend { const b = this.exp(a); const sumExp = this.sum(b, axes).reshape(expandedShape); - return this.realDivide(b, sumExp) as T; + // TODO(annxingyuan): Call divImpl rather than op as part of softmax kernel + // modularization. + return div(b, sumExp); } log(x: T): T { diff --git a/tfjs-core/src/backends/webgl/kernels/Div.ts b/tfjs-core/src/backends/webgl/kernels/Div.ts new file mode 100644 index 00000000000..0c88c4b9738 --- /dev/null +++ b/tfjs-core/src/backends/webgl/kernels/Div.ts @@ -0,0 +1,33 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Div, DivInputs} from '../../../kernel_names'; +import {KernelConfig} from '../../../kernel_registry'; +import {MathBackendWebGL} from '../backend_webgl'; +import {divImpl} from './Div_impl'; + +export const divConfig: KernelConfig = { + kernelName: Div, + backendName: 'webgl', + kernelFunc: ({inputs, backend}) => { + const {a, b} = inputs as DivInputs; + + const webglBackend = backend as MathBackendWebGL; + + return divImpl(a, b, webglBackend); + } +}; diff --git a/tfjs-core/src/backends/webgl/kernels/Div_impl.ts b/tfjs-core/src/backends/webgl/kernels/Div_impl.ts new file mode 100644 index 00000000000..01a2c145164 --- /dev/null +++ b/tfjs-core/src/backends/webgl/kernels/Div_impl.ts @@ -0,0 +1,35 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {env} from '../../../environment'; +import {TensorInfo} from '../../../kernel_registry'; +import {MathBackendWebGL} from '../backend_webgl'; +import * as binaryop_gpu from '../binaryop_gpu'; +import {BinaryOpProgram} from '../binaryop_gpu'; +import * as binaryop_packed_gpu from '../binaryop_packed_gpu'; +import {BinaryOpPackedProgram} from '../binaryop_packed_gpu'; + +export function divImpl( + a: TensorInfo, b: TensorInfo, backend: MathBackendWebGL): TensorInfo { + let program = new BinaryOpProgram(binaryop_gpu.DIV, a.shape, b.shape); + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + program = new BinaryOpPackedProgram( + binaryop_packed_gpu.DIV, a.shape, b.shape, true); + } + const output = backend.runWebGLProgram(program, [a, b], 'float32'); + return output; +} diff --git a/tfjs-core/src/backends/webgl/register_all_kernels.ts b/tfjs-core/src/backends/webgl/register_all_kernels.ts index f7913ac6b21..fe04734d2cb 100644 --- a/tfjs-core/src/backends/webgl/register_all_kernels.ts +++ b/tfjs-core/src/backends/webgl/register_all_kernels.ts @@ -16,6 +16,7 @@ */ import {KernelConfig, registerKernel} from '../../kernel_registry'; +import {divConfig} from './kernels/Div'; import {fromPixelsConfig} from './kernels/FromPixels'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; import {squareConfig} from './kernels/Square'; @@ -24,6 +25,7 @@ import {squaredDifferenceConfig} from './kernels/SquaredDifference'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ fromPixelsConfig, + divConfig, nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, diff --git a/tfjs-core/src/gradients/Div_grad.ts b/tfjs-core/src/gradients/Div_grad.ts new file mode 100644 index 00000000000..578f48dbdd0 --- /dev/null +++ b/tfjs-core/src/gradients/Div_grad.ts @@ -0,0 +1,53 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Div} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import * as broadcast_util from '../ops/broadcast_util'; +import {div} from '../ops/div'; +import {sum} from '../ops/reduction_ops'; +import {square} from '../ops/square'; +import {neg} from '../ops/unary_ops'; +import {Tensor} from '../tensor'; + +export const divGradConfig: GradConfig = { + kernelName: Div, + inputsToSave: ['a', 'b'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [a, b] = saved; + const outShape = + broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); + const derA = () => { + const res = div(dy, b.toFloat()); + const reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); + if (reduceAxes.length > 0) { + return sum(res, reduceAxes).reshape(a.shape); + } + return res; + }; + const derB = () => { + let res = dy.mul(a.toFloat()); + const reduceAxes = broadcast_util.getReductionAxes(b.shape, outShape); + if (reduceAxes.length > 0) { + res = sum(res, reduceAxes).reshape(b.shape); + } + const tmp = square(b); + return neg(div(res, tmp.toFloat())); + }; + return {a: derA, b: derB}; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 58d715bc94c..0b9927ef1c8 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -21,8 +21,13 @@ import {NamedTensorInfoMap} from './kernel_registry'; import {PixelData} from './types'; +export type BinaryInputs = Pick; + +export const Div = 'Div'; +export type DivInputs = BinaryInputs; + export const SquaredDifference = 'SquaredDifference'; -export type SquaredDifferenceInputs = Pick; +export type SquaredDifferenceInputs = BinaryInputs; export const Square = 'Square'; export type SquareInputs = Pick; diff --git a/tfjs-core/src/ops/binary_ops.ts b/tfjs-core/src/ops/binary_ops.ts index 9aea464bea4..ac5e0525ce4 100644 --- a/tfjs-core/src/ops/binary_ops.ts +++ b/tfjs-core/src/ops/binary_ops.ts @@ -23,7 +23,6 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; import * as broadcast_util from './broadcast_util'; -import {where} from './logical_ops'; import {op} from './operation'; import {scalar, zerosLike} from './tensor_ops'; import {neg} from './unary_ops'; @@ -378,114 +377,6 @@ function mulStrict_(a: T|TensorLike, b: T|TensorLike): T { return $a.mul($b); } -/** - * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. - * - * We also expose `tf.divStrict` which has the same signature as this op and - * asserts that `a` and `b` are the same shape (does not broadcast). - * - * ```js - * const a = tf.tensor1d([1, 4, 9, 16]); - * const b = tf.tensor1d([1, 2, 3, 4]); - * - * a.div(b).print(); // or tf.div(a, b) - * ``` - * - * ```js - * // Broadcast div a with b. - * const a = tf.tensor1d([2, 4, 6, 8]); - * const b = tf.scalar(2); - * - * a.div(b).print(); // or tf.div(a, b) - * ``` - * - * @param a The first tensor as the numerator. - * @param b The second tensor as the denominator. Must have the same dtype as - * `a`. - */ -/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ -function div_(a: Tensor|TensorLike, b: Tensor|TensorLike): T { - let $a = convertToTensor(a, 'a', 'div'); - let $b = convertToTensor(b, 'b', 'div'); - [$a, $b] = makeTypesMatch($a, $b); - - if ($a.dtype === 'int32' && $b.dtype === 'int32') { - return floorDiv($a, $b); - } - - const outShape = - broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); - const der = (dy: Tensor, saved: Tensor[]) => { - const [$a, $b] = saved; - const derA = () => { - const res = dy.div($b.toFloat()); - const reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape); - if (reduceAxes.length > 0) { - return res.sum(reduceAxes).reshape($a.shape); - } - return res; - }; - const derB = () => { - let res = dy.mul($a.toFloat()); - const reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes).reshape($b.shape); - } - const tmp = $b.square(); - return res.div(tmp.toFloat()).neg(); - }; - return {a: derA, b: derB}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.realDivide($a, $b); - save([$a, $b]); - return res; - }, {a: $a, b: $b}, der, 'Div') as T; -} - -/** - * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. Return 0 - * if denominator is 0. - * - * We also expose `tf.divStrict` which has the same signature as this op and - * asserts that `a` and `b` are the same shape (does not broadcast). - * - * ```js - * const a = tf.tensor1d([1, 4, 9, 16]); - * const b = tf.tensor1d([1, 2, 3, 4]); - * const c = tf.tensor1d([0, 0, 0, 0]); - * - * a.divNoNan(b).print(); // or tf.divNoNan(a, b) - * a.divNoNan(c).print(); // or tf.divNoNan(a, c) - * ``` - * - * ```js - * // Broadcast div a with b. - * const a = tf.tensor1d([2, 4, 6, 8]); - * const b = tf.scalar(2); - * const c = tf.scalar(0); - * - * a.divNoNan(b).print(); // or tf.divNoNan(a, b) - * a.divNoNan(c).print(); // or tf.divNoNan(a, c) - * ``` - * - * @param a The first tensor as the numerator. - * @param b The second tensor as the denominator. Must have the same dtype as - * `a`. - */ -/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ -function divNoNan_( - a: Tensor|TensorLike, b: Tensor|TensorLike): T { - let $a = convertToTensor(a, 'a', 'div'); - let $b = convertToTensor(b, 'b', 'div'); - [$a, $b] = makeTypesMatch($a, $b); - - const divResult = div($a, $b); - const zeros = zerosLike(divResult); - const bEqualsZero = $b.equal(zeros); - return where(bEqualsZero, zeros, divResult) as T; -} - /** * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. * The result is rounded with floor function. @@ -841,8 +732,6 @@ export const add = op({add_}); export const addN = op({addN_}); export const addStrict = op({addStrict_}); export const atan2 = op({atan2_}); -export const div = op({div_}); -export const divNoNan = op({divNoNan_}); export const divStrict = op({divStrict_}); export const floorDiv = op({floorDiv_}); export const maximum = op({maximum_}); diff --git a/tfjs-core/src/ops/div.ts b/tfjs-core/src/ops/div.ts new file mode 100644 index 00000000000..54364f18f05 --- /dev/null +++ b/tfjs-core/src/ops/div.ts @@ -0,0 +1,78 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {Div, DivInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {makeTypesMatch} from '../tensor_util'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {floorDiv} from './binary_ops'; +import {op} from './operation'; + +/** + * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. + * + * We also expose `tf.divStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 4, 9, 16]); + * const b = tf.tensor1d([1, 2, 3, 4]); + * + * a.div(b).print(); // or tf.div(a, b) + * ``` + * + * ```js + * // Broadcast div a with b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(2); + * + * a.div(b).print(); // or tf.div(a, b) + * ``` + * + * @param a The first tensor as the numerator. + * @param b The second tensor as the denominator. Must have the same dtype as + * `a`. + */ +/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ +function div_(a: Tensor|TensorLike, b: Tensor|TensorLike): T { + let $a = convertToTensor(a, 'a', 'div'); + let $b = convertToTensor(b, 'b', 'div'); + [$a, $b] = makeTypesMatch($a, $b); + + if ($a.dtype === 'int32' && $b.dtype === 'int32') { + return floorDiv($a, $b); + } + + const forward: ForwardFunc = (backend, save) => { + const res = backend.realDivide($a, $b); + save([$a, $b]); + return res; + }; + + const inputs: DivInputs = {a: $a, b: $b}; + const attrs = {}; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* gradient */, Div, + attrs) as T; +} + +export const div = op({div_}); diff --git a/tfjs-core/src/ops/div_no_nan.ts b/tfjs-core/src/ops/div_no_nan.ts new file mode 100644 index 00000000000..b9737316cad --- /dev/null +++ b/tfjs-core/src/ops/div_no_nan.ts @@ -0,0 +1,72 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor} from '../tensor'; +import {makeTypesMatch} from '../tensor_util'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {div} from './div'; +import {where} from './logical_ops'; +import {op} from './operation'; +import {zerosLike} from './tensor_ops'; + +/** + * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. Return 0 + * if denominator is 0. + * + * We also expose `tf.divStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 4, 9, 16]); + * const b = tf.tensor1d([1, 2, 3, 4]); + * const c = tf.tensor1d([0, 0, 0, 0]); + * + * a.divNoNan(b).print(); // or tf.divNoNan(a, b) + * a.divNoNan(c).print(); // or tf.divNoNan(a, c) + * ``` + * + * ```js + * // Broadcast div a with b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(2); + * const c = tf.scalar(0); + * + * a.divNoNan(b).print(); // or tf.divNoNan(a, b) + * a.divNoNan(c).print(); // or tf.divNoNan(a, c) + * ``` + * + * @param a The first tensor as the numerator. + * @param b The second tensor as the denominator. Must have the same dtype as + * `a`. + */ +/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ +function divNoNan_( + a: Tensor|TensorLike, b: Tensor|TensorLike): T { + // TODO: Make this into its own kernel. + let $a = convertToTensor(a, 'a', 'div'); + let $b = convertToTensor(b, 'b', 'div'); + [$a, $b] = makeTypesMatch($a, $b); + + const divResult = div($a, $b); + const zeros = zerosLike(divResult); + const bEqualsZero = $b.equal(zeros); + return where(bEqualsZero, zeros, divResult) as T; +} + +export const divNoNan = op({divNoNan_}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index c5f915f271d..7f4268fdee2 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -18,6 +18,8 @@ // Modularized ops. export {broadcastTo} from './broadcast_to'; export {clone} from './clone'; +export {div} from './div'; +export {divNoNan} from './div_no_nan'; export {eye} from './eye'; export {multinomial} from './multinomial'; export {oneHot} from './one_hot'; diff --git a/tfjs-core/src/public/chained_ops/div.ts b/tfjs-core/src/public/chained_ops/div.ts new file mode 100644 index 00000000000..5491e286965 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/div.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {div} from '../../ops/div'; +import {Tensor} from '../../tensor'; +import {Rank, TensorLike} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + div(b: Tensor|TensorLike): T; + } +} + +Tensor.prototype.div = function(b: Tensor|TensorLike): T { + return div(this, b); +}; diff --git a/tfjs-core/src/public/chained_ops/div_no_nan.ts b/tfjs-core/src/public/chained_ops/div_no_nan.ts new file mode 100644 index 00000000000..ad09451be0f --- /dev/null +++ b/tfjs-core/src/public/chained_ops/div_no_nan.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {divNoNan} from '../../ops/div_no_nan'; +import {Tensor} from '../../tensor'; +import {Rank, TensorLike} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + divNoNan(b: Tensor|TensorLike): T; + } +} + +Tensor.prototype.divNoNan = function(b: Tensor| + TensorLike): T { + return divNoNan(this, b); +}; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index 47d2f4fed92..2c4114a7d1b 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -15,7 +15,9 @@ * ============================================================================= */ -import './squared_difference'; import './broadcast_to'; +import './div'; +import './div_no_nan'; +import './squared_difference'; import './tile'; import './one_hot'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index bedd64a0518..bc54a8a6122 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -23,7 +23,8 @@ import {ALL_ENVS, describeWithFlags} from '../../jasmine_util'; // (And kerma will always load the chain augmentor files). But this gives us // flexibility to change in future. -const CHAINED_OPS = ['square', 'broadcastTo', 'tile', 'oneHot']; +const CHAINED_OPS = + ['square', 'broadcastTo', 'tile', 'oneHot', 'div', 'divNoNan']; describeWithFlags('chained ops', ALL_ENVS, () => { it('all chained ops should exist on tensor ', async () => { diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index b5af288535a..29532946a12 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -15,6 +15,7 @@ * ============================================================================= */ import {broadcastToGradConfig} from './gradients/BroadcastTo_grad'; +import {divGradConfig} from './gradients/Div_grad'; import {identityGradConfig} from './gradients/Identity_grad'; import {oneHotGradConfig} from './gradients/OneHot_grad'; import {squareGradConfig} from './gradients/Square_grad'; @@ -25,8 +26,8 @@ import {registerGradient} from './kernel_registry'; // Export all kernel configs here so that the package can auto register them const gradConfigs: GradConfig[] = [ - squareGradConfig, squaredDifferenceGradConfig, broadcastToGradConfig, - identityGradConfig, tileGradConfig, oneHotGradConfig + divGradConfig, squareGradConfig, squaredDifferenceGradConfig, + broadcastToGradConfig, identityGradConfig, tileGradConfig, oneHotGradConfig ]; for (const gradientConfig of gradConfigs) { diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index 278feb70518..d3294000a01 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -225,8 +225,6 @@ export interface OpHandler { powStrict(base: T, exp: Tensor|TensorLike): T; mul(a: Tensor, b: Tensor|TensorLike): T; mulStrict(a: T, b: T|TensorLike): T; - div(a: Tensor, b: Tensor|TensorLike): T; - divNoNan(a: Tensor, b: Tensor|TensorLike): T; floorDiv(a: Tensor, b: Tensor|TensorLike): T; divStrict(a: T, b: T|TensorLike): T; mod(a: Tensor, b: Tensor|TensorLike): T; @@ -940,14 +938,6 @@ export class Tensor { this.throwIfDisposed(); return opHandler.mulStrict(this, x); } - div(x: Tensor|TensorLike): T { - this.throwIfDisposed(); - return opHandler.div(this, x); - } - divNoNan(x: Tensor|TensorLike): T { - this.throwIfDisposed(); - return opHandler.divNoNan(this, x); - } floorDiv(x: Tensor|TensorLike): T { this.throwIfDisposed(); return opHandler.floorDiv(this, x);