diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index f45855309a8..3280e5c089a 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -19,7 +19,6 @@ import * as seedrandom from 'seedrandom'; import {ENGINE} from '../../engine'; import {env} from '../../environment'; - import {warn} from '../../log'; import * as array_ops_util from '../../ops/array_ops_util'; import * as axis_util from '../../ops/axis_util'; @@ -27,6 +26,7 @@ import * as broadcast_util from '../../ops/broadcast_util'; import {complex, imag, real} from '../../ops/complex_ops'; import * as concat_util from '../../ops/concat_util'; import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util'; +import {div} from '../../ops/div'; import * as erf_util from '../../ops/erf_util'; import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util'; import * as gather_nd_util from '../../ops/gather_nd_util'; @@ -47,6 +47,7 @@ import {split} from '../split_shared'; import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; import {whereImpl} from '../where_impl'; + import {assertNotComplex} from './cpu_util'; function mapActivation( @@ -377,17 +378,6 @@ export class MathBackendCPU extends KernelBackend { return result.toTensor() as T; } - softmax(logits: T, dim: number): T { - const axes = util.parseAxisParam([dim], logits.shape); - const maxLogit = this.max(logits, axes); - const expandedShape = axis_util.expandShapeToKeepDim(maxLogit.shape, axes); - const a = this.subtract(logits, maxLogit.reshape(expandedShape)); - const b = this.exp(a); - const sumExp = this.sum(b, axes).reshape(expandedShape); - - return this.realDivide(b, sumExp) as T; - } - subtract(a: Tensor, b: Tensor): Tensor { if (a.dtype === 'complex64' || b.dtype === 'complex64') { return this.broadcastedBinaryComplexOp( @@ -493,14 +483,6 @@ export class MathBackendCPU extends KernelBackend { (aValue, bValue) => aValue * bValue); } - realDivide(a: Tensor, b: Tensor): Tensor { - assertNotComplex([a, b], 'realDivide'); - - const op = (a: number, b: number) => a / b; - const outputDtype = 'float32'; - return this.broadcastedBinaryOp(a, b, outputDtype, op); - } - floorDiv(a: Tensor, b: Tensor): Tensor { assertNotComplex([a, b], 'floorDiv'); diff --git a/tfjs-core/src/backends/cpu/kernels/Div.ts b/tfjs-core/src/backends/cpu/kernels/Div.ts new file mode 100644 index 00000000000..ca59177542d --- /dev/null +++ b/tfjs-core/src/backends/cpu/kernels/Div.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Div} from '../../../kernel_names'; +import {createBinaryKernelConfig} from '../utils/kernel_utils'; +<<<<<<< HEAD +import {createBinaryOp} from '../utils/kernel_utils'; + +export const div = createBinaryOp((a: number, b: number) => a / b); +======= +import {createBinaryKernelImpl} from '../utils/kernel_utils'; + +export const div = createBinaryKernelImpl((a: number, b: number) => a / b); +>>>>>>> modularize_div +export const divConfig = createBinaryKernelConfig(Div, div); diff --git a/tfjs-core/src/backends/cpu/kernels/Exp.ts b/tfjs-core/src/backends/cpu/kernels/Exp.ts new file mode 100644 index 00000000000..a6d40fdd1c4 --- /dev/null +++ b/tfjs-core/src/backends/cpu/kernels/Exp.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Exp, ExpInputs} from '../../../kernel_names'; +import {KernelConfig} from '../../../kernel_registry'; +import {TypedArray} from '../../../types'; +import * as util from '../../../util'; +import {MathBackendCPU} from '../backend_cpu'; + +export const exp = (x: TypedArray): TypedArray => { + const outValues = util.getTypedArrayFromDType('float32', x.length); + + for (let i = 0; i < x.length; ++i) { + outValues[i] = Math.exp(x[i]); + } + + return outValues; +}; + +export const expConfig: KernelConfig = { + kernelName: Exp, + backendName: 'cpu', + kernelFunc: ({inputs, backend}) => { + const {x} = inputs as ExpInputs; + const cpuBackend = backend as MathBackendCPU; + + const xVals = cpuBackend.data.get(x.dataId).values as Float32Array; + + const result = exp(xVals); + + const dataId = cpuBackend.write(result, x.shape, x.dtype); + return {dataId, shape: x.shape, dtype: x.dtype}; + } +}; diff --git a/tfjs-core/src/backends/cpu/kernels/Max.ts b/tfjs-core/src/backends/cpu/kernels/Max.ts new file mode 100644 index 00000000000..2d1a589da10 --- /dev/null +++ b/tfjs-core/src/backends/cpu/kernels/Max.ts @@ -0,0 +1,69 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Max, MaxAttrs, MaxInputs} from '../../../kernel_names'; +import {KernelConfig} from '../../../kernel_registry'; +import * as axis_util from '../../../ops/axis_util'; +import {DataType, NumericDataType, TypedArray} from '../../../types'; +import * as util from '../../../util'; +import {sizeFromShape} from '../../../util'; +import {MathBackendCPU} from '../backend_cpu'; +import {assertNotComplex} from '../cpu_util'; + +export const max = + (x: TypedArray, reduceSize: number, outShape: number[], dtype: DataType): + TypedArray => { + const outValues = util.getTypedArrayFromDType( + dtype as NumericDataType, util.sizeFromShape(outShape)); + + for (let i = 0; i < x.length; ++i) { + const offset = i * reduceSize; + let max = x[offset]; + for (let j = 0; j < reduceSize; ++j) { + const value = x[offset + j]; + if (value > max) { + max = value; + } + } + outValues[i] = max; + } + + return outValues; + }; + +export const maxConfig: KernelConfig = { + kernelName: Max, + backendName: 'cpu', + kernelFunc: ({inputs, attrs, backend}) => { + const {x} = inputs as MaxInputs; + const {axes} = attrs as {} as MaxAttrs; + const cpuBackend = backend as MathBackendCPU; + + assertNotComplex(x, 'max'); + + axis_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); + + const [outShape, reduceShape] = + axis_util.computeOutAndReduceShapes(x.shape, axes); + + const xVals = cpuBackend.data.get(x.dataId).values as Float32Array; + const result = max(xVals, sizeFromShape(reduceShape), outShape, x.dtype); + + const dataId = cpuBackend.write(result, outShape, x.dtype); + return {dataId, shape: outShape, dtype: x.dtype}; + } +}; diff --git a/tfjs-core/src/backends/cpu/kernels/Softmax.ts b/tfjs-core/src/backends/cpu/kernels/Softmax.ts new file mode 100644 index 00000000000..8b4ed3f9c13 --- /dev/null +++ b/tfjs-core/src/backends/cpu/kernels/Softmax.ts @@ -0,0 +1,64 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Softmax, SoftmaxAttrs, SoftmaxInputs} from '../../../kernel_names'; +import {KernelConfig} from '../../../kernel_registry'; +import * as axis_util from '../../../ops/axis_util'; +import {parseAxisParam, sizeFromShape} from '../../../util'; +import {MathBackendCPU} from '../backend_cpu'; +import {assertNotComplex} from '../cpu_util'; + +import {div} from './Div'; +import {exp} from './Exp'; +import {max} from './Max'; +import {sub} from './Sub'; +import {sum} from './Sum'; + +export const softmaxConfig: KernelConfig = { + kernelName: Softmax, + backendName: 'cpu', + kernelFunc: ({inputs, attrs, backend}) => { + const {logits} = inputs as SoftmaxInputs; + const {dim} = attrs as {} as SoftmaxAttrs; + const cpuBackend = backend as MathBackendCPU; + assertNotComplex(logits, 'softmax'); + + const axes = parseAxisParam([dim], logits.shape); + + const [reduceOutShape, reduceShape] = + axis_util.computeOutAndReduceShapes(logits.shape, axes); + const logitsValues = + cpuBackend.data.get(logits.dataId).values as Float32Array; + const maxLogit = max( + logitsValues, sizeFromShape(reduceShape), reduceOutShape, logits.dtype); + + const expandedShape = axis_util.expandShapeToKeepDim(reduceOutShape, axes); + + const [aValues,] = + sub(logits.shape, expandedShape, logitsValues, maxLogit, logits.dtype); + + const b = exp(aValues); + + const sumExp = + sum(b, sizeFromShape(reduceShape), reduceOutShape, logits.dtype); + + const [resultData, resultShape] = + div(logits.shape, reduceShape, b, sumExp, logits.dtype); + const dataId = cpuBackend.write(resultData, resultShape, logits.dtype); + return {dataId, shape: resultShape, dtype: logits.dtype}; + } +}; diff --git a/tfjs-core/src/backends/cpu/kernels/SquaredDifference.ts b/tfjs-core/src/backends/cpu/kernels/SquaredDifference.ts index a57bf1a156a..f7ad0205821 100644 --- a/tfjs-core/src/backends/cpu/kernels/SquaredDifference.ts +++ b/tfjs-core/src/backends/cpu/kernels/SquaredDifference.ts @@ -15,31 +15,14 @@ * ============================================================================= */ -import {SquaredDifference, SquaredDifferenceInputs} from '../../../kernel_names'; -import {KernelConfig} from '../../../kernel_registry'; -import {TypedArray} from '../../../types'; -import {MathBackendCPU} from '../backend_cpu'; -import {assertNotComplex} from '../cpu_util'; -import {broadcastedBinaryOp} from '../utils/kernel_utils'; +import {SquaredDifference} from '../../../kernel_names'; +import {createBinaryKernelImpl} from '../utils/kernel_utils'; +import {createBinaryKernelConfig} from '../utils/kernel_utils'; -export const squaredDifferenceConfig: KernelConfig = { - kernelName: SquaredDifference, - backendName: 'cpu', - kernelFunc: ({inputs, backend}) => { - const {a, b} = inputs as SquaredDifferenceInputs; - const cpuBackend = backend as MathBackendCPU; - assertNotComplex([a, b], SquaredDifference); +const squaredDifferenceImpl = createBinaryKernelImpl((aVal, bVal) => { + const diff = aVal - bVal; + return diff * diff; +}); - const aVals = cpuBackend.data.get(a.dataId).values as TypedArray; - const bVals = cpuBackend.data.get(b.dataId).values as TypedArray; - - const [resultData, resultShape] = broadcastedBinaryOp( - a.shape, b.shape, aVals, bVals, a.dtype, (aVal, bVal) => { - const diff = aVal - bVal; - return diff * diff; - }); - - const dataId = cpuBackend.write(resultData, resultShape, a.dtype); - return {dataId, shape: resultShape, dtype: a.dtype}; - } -}; +export const squaredDifferenceConfig = + createBinaryKernelConfig(SquaredDifference, squaredDifferenceImpl); diff --git a/tfjs-core/src/backends/cpu/kernels/Sub.ts b/tfjs-core/src/backends/cpu/kernels/Sub.ts new file mode 100644 index 00000000000..f37ebde371f --- /dev/null +++ b/tfjs-core/src/backends/cpu/kernels/Sub.ts @@ -0,0 +1,24 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Sub} from '../../../kernel_names'; +import {createBinaryKernelConfig} from '../utils/kernel_utils'; +import {createBinaryOp} from '../utils/kernel_utils'; + +export const sub = createBinaryOp((a: number, b: number) => a - b); + +export const subConfig = createBinaryKernelConfig(Sub, sub); diff --git a/tfjs-core/src/backends/cpu/kernels/Sum.ts b/tfjs-core/src/backends/cpu/kernels/Sum.ts new file mode 100644 index 00000000000..7a567ef3d9e --- /dev/null +++ b/tfjs-core/src/backends/cpu/kernels/Sum.ts @@ -0,0 +1,69 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Sum, SumAttrs, SumInputs} from '../../../kernel_names'; +import {KernelConfig} from '../../../kernel_registry'; +import * as axis_util from '../../../ops/axis_util'; +import {upcastType} from '../../../types'; +import {DataType, NumericDataType, TypedArray} from '../../../types'; +import * as util from '../../../util'; +import {sizeFromShape} from '../../../util'; +import {MathBackendCPU} from '../backend_cpu'; +import {assertNotComplex} from '../cpu_util'; + +export const sum = + (x: TypedArray, reduceSize: number, outShape: number[], dtype: DataType): + TypedArray => { + const outValues = util.getTypedArrayFromDType( + dtype as NumericDataType, util.sizeFromShape(outShape)); + + for (let i = 0; i < x.length; ++i) { + const offset = i * reduceSize; + let sum = 0; + for (let j = 0; j < reduceSize; ++j) { + const value = x[offset + j]; + sum += value; + } + outValues[i] = sum; + } + + return outValues; + }; + +export const sumConfig: KernelConfig = { + kernelName: Sum, + backendName: 'cpu', + kernelFunc: ({inputs, attrs, backend}) => { + const {x} = inputs as SumInputs; + const {axes} = attrs as {} as SumAttrs; + const cpuBackend = backend as MathBackendCPU; + + assertNotComplex(x, 'sum'); + + axis_util.assertAxesAreInnerMostDims('sum', axes, x.shape.length); + + const [outShape, reduceShape] = + axis_util.computeOutAndReduceShapes(x.shape, axes); + const resultDtype = upcastType(x.dtype, 'int32'); + + const xVals = cpuBackend.data.get(x.dataId).values as Float32Array; + const result = sum(xVals, sizeFromShape(reduceShape), outShape, x.dtype); + + const dataId = cpuBackend.write(result, outShape, resultDtype); + return {dataId, shape: outShape, dtype: resultDtype}; + } +}; diff --git a/tfjs-core/src/backends/cpu/register_all_kernels.ts b/tfjs-core/src/backends/cpu/register_all_kernels.ts index 9516f9af311..ef9282a38a2 100644 --- a/tfjs-core/src/backends/cpu/register_all_kernels.ts +++ b/tfjs-core/src/backends/cpu/register_all_kernels.ts @@ -19,15 +19,16 @@ // the contents of this file and import only the kernels that are needed. import {KernelConfig, registerKernel} from '../../kernel_registry'; +import {divConfig} from './kernels/Div'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; +import {softmaxConfig} from './kernels/Softmax'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ - nonMaxSuppressionV5Config, - squareConfig, - squaredDifferenceConfig, + nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, divConfig, + softmaxConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-core/src/backends/cpu/utils/kernel_utils.ts b/tfjs-core/src/backends/cpu/utils/kernel_utils.ts index ce21517dba4..68279392173 100644 --- a/tfjs-core/src/backends/cpu/utils/kernel_utils.ts +++ b/tfjs-core/src/backends/cpu/utils/kernel_utils.ts @@ -16,50 +16,80 @@ */ import * as backend_util from '../../../backends/backend_util'; +import {BinaryInputs} from '../../../kernel_names'; +import {KernelConfig} from '../../../kernel_registry'; import {DataType, NumericDataType, TypedArray} from '../../../types'; import * as util from '../../../util'; +import {MathBackendCPU} from '../backend_cpu'; +import {assertNotComplex} from '../cpu_util'; -export function broadcastedBinaryOp( - aShape: number[], bShape: number[], aVals: TypedArray, bVals: TypedArray, - dtype: DataType, - op: (a: number, b: number) => number): [TypedArray, number[]] { - const newShape = backend_util.assertAndGetBroadcastShape(aShape, bShape); +export function createBinaryKernelConfig( + name: string, + op: ( + aShape: number[], bShape: number[], aVals: TypedArray, + bVals: TypedArray, + dtype: DataType) => [TypedArray, number[]]): KernelConfig { + return { + kernelName: name, + backendName: 'cpu', + kernelFunc: ({inputs, backend}) => { + const {a, b} = inputs as BinaryInputs; + const cpuBackend = backend as MathBackendCPU; + assertNotComplex([a, b], name); - const resultRank = newShape.length; - const resultStrides = util.computeStrides(newShape); - const resultSize = util.sizeFromShape(newShape); + const aVals = cpuBackend.data.get(a.dataId).values as TypedArray; + const bVals = cpuBackend.data.get(b.dataId).values as TypedArray; - const result = - util.getTypedArrayFromDType(dtype as NumericDataType, resultSize); + const [resultData, resultShape] = + op(a.shape, b.shape, aVals, bVals, a.dtype); - const aRank = aShape.length; - const bRank = bShape.length; + const dataId = cpuBackend.write(resultData, resultShape, a.dtype); + return {dataId, shape: resultShape, dtype: a.dtype}; + } + }; +} - const aStrides = util.computeStrides(aShape); - const bStrides = util.computeStrides(bShape); +export function createBinaryKernelImpl(op: (a: number, b: number) => number) { + return (aShape: number[], bShape: number[], aVals: TypedArray, + bVals: TypedArray, dtype: DataType): [TypedArray, number[]] => { + const newShape = backend_util.assertAndGetBroadcastShape(aShape, bShape); - const aBroadcastDims = backend_util.getBroadcastDims(aShape, newShape); - const bBroadcastDims = backend_util.getBroadcastDims(bShape, newShape); + const resultRank = newShape.length; + const resultStrides = util.computeStrides(newShape); + const resultSize = util.sizeFromShape(newShape); - if (aBroadcastDims.length + bBroadcastDims.length === 0) { - for (let i = 0; i < result.length; ++i) { - result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]); - } - } else { - for (let i = 0; i < result.length; ++i) { - const loc = util.indexToLoc(i, resultRank, resultStrides); + const result = + util.getTypedArrayFromDType(dtype as NumericDataType, resultSize); + + const aRank = aShape.length; + const bRank = bShape.length; + + const aStrides = util.computeStrides(aShape); + const bStrides = util.computeStrides(bShape); + + const aBroadcastDims = backend_util.getBroadcastDims(aShape, newShape); + const bBroadcastDims = backend_util.getBroadcastDims(bShape, newShape); + + if (aBroadcastDims.length + bBroadcastDims.length === 0) { + for (let i = 0; i < result.length; ++i) { + result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]); + } + } else { + for (let i = 0; i < result.length; ++i) { + const loc = util.indexToLoc(i, resultRank, resultStrides); - const aLoc = loc.slice(-aRank); - aBroadcastDims.forEach(d => aLoc[d] = 0); - const aIndex = util.locToIndex(aLoc, aRank, aStrides); + const aLoc = loc.slice(-aRank); + aBroadcastDims.forEach(d => aLoc[d] = 0); + const aIndex = util.locToIndex(aLoc, aRank, aStrides); - const bLoc = loc.slice(-bRank); - bBroadcastDims.forEach(d => bLoc[d] = 0); - const bIndex = util.locToIndex(bLoc, bRank, bStrides); + const bLoc = loc.slice(-bRank); + bBroadcastDims.forEach(d => bLoc[d] = 0); + const bIndex = util.locToIndex(bLoc, bRank, bStrides); - result[i] = op(aVals[aIndex], bVals[bIndex]); + result[i] = op(aVals[aIndex], bVals[bIndex]); + } } - } - return [result, newShape]; + return [result, newShape]; + }; } diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index f513265c7aa..64d29147b21 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -30,6 +30,7 @@ import * as axis_util from '../../ops/axis_util'; import {complex, imag, real} from '../../ops/complex_ops'; import {computeOutShape} from '../../ops/concat_util'; import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util'; +import {div} from '../../ops/div'; import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util'; import * as gather_nd_util from '../../ops/gather_nd_util'; import * as reduce_util from '../../ops/reduce_util'; @@ -1372,18 +1373,6 @@ export class MathBackendWebGL extends KernelBackend { return this.reduce(a2D, 'any', a2D.dtype).reshape(outShape); } - realDivide(a: Tensor, b: Tensor): Tensor { - const op = binaryop_gpu.DIV; - const outputDtype = 'float32'; - if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { - const checkOutOfBounds = true; - return this.packedBinaryOp( - a, b, binaryop_packed_gpu.DIV, outputDtype, checkOutOfBounds); - } - const program = new BinaryOpProgram(op, a.shape, b.shape); - return this.compileAndRun(program, [a, b], outputDtype); - } - floorDiv(a: Tensor, b: Tensor): Tensor { const op = binaryop_gpu.INT_DIV; const outputDtype = 'int32'; @@ -1597,7 +1586,9 @@ export class MathBackendWebGL extends KernelBackend { const b = this.exp(a); const sumExp = this.sum(b, axes).reshape(expandedShape); - return this.realDivide(b, sumExp) as T; + // TODO(annxingyuan): Call divImpl rather than op as part of softmax kernel + // modularization. + return div(b, sumExp); } log(x: T): T { diff --git a/tfjs-core/src/backends/webgl/kernels/Div.ts b/tfjs-core/src/backends/webgl/kernels/Div.ts new file mode 100644 index 00000000000..05ce2c2717b --- /dev/null +++ b/tfjs-core/src/backends/webgl/kernels/Div.ts @@ -0,0 +1,50 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {env} from '../../../environment'; +import {Div, DivInputs} from '../../../kernel_names'; +import {KernelConfig, TensorInfo} from '../../../kernel_registry'; +import {MathBackendWebGL} from '../backend_webgl'; +import * as binaryop_gpu from '../binaryop_gpu'; +import {BinaryOpProgram} from '../binaryop_gpu'; +import * as binaryop_packed_gpu from '../binaryop_packed_gpu'; +import {BinaryOpPackedProgram} from '../binaryop_packed_gpu'; + +export function divImpl( + a: TensorInfo, b: TensorInfo, backend: MathBackendWebGL): TensorInfo { + let program = new BinaryOpProgram(binaryop_gpu.DIV, a.shape, b.shape); + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + program = new BinaryOpPackedProgram( + binaryop_packed_gpu.DIV, a.shape, b.shape, true); + } + const output = backend.runWebGLProgram(program, [a, b], 'float32'); + return output; +} + +export const divConfig: KernelConfig = { + kernelName: Div, + backendName: 'webgl', + kernelFunc: ({inputs, backend}) => { + const {a, b} = inputs as DivInputs; + + const webglBackend = backend as MathBackendWebGL; + + const out = divImpl(a, b, webglBackend); + + return {dataId: out.dataId, shape: out.shape, dtype: out.dtype}; + } +}; diff --git a/tfjs-core/src/backends/webgl/kernels/SquaredDifference.ts b/tfjs-core/src/backends/webgl/kernels/SquaredDifference.ts index 41463b2c1cb..9ff08538a74 100644 --- a/tfjs-core/src/backends/webgl/kernels/SquaredDifference.ts +++ b/tfjs-core/src/backends/webgl/kernels/SquaredDifference.ts @@ -16,7 +16,7 @@ */ import {env} from '../../../environment'; -import {SquaredDifference, SquaredDifferenceInputs} from '../../../kernel_names'; +import {BinaryInputs, SquaredDifference} from '../../../kernel_names'; import {KernelConfig} from '../../../kernel_registry'; import {MathBackendWebGL} from '../backend_webgl'; import {BinaryOpProgram} from '../binaryop_gpu'; @@ -26,7 +26,7 @@ export const squaredDifferenceConfig: KernelConfig = { kernelName: SquaredDifference, backendName: 'webgl', kernelFunc: ({inputs, backend}) => { - const {a, b} = inputs as SquaredDifferenceInputs; + const {a, b} = inputs as BinaryInputs; const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);'; const webGLBackend = backend as MathBackendWebGL; diff --git a/tfjs-core/src/backends/webgl/register_all_kernels.ts b/tfjs-core/src/backends/webgl/register_all_kernels.ts index f7913ac6b21..fe04734d2cb 100644 --- a/tfjs-core/src/backends/webgl/register_all_kernels.ts +++ b/tfjs-core/src/backends/webgl/register_all_kernels.ts @@ -16,6 +16,7 @@ */ import {KernelConfig, registerKernel} from '../../kernel_registry'; +import {divConfig} from './kernels/Div'; import {fromPixelsConfig} from './kernels/FromPixels'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; import {squareConfig} from './kernels/Square'; @@ -24,6 +25,7 @@ import {squaredDifferenceConfig} from './kernels/SquaredDifference'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ fromPixelsConfig, + divConfig, nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, diff --git a/tfjs-core/src/gradients/Div_grad.ts b/tfjs-core/src/gradients/Div_grad.ts new file mode 100644 index 00000000000..8b4fc15d985 --- /dev/null +++ b/tfjs-core/src/gradients/Div_grad.ts @@ -0,0 +1,50 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Div} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import * as broadcast_util from '../ops/broadcast_util'; +import {div} from '../ops/div'; +import {Tensor} from '../tensor'; + +export const divGradConfig: GradConfig = { + kernelName: Div, + inputsToSave: ['a', 'b'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [a, b] = saved; + const outShape = + broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); + const derA = () => { + const res = div(dy, b.toFloat()); + const reduceAxes = broadcast_util.getReductionAxes(a.shape, outShape); + if (reduceAxes.length > 0) { + return res.sum(reduceAxes).reshape(a.shape); + } + return res; + }; + const derB = () => { + let res = dy.mul(a.toFloat()); + const reduceAxes = broadcast_util.getReductionAxes(b.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes).reshape(b.shape); + } + const tmp = b.square(); + return div(res, tmp.toFloat()).neg(); + }; + return {a: derA, b: derB}; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 201d73ac05f..13acbeaf7a0 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -21,12 +21,25 @@ import {NamedTensorInfoMap} from './kernel_registry'; import {PixelData} from './types'; +export type BinaryInputs = Pick; + +export const Div = 'Div'; +export type DivInputs = BinaryInputs; + export const SquaredDifference = 'SquaredDifference'; -export type SquaredDifferenceInputs = Pick; +export type SquaredDifferenceInputs = BinaryInputs; export const Square = 'Square'; export type SquareInputs = Pick; +export const Sub = 'Sub'; + +export const Sum = 'Sum'; +export type SumInputs = Pick; +export interface SumAttrs { + axes: number[]; +} + export const NonMaxSuppressionV5 = 'NonMaxSuppressionV5'; export type NonMaxSuppressionV5Inputs = Pick; diff --git a/tfjs-core/src/ops/binary_ops.ts b/tfjs-core/src/ops/binary_ops.ts index 9aea464bea4..ac5e0525ce4 100644 --- a/tfjs-core/src/ops/binary_ops.ts +++ b/tfjs-core/src/ops/binary_ops.ts @@ -23,7 +23,6 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; import * as broadcast_util from './broadcast_util'; -import {where} from './logical_ops'; import {op} from './operation'; import {scalar, zerosLike} from './tensor_ops'; import {neg} from './unary_ops'; @@ -378,114 +377,6 @@ function mulStrict_(a: T|TensorLike, b: T|TensorLike): T { return $a.mul($b); } -/** - * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. - * - * We also expose `tf.divStrict` which has the same signature as this op and - * asserts that `a` and `b` are the same shape (does not broadcast). - * - * ```js - * const a = tf.tensor1d([1, 4, 9, 16]); - * const b = tf.tensor1d([1, 2, 3, 4]); - * - * a.div(b).print(); // or tf.div(a, b) - * ``` - * - * ```js - * // Broadcast div a with b. - * const a = tf.tensor1d([2, 4, 6, 8]); - * const b = tf.scalar(2); - * - * a.div(b).print(); // or tf.div(a, b) - * ``` - * - * @param a The first tensor as the numerator. - * @param b The second tensor as the denominator. Must have the same dtype as - * `a`. - */ -/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ -function div_(a: Tensor|TensorLike, b: Tensor|TensorLike): T { - let $a = convertToTensor(a, 'a', 'div'); - let $b = convertToTensor(b, 'b', 'div'); - [$a, $b] = makeTypesMatch($a, $b); - - if ($a.dtype === 'int32' && $b.dtype === 'int32') { - return floorDiv($a, $b); - } - - const outShape = - broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); - const der = (dy: Tensor, saved: Tensor[]) => { - const [$a, $b] = saved; - const derA = () => { - const res = dy.div($b.toFloat()); - const reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape); - if (reduceAxes.length > 0) { - return res.sum(reduceAxes).reshape($a.shape); - } - return res; - }; - const derB = () => { - let res = dy.mul($a.toFloat()); - const reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes).reshape($b.shape); - } - const tmp = $b.square(); - return res.div(tmp.toFloat()).neg(); - }; - return {a: derA, b: derB}; - }; - return ENGINE.runKernelFunc((backend, save) => { - const res = backend.realDivide($a, $b); - save([$a, $b]); - return res; - }, {a: $a, b: $b}, der, 'Div') as T; -} - -/** - * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. Return 0 - * if denominator is 0. - * - * We also expose `tf.divStrict` which has the same signature as this op and - * asserts that `a` and `b` are the same shape (does not broadcast). - * - * ```js - * const a = tf.tensor1d([1, 4, 9, 16]); - * const b = tf.tensor1d([1, 2, 3, 4]); - * const c = tf.tensor1d([0, 0, 0, 0]); - * - * a.divNoNan(b).print(); // or tf.divNoNan(a, b) - * a.divNoNan(c).print(); // or tf.divNoNan(a, c) - * ``` - * - * ```js - * // Broadcast div a with b. - * const a = tf.tensor1d([2, 4, 6, 8]); - * const b = tf.scalar(2); - * const c = tf.scalar(0); - * - * a.divNoNan(b).print(); // or tf.divNoNan(a, b) - * a.divNoNan(c).print(); // or tf.divNoNan(a, c) - * ``` - * - * @param a The first tensor as the numerator. - * @param b The second tensor as the denominator. Must have the same dtype as - * `a`. - */ -/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ -function divNoNan_( - a: Tensor|TensorLike, b: Tensor|TensorLike): T { - let $a = convertToTensor(a, 'a', 'div'); - let $b = convertToTensor(b, 'b', 'div'); - [$a, $b] = makeTypesMatch($a, $b); - - const divResult = div($a, $b); - const zeros = zerosLike(divResult); - const bEqualsZero = $b.equal(zeros); - return where(bEqualsZero, zeros, divResult) as T; -} - /** * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. * The result is rounded with floor function. @@ -841,8 +732,6 @@ export const add = op({add_}); export const addN = op({addN_}); export const addStrict = op({addStrict_}); export const atan2 = op({atan2_}); -export const div = op({div_}); -export const divNoNan = op({divNoNan_}); export const divStrict = op({divStrict_}); export const floorDiv = op({floorDiv_}); export const maximum = op({maximum_}); diff --git a/tfjs-core/src/ops/div.ts b/tfjs-core/src/ops/div.ts new file mode 100644 index 00000000000..194202f343a --- /dev/null +++ b/tfjs-core/src/ops/div.ts @@ -0,0 +1,102 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {Div, DivInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {makeTypesMatch} from '../tensor_util'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {floorDiv} from './binary_ops'; +import * as broadcast_util from './broadcast_util'; +import {op} from './operation'; + +/** + * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. + * + * We also expose `tf.divStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 4, 9, 16]); + * const b = tf.tensor1d([1, 2, 3, 4]); + * + * a.div(b).print(); // or tf.div(a, b) + * ``` + * + * ```js + * // Broadcast div a with b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(2); + * + * a.div(b).print(); // or tf.div(a, b) + * ``` + * + * @param a The first tensor as the numerator. + * @param b The second tensor as the denominator. Must have the same dtype as + * `a`. + */ +/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ +function div_(a: Tensor|TensorLike, b: Tensor|TensorLike): T { + let $a = convertToTensor(a, 'a', 'div'); + let $b = convertToTensor(b, 'b', 'div'); + [$a, $b] = makeTypesMatch($a, $b); + + if ($a.dtype === 'int32' && $b.dtype === 'int32') { + return floorDiv($a, $b); + } + + const outShape = + broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); + const der = (dy: Tensor, saved: Tensor[]) => { + const [$a, $b] = saved; + const derA = () => { + const res = dy.div($b.toFloat()); + const reduceAxes = broadcast_util.getReductionAxes($a.shape, outShape); + if (reduceAxes.length > 0) { + return res.sum(reduceAxes).reshape($a.shape); + } + return res; + }; + const derB = () => { + let res = dy.mul($a.toFloat()); + const reduceAxes = broadcast_util.getReductionAxes($b.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes).reshape($b.shape); + } + const tmp = $b.square(); + return res.div(tmp.toFloat()).neg(); + }; + return {a: derA, b: derB}; + }; + + const forward: ForwardFunc = (backend, save) => { + const res = backend.realDivide($a, $b); + save([$a, $b]); + return res; + }; + + const inputs: DivInputs = {a: $a, b: $b}; + const attrs = {}; + + return ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, der, Div, attrs) as T; +} + +export const div = op({div_}); diff --git a/tfjs-core/src/ops/divNoNan.ts b/tfjs-core/src/ops/divNoNan.ts new file mode 100644 index 00000000000..d329c28ec48 --- /dev/null +++ b/tfjs-core/src/ops/divNoNan.ts @@ -0,0 +1,71 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor} from '../tensor'; +import {makeTypesMatch} from '../tensor_util'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; + +import {div} from './div'; +import {where} from './logical_ops'; +import {op} from './operation'; +import {zerosLike} from './tensor_ops'; + +/** + * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. Return 0 + * if denominator is 0. + * + * We also expose `tf.divStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 4, 9, 16]); + * const b = tf.tensor1d([1, 2, 3, 4]); + * const c = tf.tensor1d([0, 0, 0, 0]); + * + * a.divNoNan(b).print(); // or tf.divNoNan(a, b) + * a.divNoNan(c).print(); // or tf.divNoNan(a, c) + * ``` + * + * ```js + * // Broadcast div a with b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(2); + * const c = tf.scalar(0); + * + * a.divNoNan(b).print(); // or tf.divNoNan(a, b) + * a.divNoNan(c).print(); // or tf.divNoNan(a, c) + * ``` + * + * @param a The first tensor as the numerator. + * @param b The second tensor as the denominator. Must have the same dtype as + * `a`. + */ +/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ +function divNoNan_( + a: Tensor|TensorLike, b: Tensor|TensorLike): T { + let $a = convertToTensor(a, 'a', 'div'); + let $b = convertToTensor(b, 'b', 'div'); + [$a, $b] = makeTypesMatch($a, $b); + + const divResult = div($a, $b); + const zeros = zerosLike(divResult); + const bEqualsZero = $b.equal(zeros); + return where(bEqualsZero, zeros, divResult) as T; +} + +export const divNoNan = op({divNoNan_}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index bbb169d6293..adf20011993 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -18,6 +18,8 @@ // Modularized ops. export {broadcastTo} from './broadcast_to'; export {clone} from './clone'; +export {div} from './div'; +export {divNoNan} from './divNoNan'; export {multinomial} from './multinomial'; export {rand} from './rand'; export {randomGamma} from './random_gamma'; diff --git a/tfjs-core/src/ops/reduce_util.ts b/tfjs-core/src/ops/reduce_util.ts index cb9f27d158a..a772aa03372 100644 --- a/tfjs-core/src/ops/reduce_util.ts +++ b/tfjs-core/src/ops/reduce_util.ts @@ -19,7 +19,7 @@ * Inputs of size above this threshold will be parallelized by calling multiple * shader programs. */ -import {nearestDivisor} from '../util'; +import {nearestDivisor, sizeFromShape} from '../util'; export const PARALLELIZE_THRESHOLD = 30; @@ -35,3 +35,13 @@ export function computeOptimalWindowSize(inSize: number): number { } return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize))); } + +export function reduceOutShapeFromInShape( + xShape: number[], reduceShape: number[]): number[] { + const xSize = sizeFromShape(xShape); + const inSize = sizeFromShape(reduceShape); + const batchSize = xSize / inSize; + const windowSize = computeOptimalWindowSize(inSize); + const outSize = Math.ceil(inSize / windowSize); + return [batchSize, outSize]; +} diff --git a/tfjs-core/src/ops/squared_difference.ts b/tfjs-core/src/ops/squared_difference.ts index 0653191326c..264fdb2d9f9 100644 --- a/tfjs-core/src/ops/squared_difference.ts +++ b/tfjs-core/src/ops/squared_difference.ts @@ -16,7 +16,7 @@ */ import {ENGINE, ForwardFunc} from '../engine'; -import {SquaredDifference, SquaredDifferenceInputs} from '../kernel_names'; +import {BinaryInputs, SquaredDifference} from '../kernel_names'; import {Tensor} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import {makeTypesMatch} from '../tensor_util'; @@ -27,6 +27,7 @@ import {assertAndGetBroadcastShape} from './broadcast_util'; import {op} from './operation'; import {scalar} from './tensor_ops'; + /** * Returns (a - b) * (a - b) element-wise. * Supports broadcasting. @@ -74,7 +75,7 @@ function squaredDifference_( return res; }; - const inputs: SquaredDifferenceInputs = {a: $a, b: $b}; + const inputs: BinaryInputs = {a: $a, b: $b}; const attrs = {}; const inputsToSave = [$a, $b]; diff --git a/tfjs-core/src/public/chained_ops/div.ts b/tfjs-core/src/public/chained_ops/div.ts new file mode 100644 index 00000000000..5491e286965 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/div.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {div} from '../../ops/div'; +import {Tensor} from '../../tensor'; +import {Rank, TensorLike} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + div(b: Tensor|TensorLike): T; + } +} + +Tensor.prototype.div = function(b: Tensor|TensorLike): T { + return div(this, b); +}; diff --git a/tfjs-core/src/public/chained_ops/divNoNan.ts b/tfjs-core/src/public/chained_ops/divNoNan.ts new file mode 100644 index 00000000000..b1fce46b8e0 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/divNoNan.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {divNoNan} from '../../ops/divNoNan'; +import {Tensor} from '../../tensor'; +import {Rank, TensorLike} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + divNoNan(b: Tensor|TensorLike): T; + } +} + +Tensor.prototype.divNoNan = function(b: Tensor| + TensorLike): T { + return divNoNan(this, b); +}; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index 3a79a8dd6d9..bf68f168aa0 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -15,5 +15,7 @@ * ============================================================================= */ -import './squared_difference'; import './broadcast_to'; +import './div'; +import './divNoNan'; +import './squared_difference'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 0ce556c9aa0..e111b20a729 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -24,8 +24,10 @@ import {ALL_ENVS, describeWithFlags} from '../../jasmine_util'; // flexibility to change in future. const CHAINED_OPS = [ - 'square', 'broadcastTo', + 'div', + 'divNoNan', + 'square', ]; describeWithFlags('chained ops', ALL_ENVS, () => { diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index f7a934358a1..c6dbd465fb4 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -15,6 +15,7 @@ * ============================================================================= */ import {broadcastToGradConfig} from './gradients/BroadcastTo_grad'; +import {divGradConfig} from './gradients/Div_grad'; import {identityGradConfig} from './gradients/Identity_grad'; import {squareGradConfig} from './gradients/Square_grad'; import {squaredDifferenceGradConfig} from './gradients/SquaredDifference_grad'; @@ -25,6 +26,7 @@ import {registerGradient} from './kernel_registry'; const gradConfigs: GradConfig[] = [ squareGradConfig, squaredDifferenceGradConfig, + divGradConfig, broadcastToGradConfig, identityGradConfig, ]; diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index cee26c4cde3..b4606e44bcc 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -229,8 +229,6 @@ export interface OpHandler { powStrict(base: T, exp: Tensor|TensorLike): T; mul(a: Tensor, b: Tensor|TensorLike): T; mulStrict(a: T, b: T|TensorLike): T; - div(a: Tensor, b: Tensor|TensorLike): T; - divNoNan(a: Tensor, b: Tensor|TensorLike): T; floorDiv(a: Tensor, b: Tensor|TensorLike): T; divStrict(a: T, b: T|TensorLike): T; mod(a: Tensor, b: Tensor|TensorLike): T; @@ -955,14 +953,6 @@ export class Tensor { this.throwIfDisposed(); return opHandler.mulStrict(this, x); } - div(x: Tensor|TensorLike): T { - this.throwIfDisposed(); - return opHandler.div(this, x); - } - divNoNan(x: Tensor|TensorLike): T { - this.throwIfDisposed(); - return opHandler.divNoNan(this, x); - } floorDiv(x: Tensor|TensorLike): T { this.throwIfDisposed(); return opHandler.floorDiv(this, x);