diff --git a/tfjs-core/src/ops/fused_ops.ts b/tfjs-core/src/ops/fused_ops.ts index 1a88446655b..a08d17b6061 100644 --- a/tfjs-core/src/ops/fused_ops.ts +++ b/tfjs-core/src/ops/fused_ops.ts @@ -25,8 +25,55 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; +import {add} from './binary_ops'; import * as broadcast_util from './broadcast_util'; -import {Activation} from './fused_util'; +import {conv2d as unfusedConv2d, depthwiseConv2d as unfusedDepthwiseConv2d} from './conv'; +import {Activation, shouldFuse} from './fused_util'; +import {matMul as unfusedMatMul} from './matmul'; + +import {elu, prelu, relu, relu6} from './relu_ops'; + +// Returns gradient for fused activation. +const getFusedDyActivation = + (dy: Tensor, y: Tensor, activation: Activation): Tensor => { + if (activation == null || activation === 'linear') { + return dy; + } + if (activation === 'relu') { + return dy.mul(y.step()); + } + throw new Error( + `Gradient for activation ${activation} has not been ` + + `implemented yet.`); + }; + +// Returns gradient for fused bias. +const getFusedBiasGradient = (bias: Tensor, dyActivation: Tensor): Tensor => { + let res = dyActivation; + const reduceAxes = + broadcast_util.getReductionAxes(bias.shape, dyActivation.shape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape(bias.shape); +}; + +const applyActivation = + (x: Tensor, activation: Activation, preluActivationWeights?: Tensor): + Tensor => { + if (activation === 'linear') { + return x; + } else if (activation === 'relu') { + return relu(x); + } else if (activation === 'elu') { + return elu(x); + } else if (activation === 'relu6') { + return relu6(x); + } else if (activation === 'prelu') { + return prelu(x, preluActivationWeights); + } + throw new Error(`Unknown fused activation ${activation}.`); + }; /** * Computes the dot product of two matrices with optional activation and bias. @@ -65,6 +112,15 @@ function matMul_({ activation?: Activation, preluActivationWeights?: Tensor }): T { + if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) { + let result = unfusedMatMul(a, b, transposeA, transposeB); + if (bias != null) { + result = add(result, bias); + } + + return applyActivation(result, activation, preluActivationWeights) as T; + } + let $a = convertToTensor(a, 'a', 'fused matMul'); let $b = convertToTensor(b, 'b', 'fused matMul'); [$a, $b] = makeTypesMatch($a, $b); @@ -126,34 +182,11 @@ function matMul_({ const grad = (dy: Tensor3D, saved: Tensor[]) => { const [a3D, b3D, y] = saved; - - let dyActivation: Tensor3D; - if (activation == null || activation === 'linear') { - dyActivation = dy; - } else if (activation === 'relu') { - dyActivation = dy.mul(y.step()); - } else { - throw new Error( - `Gradient for activation ${activation} has not been ` + - `implemented yet.`); - } + const dyActivation = getFusedDyActivation(dy, y, activation); let biasGradient = {}; if (bias != null) { - biasGradient = { - $bias: () => { - let res = dyActivation; - // Using dyActivation as reference shape because outputShape does not - // account for the fact that we temporarily reshape inputs to 3D as - // part of batched matMul. - const reduceAxes = - broadcast_util.getReductionAxes($bias.shape, dyActivation.shape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes); - } - return res.reshape($bias.shape); - } - }; + biasGradient = {$bias: () => getFusedBiasGradient($bias, dyActivation)}; } if (!transposeA && !transposeB) { @@ -295,6 +328,16 @@ function conv2d_({ activation?: Activation, preluActivationWeights?: Tensor }): T { + if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) { + let result = unfusedConv2d( + x, filter, strides, pad, dataFormat, dilations, dimRoundingMode); + if (bias != null) { + result = add(result, bias); + } + + return applyActivation(result, activation, preluActivationWeights) as T; + } + const $x = convertToTensor(x, 'x', 'conv2d'); const $filter = convertToTensor(filter, 'filter', 'conv2d'); @@ -353,16 +396,7 @@ function conv2d_({ const grad = (dy: Tensor4D, saved: Tensor[]) => { const [$filter, x4D, y] = saved as [Tensor4D, Tensor4D, Tensor4D]; - let dyActivation: Tensor4D; - if (activation == null || activation === 'linear') { - dyActivation = dy; - } else if (activation === 'relu') { - dyActivation = dy.mul(y.step()); - } else { - throw new Error( - `Gradient for activation ${activation} has not been ` + - `implemented yet.`); - } + const dyActivation = getFusedDyActivation(dy, y, activation) as Tensor4D; util.assert( conv_util.tupleValuesAreOne(dilations), @@ -372,17 +406,7 @@ function conv2d_({ let biasGradient = {}; if (bias != null) { - biasGradient = { - $bias: () => { - let res = dyActivation; - const reduceAxes = - broadcast_util.getReductionAxes($bias.shape, dyActivation.shape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes); - } - return res.reshape($bias.shape); - } - }; + biasGradient = {$bias: () => getFusedBiasGradient($bias, dyActivation)}; } return Object.assign( @@ -500,6 +524,16 @@ function depthwiseConv2d_({ activation?: Activation, preluActivationWeights?: Tensor }): T { + if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) { + let result = unfusedDepthwiseConv2d( + x, filter, strides, pad, dataFormat, dilations, dimRoundingMode); + if (bias != null) { + result = add(result, bias); + } + + return applyActivation(result, activation, preluActivationWeights) as T; + } + const $x = convertToTensor(x, 'x', 'depthwiseConv2d'); const $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d'); @@ -564,30 +598,11 @@ function depthwiseConv2d_({ `'${dilations}'`); const [x4D, $filter, y] = saved; - let dyActivation: Tensor4D; - if (activation == null || activation === 'linear') { - dyActivation = dy; - } else if (activation === 'relu') { - dyActivation = dy.mul(y.step()); - } else { - throw new Error( - `Gradient for activation ${activation} has not been ` + - `implemented yet.`); - } + const dyActivation = getFusedDyActivation(dy, y, activation) as Tensor4D; let biasGradient = {}; if (bias != null) { - biasGradient = { - $bias: () => { - let res = dyActivation; - const reduceAxes = - broadcast_util.getReductionAxes($bias.shape, dyActivation.shape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes); - } - return res.reshape($bias.shape); - } - }; + biasGradient = {$bias: () => getFusedBiasGradient($bias, dyActivation)}; } return Object.assign( diff --git a/tfjs-core/src/ops/fused_test.ts b/tfjs-core/src/ops/fused_test.ts index 0ca1ede5678..8d324393298 100644 --- a/tfjs-core/src/ops/fused_test.ts +++ b/tfjs-core/src/ops/fused_test.ts @@ -904,7 +904,7 @@ describeWithFlags('fused conv2d', ALL_ENVS, () => { expectArraysClose(await dbiasFused.array(), await dbias.array()); }); - it('gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias and activation', + it('gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias and relu', async () => { const inputDepth = 1; const outputDepth = 1; @@ -948,4 +948,70 @@ describeWithFlags('fused conv2d', ALL_ENVS, () => { expectArraysClose(await dfilterFused.array(), await dfilter.array()); expectArraysClose(await dbiasFused.array(), await dbias.array()); }); + + it('gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias and elu', async () => { + const inputDepth = 1; + const outputDepth = 1; + const inputShape: [number, number, number, number] = [2, 3, 3, inputDepth]; + const filterSize = 2; + const strides = 1; + const pad = 0; + + const filterShape: [number, number, number, number] = + [filterSize, filterSize, inputDepth, outputDepth]; + const filter = tf.tensor4d([-1, 1, -2, 0.5], filterShape); + const bias = tf.ones([2, 2, 2, 1]); + + const x = tf.tensor4d( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9], inputShape); + const dy = tf.tensor4d([3, 1, 2, 0, 3, 1, 2, 0], [2, 2, 2, 1]); + + const fusedGrads = + tf.grads((x: tf.Tensor4D, w: tf.Tensor4D, b) => tf.fused.conv2d({ + x, + filter: w, + strides, + pad, + dataFormat: 'NHWC', + dilations: [1, 1], + bias: b, + activation: 'elu' + })); + const [dxFused, dfilterFused, dbiasFused] = + fusedGrads([x, filter, bias], dy); + + const grads = tf.grads((x: tf.Tensor4D, filter: tf.Tensor4D, bias) => { + const conv = tf.conv2d(x, filter, strides, pad); + const sum = tf.add(conv, bias); + return tf.elu(sum); + }); + const [dx, dfilter, dbias] = grads([x, filter, bias], dy); + + expectArraysClose(await dxFused.array(), await dx.array()); + expectArraysClose(await dfilterFused.array(), await dfilter.array()); + expectArraysClose(await dbiasFused.array(), await dbias.array()); + }); + + it('fused matmul with relu6', async () => { + const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]); + const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); + const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]); + const transposeA = false; + const transposeB = false; + + const fusedGrads = tf.grads((a, b) => { + return tf.fused.matMul( + {a, b, transposeA, transposeB, bias: null, activation: 'relu6'}); + }); + const [fusedDa, fusedDb] = fusedGrads([a, b], dy); + + const grads = tf.grads((a, b) => { + const prod = tf.matMul(a, b, transposeA, transposeB); + return tf.relu6(prod); + }); + const [da, db] = grads([a, b], dy); + + expectArraysClose(await da.array(), await fusedDa.array()); + expectArraysClose(await db.data(), await fusedDb.array()); + }); }); diff --git a/tfjs-core/src/ops/fused_util.ts b/tfjs-core/src/ops/fused_util.ts index c2ad2d2d3f6..73f1fc6d9dc 100644 --- a/tfjs-core/src/ops/fused_util.ts +++ b/tfjs-core/src/ops/fused_util.ts @@ -16,6 +16,7 @@ */ import {Tensor, Tensor3D, Tensor4D} from '../tensor'; + import {Conv2DInfo} from './conv_util'; export type Activation = 'linear'|'relu'|'prelu'|'elu'|'relu6'; @@ -38,3 +39,9 @@ export type FusedConv2DConfig = { activation?: Activation, preluActivationWeights?: Tensor }; + +// Whether we should call fused ops. +export const shouldFuse = (gradientDepth: number, activation: Activation) => { + const gradientMode = gradientDepth > 0; + return !gradientMode && (activation === 'linear' || activation === 'relu'); +};