Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 83 additions & 68 deletions tfjs-core/src/ops/fused_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,55 @@ import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import * as util from '../util';

import {add} from './binary_ops';
import * as broadcast_util from './broadcast_util';
import {Activation} from './fused_util';
import {conv2d as unfusedConv2d, depthwiseConv2d as unfusedDepthwiseConv2d} from './conv';
import {Activation, shouldFuse} from './fused_util';
import {matMul as unfusedMatMul} from './matmul';

import {elu, prelu, relu, relu6} from './relu_ops';

// Returns gradient for fused activation.
const getFusedDyActivation =
(dy: Tensor, y: Tensor, activation: Activation): Tensor => {
if (activation == null || activation === 'linear') {
return dy;
}
if (activation === 'relu') {
return dy.mul(y.step());
}
throw new Error(
`Gradient for activation ${activation} has not been ` +
`implemented yet.`);
};

// Returns gradient for fused bias.
const getFusedBiasGradient = (bias: Tensor, dyActivation: Tensor): Tensor => {
let res = dyActivation;
const reduceAxes =
broadcast_util.getReductionAxes(bias.shape, dyActivation.shape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes);
}
return res.reshape(bias.shape);
};

const applyActivation =
(x: Tensor, activation: Activation, preluActivationWeights?: Tensor):
Tensor => {
if (activation === 'linear') {
return x;
} else if (activation === 'relu') {
return relu(x);
} else if (activation === 'elu') {
return elu(x);
} else if (activation === 'relu6') {
return relu6(x);
} else if (activation === 'prelu') {
return prelu(x, preluActivationWeights);
}
throw new Error(`Unknown fused activation ${activation}.`);
};

/**
* Computes the dot product of two matrices with optional activation and bias.
Expand Down Expand Up @@ -65,6 +112,15 @@ function matMul_<T extends Tensor>({
activation?: Activation,
preluActivationWeights?: Tensor
}): T {
if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
let result = unfusedMatMul(a, b, transposeA, transposeB);
if (bias != null) {
result = add(result, bias);
}

return applyActivation(result, activation, preluActivationWeights) as T;
}

let $a = convertToTensor(a, 'a', 'fused matMul');
let $b = convertToTensor(b, 'b', 'fused matMul');
[$a, $b] = makeTypesMatch($a, $b);
Expand Down Expand Up @@ -126,34 +182,11 @@ function matMul_<T extends Tensor>({

const grad = (dy: Tensor3D, saved: Tensor[]) => {
const [a3D, b3D, y] = saved;

let dyActivation: Tensor3D;
if (activation == null || activation === 'linear') {
dyActivation = dy;
} else if (activation === 'relu') {
dyActivation = dy.mul(y.step());
} else {
throw new Error(
`Gradient for activation ${activation} has not been ` +
`implemented yet.`);
}
const dyActivation = getFusedDyActivation(dy, y, activation);

let biasGradient = {};
if (bias != null) {
biasGradient = {
$bias: () => {
let res = dyActivation;
// Using dyActivation as reference shape because outputShape does not
// account for the fact that we temporarily reshape inputs to 3D as
// part of batched matMul.
const reduceAxes =
broadcast_util.getReductionAxes($bias.shape, dyActivation.shape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes);
}
return res.reshape($bias.shape);
}
};
biasGradient = {$bias: () => getFusedBiasGradient($bias, dyActivation)};
}

if (!transposeA && !transposeB) {
Expand Down Expand Up @@ -295,6 +328,16 @@ function conv2d_<T extends Tensor3D|Tensor4D>({
activation?: Activation,
preluActivationWeights?: Tensor
}): T {
if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
let result = unfusedConv2d(
x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
if (bias != null) {
result = add(result, bias);
}

return applyActivation(result, activation, preluActivationWeights) as T;
}

const $x = convertToTensor(x, 'x', 'conv2d');
const $filter = convertToTensor(filter, 'filter', 'conv2d');

Expand Down Expand Up @@ -353,16 +396,7 @@ function conv2d_<T extends Tensor3D|Tensor4D>({
const grad = (dy: Tensor4D, saved: Tensor[]) => {
const [$filter, x4D, y] = saved as [Tensor4D, Tensor4D, Tensor4D];

let dyActivation: Tensor4D;
if (activation == null || activation === 'linear') {
dyActivation = dy;
} else if (activation === 'relu') {
dyActivation = dy.mul(y.step());
} else {
throw new Error(
`Gradient for activation ${activation} has not been ` +
`implemented yet.`);
}
const dyActivation = getFusedDyActivation(dy, y, activation) as Tensor4D;

util.assert(
conv_util.tupleValuesAreOne(dilations),
Expand All @@ -372,17 +406,7 @@ function conv2d_<T extends Tensor3D|Tensor4D>({

let biasGradient = {};
if (bias != null) {
biasGradient = {
$bias: () => {
let res = dyActivation;
const reduceAxes =
broadcast_util.getReductionAxes($bias.shape, dyActivation.shape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes);
}
return res.reshape($bias.shape);
}
};
biasGradient = {$bias: () => getFusedBiasGradient($bias, dyActivation)};
}

return Object.assign(
Expand Down Expand Up @@ -500,6 +524,16 @@ function depthwiseConv2d_<T extends Tensor3D|Tensor4D>({
activation?: Activation,
preluActivationWeights?: Tensor
}): T {
if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
let result = unfusedDepthwiseConv2d(
x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
if (bias != null) {
result = add(result, bias);
}

return applyActivation(result, activation, preluActivationWeights) as T;
}

const $x = convertToTensor(x, 'x', 'depthwiseConv2d');
const $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d');

Expand Down Expand Up @@ -564,30 +598,11 @@ function depthwiseConv2d_<T extends Tensor3D|Tensor4D>({
`'${dilations}'`);
const [x4D, $filter, y] = saved;

let dyActivation: Tensor4D;
if (activation == null || activation === 'linear') {
dyActivation = dy;
} else if (activation === 'relu') {
dyActivation = dy.mul(y.step());
} else {
throw new Error(
`Gradient for activation ${activation} has not been ` +
`implemented yet.`);
}
const dyActivation = getFusedDyActivation(dy, y, activation) as Tensor4D;

let biasGradient = {};
if (bias != null) {
biasGradient = {
$bias: () => {
let res = dyActivation;
const reduceAxes =
broadcast_util.getReductionAxes($bias.shape, dyActivation.shape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes);
}
return res.reshape($bias.shape);
}
};
biasGradient = {$bias: () => getFusedBiasGradient($bias, dyActivation)};
}

return Object.assign(
Expand Down
68 changes: 67 additions & 1 deletion tfjs-core/src/ops/fused_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ describeWithFlags('fused conv2d', ALL_ENVS, () => {
expectArraysClose(await dbiasFused.array(), await dbias.array());
});

it('gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias and activation',
it('gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias and relu',
async () => {
const inputDepth = 1;
const outputDepth = 1;
Expand Down Expand Up @@ -948,4 +948,70 @@ describeWithFlags('fused conv2d', ALL_ENVS, () => {
expectArraysClose(await dfilterFused.array(), await dfilter.array());
expectArraysClose(await dbiasFused.array(), await dbias.array());
});

it('gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias and elu', async () => {
const inputDepth = 1;
const outputDepth = 1;
const inputShape: [number, number, number, number] = [2, 3, 3, inputDepth];
const filterSize = 2;
const strides = 1;
const pad = 0;

const filterShape: [number, number, number, number] =
[filterSize, filterSize, inputDepth, outputDepth];
const filter = tf.tensor4d([-1, 1, -2, 0.5], filterShape);
const bias = tf.ones([2, 2, 2, 1]);

const x = tf.tensor4d(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9], inputShape);
const dy = tf.tensor4d([3, 1, 2, 0, 3, 1, 2, 0], [2, 2, 2, 1]);

const fusedGrads =
tf.grads((x: tf.Tensor4D, w: tf.Tensor4D, b) => tf.fused.conv2d({
x,
filter: w,
strides,
pad,
dataFormat: 'NHWC',
dilations: [1, 1],
bias: b,
activation: 'elu'
}));
const [dxFused, dfilterFused, dbiasFused] =
fusedGrads([x, filter, bias], dy);

const grads = tf.grads((x: tf.Tensor4D, filter: tf.Tensor4D, bias) => {
const conv = tf.conv2d(x, filter, strides, pad);
const sum = tf.add(conv, bias);
return tf.elu(sum);
});
const [dx, dfilter, dbias] = grads([x, filter, bias], dy);

expectArraysClose(await dxFused.array(), await dx.array());
expectArraysClose(await dfilterFused.array(), await dfilter.array());
expectArraysClose(await dbiasFused.array(), await dbias.array());
});

it('fused matmul with relu6', async () => {
const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]);
const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]);
const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]);
const transposeA = false;
const transposeB = false;

const fusedGrads = tf.grads((a, b) => {
return tf.fused.matMul(
{a, b, transposeA, transposeB, bias: null, activation: 'relu6'});
});
const [fusedDa, fusedDb] = fusedGrads([a, b], dy);

const grads = tf.grads((a, b) => {
const prod = tf.matMul(a, b, transposeA, transposeB);
return tf.relu6(prod);
});
const [da, db] = grads([a, b], dy);

expectArraysClose(await da.array(), await fusedDa.array());
expectArraysClose(await db.data(), await fusedDb.array());
});
});
7 changes: 7 additions & 0 deletions tfjs-core/src/ops/fused_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

import {Tensor, Tensor3D, Tensor4D} from '../tensor';

import {Conv2DInfo} from './conv_util';

export type Activation = 'linear'|'relu'|'prelu'|'elu'|'relu6';
Expand All @@ -38,3 +39,9 @@ export type FusedConv2DConfig = {
activation?: Activation,
preluActivationWeights?: Tensor
};

// Whether we should call fused ops.
export const shouldFuse = (gradientDepth: number, activation: Activation) => {
const gradientMode = gradientDepth > 0;
return !gradientMode && (activation === 'linear' || activation === 'relu');
};