diff --git a/tfjs-core/src/backends/backend.ts b/tfjs-core/src/backends/backend.ts index 7f5a14c4328..22ba73dfead 100644 --- a/tfjs-core/src/backends/backend.ts +++ b/tfjs-core/src/backends/backend.ts @@ -16,7 +16,7 @@ */ import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util'; -import {Activation, FusedBatchMatMulConfig} from '../ops/fused_util'; +import {FusedBatchMatMulConfig, FusedConv2DConfig} from '../ops/fused_util'; import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor'; import {BackendValues, DataType, PixelData, Rank, ShapeMap} from '../types'; @@ -410,8 +410,8 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer { } fusedConv2d( - x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D, - activation?: Activation, preluActivationWeights?: Tensor): Tensor4D { + {input, filter, convInfo, bias, activation, preluActivationWeights}: + FusedConv2DConfig): Tensor4D { throw new Error('Not yet implemented'); } @@ -426,6 +426,12 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer { throw new Error('Not yet implemented'); } + fusedDepthwiseConv2D( + {input, filter, convInfo, bias, activation, preluActivationWeights}: + FusedConv2DConfig): Tensor4D { + throw new Error('Not yet implemented'); + } + depthwiseConv2D(input: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D { throw new Error('Not yet implemented'); diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index df06676976d..0abdc19a8a2 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -27,7 +27,7 @@ import {complex, imag, real} from '../../ops/complex_ops'; import * as concat_util from '../../ops/concat_util'; import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util'; import * as erf_util from '../../ops/erf_util'; -import {Activation, FusedBatchMatMulConfig} from '../../ops/fused_util'; +import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util'; import * as gather_nd_util from '../../ops/gather_nd_util'; import * as ops from '../../ops/ops'; import {buffer, scalar, tensor, tensor3d, tensor4d} from '../../ops/ops'; @@ -1531,9 +1531,9 @@ export class MathBackendCPU implements KernelBackend { } fusedConv2d( - x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D, - activation?: Activation, preluActivationWeights?: Tensor): Tensor4D { - let result = this.conv2d(x, filter, convInfo); + {input, filter, convInfo, bias, activation, preluActivationWeights}: + FusedConv2DConfig): Tensor4D { + let result = this.conv2d(input, filter, convInfo); if (bias) { result = this.add(result, bias) as Tensor4D; @@ -1973,6 +1973,22 @@ export class MathBackendCPU implements KernelBackend { return dw.toTensor(); } + fusedDepthwiseConv2D( + {input, filter, convInfo, bias, activation, preluActivationWeights}: + FusedConv2DConfig): Tensor4D { + let result = this.depthwiseConv2D(input, filter, convInfo); + + if (bias) { + result = this.add(result, bias) as Tensor4D; + } + if (activation) { + result = + mapActivation(this, result, activation, preluActivationWeights) as + Tensor4D; + } + return result; + } + depthwiseConv2D(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D { this.assertNotComplex([x, filter], 'depthwiseConv2D'); diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index 492f3ac6d59..96f16b37371 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -29,7 +29,7 @@ import * as axis_util from '../../ops/axis_util'; import {complex, imag, real} from '../../ops/complex_ops'; import {computeOutShape} from '../../ops/concat_util'; import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util'; -import {Activation, FusedBatchMatMulConfig} from '../../ops/fused_util'; +import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util'; import * as gather_nd_util from '../../ops/gather_nd_util'; import * as reduce_util from '../../ops/reduce_util'; import * as scatter_nd_util from '../../ops/scatter_nd_util'; @@ -1909,7 +1909,7 @@ export class MathBackendWebGL implements KernelBackend { } private conv2dByMatMul( - x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D, + x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor, activation?: Activation, preluActivationWeights?: Tensor): Tensor4D { // Reshapes conv2D input to 2D tensors, uses matMul and then reshape the // result from 2D to 4D. @@ -2008,7 +2008,7 @@ export class MathBackendWebGL implements KernelBackend { } private conv2dWithIm2Row( - x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D, + x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor, activation?: Activation, preluActivationWeights?: Tensor): Tensor4D { // Rearranges conv2d input so each block to be convolved over forms the // column of a new matrix with shape [filterWidth * filterHeight * @@ -2067,19 +2067,19 @@ export class MathBackendWebGL implements KernelBackend { } fusedConv2d( - x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D, - activation?: Activation, preluActivationWeights?: Tensor): Tensor4D { + {input, filter, convInfo, bias, activation, preluActivationWeights}: + FusedConv2DConfig): Tensor4D { if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) { return this.conv2dByMatMul( - x, filter, convInfo, bias, activation, preluActivationWeights); + input, filter, convInfo, bias, activation, preluActivationWeights); } - if (ENV.getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) { + if (ENV.getBool('WEBGL_CONV_IM2COL') && input.shape[0] === 1) { return this.conv2dWithIm2Row( - x, filter, convInfo, bias, activation, preluActivationWeights); + input, filter, convInfo, bias, activation, preluActivationWeights); } const hasBias = bias != null; @@ -2088,7 +2088,7 @@ export class MathBackendWebGL implements KernelBackend { activation ? mapActivationToShaderProgram(activation, false) : null; const program = new Conv2DProgram( convInfo, hasBias, fusedActivation, hasPreluActivationWeights); - const inputs: TensorHandle[] = [x, filter]; + const inputs: TensorHandle[] = [input, filter]; if (bias) { inputs.push(bias); } @@ -2124,6 +2124,40 @@ export class MathBackendWebGL implements KernelBackend { return this.compileAndRun(program, [x, dy]); } + fusedDepthwiseConv2D( + {input, filter, convInfo, bias, activation, preluActivationWeights}: + FusedConv2DConfig): Tensor4D { + const shouldPackDepthwiseConv = ENV.getBool('WEBGL_PACK_DEPTHWISECONV') && + convInfo.strideWidth <= 2 && + convInfo.outChannels / convInfo.inChannels === 1; + const fusedActivation = activation ? + mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) : + null; + const inputs: Tensor[] = [input, filter]; + + const hasBias = bias != null; + const hasPreluActivationWeights = preluActivationWeights != null; + if (hasBias) { + inputs.push(bias); + } + if (hasPreluActivationWeights) { + inputs.push(preluActivationWeights); + } + + let program: DepthwiseConv2DProgram|DepthwiseConvPacked2DProgram; + if (shouldPackDepthwiseConv) { + program = new DepthwiseConvPacked2DProgram( + convInfo, hasBias, fusedActivation, hasPreluActivationWeights); + return this.compileAndRun( + program, inputs, + this.makePackedTensor(convInfo.outShape, input.dtype)); + } + + program = new DepthwiseConv2DProgram( + convInfo, hasBias, fusedActivation, hasPreluActivationWeights); + return this.compileAndRun(program, inputs); + } + depthwiseConv2D(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D { let program: DepthwiseConv2DProgram|DepthwiseConvPacked2DProgram; diff --git a/tfjs-core/src/backends/webgl/conv_gpu_depthwise.ts b/tfjs-core/src/backends/webgl/conv_gpu_depthwise.ts index 8a89c789b78..5d88d94eaf5 100644 --- a/tfjs-core/src/backends/webgl/conv_gpu_depthwise.ts +++ b/tfjs-core/src/backends/webgl/conv_gpu_depthwise.ts @@ -23,7 +23,9 @@ export class DepthwiseConv2DProgram implements GPGPUProgram { outputShape: number[]; userCode: string; - constructor(convInfo: Conv2DInfo) { + constructor( + convInfo: Conv2DInfo, addBias = false, activation: string = null, + hasPreluActivation = false) { this.outputShape = convInfo.outShape; const xNumRows = convInfo.inHeight; @@ -38,7 +40,36 @@ export class DepthwiseConv2DProgram implements GPGPUProgram { const filterWidth = convInfo.filterWidth; const channelMul = convInfo.outChannels / convInfo.inChannels; + let activationSnippet = '', applyActivationSnippet = ''; + if (activation) { + if (hasPreluActivation) { + activationSnippet = `float activation(float a) { + float b = getPreluActivationWeightsAtOutCoords(); + ${activation} + }`; + } else { + activationSnippet = ` + float activation(float x) { + ${activation} + } + `; + } + + applyActivationSnippet = `result = activation(result);`; + } + + const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; + if (addBias) { + this.variableNames.push('bias'); + } + + if (hasPreluActivation) { + this.variableNames.push('preluActivationWeights'); + } + this.userCode = ` + ${activationSnippet} + const ivec2 strides = ivec2(${strideHeight}, ${strideWidth}); const ivec2 pads = ivec2(${padTop}, ${padLeft}); @@ -76,7 +107,11 @@ export class DepthwiseConv2DProgram implements GPGPUProgram { dotProd += xVal * wVal; } } - setOutput(dotProd); + + float result = dotProd; + ${addBiasSnippet} + ${applyActivationSnippet} + setOutput(result); } `; } diff --git a/tfjs-core/src/backends/webgl/conv_packed_gpu_depthwise.ts b/tfjs-core/src/backends/webgl/conv_packed_gpu_depthwise.ts index c2e977ac1f4..8a6cfcd4196 100644 --- a/tfjs-core/src/backends/webgl/conv_packed_gpu_depthwise.ts +++ b/tfjs-core/src/backends/webgl/conv_packed_gpu_depthwise.ts @@ -26,7 +26,9 @@ export class DepthwiseConvPacked2DProgram implements GPGPUProgram { outputShape: number[]; userCode: string; - constructor(convInfo: Conv2DInfo) { + constructor( + convInfo: Conv2DInfo, addBias = false, activation: string = null, + hasPreluActivation = false) { this.outputShape = convInfo.outShape; const xNumRows = convInfo.inHeight; @@ -257,11 +259,38 @@ export class DepthwiseConvPacked2DProgram implements GPGPUProgram { for (let r = 0; r < filterHeight; r++) { for (let c = 0; c < filterWidth; c++) { - mainLoop += `result += xR${r}C${c} * wR${r}C${c};`; + mainLoop += `dotProd += xR${r}C${c} * wR${r}C${c};`; } } + let activationSnippet = '', applyActivationSnippet = ''; + if (activation) { + if (hasPreluActivation) { + activationSnippet = `vec4 activation(vec4 a) { + vec4 b = getPreluActivationWeightsAtOutCoords(); + ${activation} + }`; + } else { + activationSnippet = `vec4 activation(vec4 x) { + ${activation} + }`; + } + + applyActivationSnippet = `result = activation(result);`; + } + + const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; + if (addBias) { + this.variableNames.push('bias'); + } + + if (hasPreluActivation) { + this.variableNames.push('preluActivationWeights'); + } + this.userCode = ` + ${activationSnippet} + const ivec2 strides = ivec2(${strideHeight}, ${strideWidth}); const ivec2 pads = ivec2(${padTop}, ${padLeft}); @@ -276,10 +305,13 @@ export class DepthwiseConvPacked2DProgram implements GPGPUProgram { int xRCorner = xRCCorner.x; int xCCorner = xRCCorner.y; - vec4 result = vec4(0.); + vec4 dotProd = vec4(0.); ${mainLoop} + vec4 result = dotProd; + ${addBiasSnippet} + ${applyActivationSnippet} setOutput(result); } `; diff --git a/tfjs-core/src/ops/conv.ts b/tfjs-core/src/ops/conv.ts index 3e702f0ee62..55e0559e376 100644 --- a/tfjs-core/src/ops/conv.ts +++ b/tfjs-core/src/ops/conv.ts @@ -194,10 +194,9 @@ function conv2d_( `are not yet supported in gradients. Got dilations '${dilations}'`); return { - x: () => - conv2dDerInput_(x4D.shape, dy, $filter, strides, pad, dataFormat), + x: () => conv2dDerInput(x4D.shape, dy, $filter, strides, pad, dataFormat), $filter: () => - conv2dDerFilter_(x4D, dy, $filter.shape, strides, pad, dataFormat) + conv2dDerFilter(x4D, dy, $filter.shape, strides, pad, dataFormat) }; }; @@ -675,7 +674,7 @@ function eitherStridesOrDilationsAreOne( return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations); } -function depthwiseConv2dDerInput( +function depthwiseConv2dDerInput_( xShape: [number, number, number, number]|[number, number, number], dy: T, filter: Tensor4D, convInfo: conv_util.Conv2DInfo): T { let dy4D = dy as Tensor4D; @@ -693,7 +692,7 @@ function depthwiseConv2dDerInput( return res as T; } -function depthwiseConv2dDerFilter( +function depthwiseConv2dDerFilter_( x: T, dy: T, filterShape: [number, number, number, number], convInfo: conv_util.Conv2DInfo): Tensor4D { let x4D = x as Tensor4D; @@ -973,6 +972,8 @@ export const conv3d = op({conv3d_}); export const conv2dDerFilter = op({conv2dDerFilter_}); export const conv2dDerInput = op({conv2dDerInput_}); export const depthwiseConv2d = op({depthwiseConv2d_}); +export const depthwiseConv2dDerInput = op({depthwiseConv2dDerInput_}); +export const depthwiseConv2dDerFilter = op({depthwiseConv2dDerFilter_}); export const separableConv2d = op({separableConv2d_}); export const conv2dTranspose = op({conv2dTranspose_}); export const conv3dTranspose = op({conv3dTranspose_}); diff --git a/tfjs-core/src/ops/fused_ops.ts b/tfjs-core/src/ops/fused_ops.ts index a2f87b459e3..dfb2b07cf38 100644 --- a/tfjs-core/src/ops/fused_ops.ts +++ b/tfjs-core/src/ops/fused_ops.ts @@ -16,7 +16,7 @@ */ import {ENGINE} from '../engine'; -import {conv2dDerFilter, conv2dDerInput} from '../ops/conv'; +import {conv2dDerFilter, conv2dDerInput, depthwiseConv2dDerFilter, depthwiseConv2dDerInput} from '../ops/conv'; import * as conv_util from '../ops/conv_util'; import {op} from '../ops/operation'; import {Tensor, Tensor3D, Tensor4D} from '../tensor'; @@ -239,40 +239,43 @@ function matMul_({ * ``` * * @param obj An object with the following properties: - * - `x` The input tensor, of rank 4 or rank 3, of shape + * @param x The input tensor, of rank 4 or rank 3, of shape * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is * assumed. - * - `filter` The filter, rank 4, of shape + * @param filter The filter, rank 4, of shape * `[filterHeight, filterWidth, inDepth, outDepth]`. - * - `strides` The strides of the convolution: `[strideHeight, + * @param strides The strides of the convolution: `[strideHeight, * strideWidth]`. - * - `pad` The type of padding algorithm. - * - `same` and stride 1: output will be of same size as input, + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, * regardless of filter size. - * - `valid`: output will be smaller than input if filter is larger + * - `valid` output will be smaller than input if filter is larger * than 1x1. * - For more info, see this guide: * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( * https://www.tensorflow.org/api_guides/python/nn#Convolution) - * - `dataFormat` An optional string from: "NHWC", "NCHW". Defaults to + * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to * "NHWC". Specify the data format of the input and output data. With the * default format "NHWC", the data is stored in the order of: [batch, * height, width, channels]. Only "NHWC" is currently supported. - * - `dilations` The dilation rates: `[dilationHeight, dilationWidth]` + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` * in which we sample input values across the height and width dimensions * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single * number, then `dilationHeight == dilationWidth`. If it is greater than * 1, then all values of `strides` must be 1. - * - `dimRoundingMode` The rounding mode used when computing output + * @param dimRoundingMode` The rounding mode used when computing output * dimensions if pad is a number. If none is provided, it will not round * and error if the output is of fractional size. - * - `bias` Tensor to be added to the result. - * - `activation` Name of activation kernel (defaults to `linear`) to be applied + * @param bias` Tensor to be added to the result. + * @param activation Name of activation kernel (defaults to `linear`) to be + * applied * after biasAdd. - * - `preluActivationWeights` Tensor of prelu weights to be applied as part of a - * `prelu` activation, typically the same shape as `x`. + * @param preluActivationWeights Tensor of prelu weights to be applied as part + * of a `prelu` activation, typically the same shape as `x`. + */ +/** + * @doc {heading: 'Operations', subheading: 'Convolution', namespace: 'fused'} */ -/** @doc {heading: 'Operations', subheading: 'Convolution'} */ function conv2d_({ x, filter, @@ -410,11 +413,15 @@ function conv2d_({ } const res = ENGINE.runKernel((backend, save) => { - const res = backend.fusedConv2d( - x4D, $filter, convInfo, $bias as Tensor4D, activation, - $preluActivationWeights); + const res = backend.fusedConv2d({ + input: x4D, + filter: $filter, + convInfo, + bias: $bias, + activation, + preluActivationWeights: $preluActivationWeights + }); save([$filter, x4D, res]); - return res; }, inputs, grad); @@ -424,7 +431,217 @@ function conv2d_({ return res as T; } +/** + * Computes depthwise 2D convolution, optionally fused with adding a + * bias and applying an activation. + * + * Given a 4D `input` array and a `filter` array of shape + * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing + * `inChannels` convolutional filters of depth 1, this op applies a + * different filter to each input channel (expanding from 1 channel to + * `channelMultiplier` channels for each), then concatenates the results + * together. The output has `inChannels * channelMultiplier` channels. + * + * See + * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d]( + * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d) + * for more details. + * + * @param obj An object with the following properties: + * @param x The input tensor, of rank 4 or rank 3, of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is + * assumed. + * @param filter The filter tensor, rank 4, of shape + * `[filterHeight, filterWidth, inChannels, channelMultiplier]`. + * @param strides The strides of the convolution: `[strideHeight, + * strideWidth]`. If strides is a single number, then `strideHeight == + * strideWidth`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` + * in which we sample input values across the height and width dimensions + * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single + * number, then `dilationHeight == dilationWidth`. If it is greater than + * 1, then all values of `strides` must be 1. + * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to + * "NHWC". Specify the data format of the input and output data. With the + * default format "NHWC", the data is stored in the order of: [batch, + * height, width, channels]. Only "NHWC" is currently supported. + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + * @param bias Tensor to be added to the result. + * @param activation Name of activation kernel (defaults to `linear`). + * @param preluActivationWeights Tensor of prelu weights to be applied as part + * of a `prelu` activation, typically the same shape as `x`. + */ +/** + * @doc {heading: 'Operations', subheading: 'Convolution', namespace: 'fused'} + */ +function depthwiseConv2d_({ + x, + filter, + strides, + pad, + dataFormat = 'NHWC', + dilations = [1, 1], + dimRoundingMode, + bias, + activation = 'linear', + preluActivationWeights +}: { + x: T|TensorLike, + filter: Tensor4D|TensorLike, + strides: [number, number]|number, + pad: 'valid'|'same'|number, + dataFormat?: 'NHWC'|'NCHW', + dilations?: [number, number]|number, + dimRoundingMode?: 'floor'|'round'|'ceil', + bias?: Tensor|TensorLike, + activation?: Activation, + preluActivationWeights?: Tensor +}): T { + const $x = convertToTensor(x, 'x', 'depthwiseConv2d'); + const $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d'); + + let x4D = $x as Tensor4D; + let reshapedTo4D = false; + if ($x.rank === 3) { + reshapedTo4D = true; + x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); + } + util.assert( + x4D.rank === 4, + () => `Error in fused depthwiseConv2d: input must be rank 4, but got ` + + `rank ${x4D.rank}.`); + util.assert( + $filter.rank === 4, + () => `Error in fused depthwiseConv2d: filter must be rank 4, ` + + `but got rank ${$filter.rank}.`); + util.assert( + x4D.shape[3] === $filter.shape[2], + () => `Error in fused depthwiseConv2d: number of input channels ` + + `(${x4D.shape[3]}) must match the inChannels dimension in ` + + `filter ${$filter.shape[2]}.`); + if (dilations == null) { + dilations = [1, 1]; + } + util.assert( + conv_util.eitherStridesOrDilationsAreOne(strides, dilations), + () => + 'Error in fused depthwiseConv2d: Either strides or dilations must ' + + `be 1. Got strides ${strides} and dilations '${dilations}'`); + + if (dimRoundingMode != null) { + util.assert( + util.isInt(pad as number), + () => `Error in fused depthwiseConv2d: pad must be an integer when ` + + `using dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); + } + + const convInfo = conv_util.computeConv2DInfo( + x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, + true /* depthwise */); + + let $bias: Tensor; + if (bias != null) { + $bias = convertToTensor(bias, 'bias', 'fused conv2d'); + [$bias] = makeTypesMatch($bias, $x); + + broadcast_util.assertAndGetBroadcastShape(convInfo.outShape, $bias.shape); + } + + let $preluActivationWeights: Tensor; + if (preluActivationWeights != null) { + $preluActivationWeights = convertToTensor( + preluActivationWeights, 'prelu weights', 'fused depthwiseConv2d'); + } + + const grad = (dy: Tensor4D, saved: Tensor[]) => { + util.assert( + conv_util.tupleValuesAreOne(dilations), + () => 'Error in gradient of fused depthwiseConv2d: dilation rates ' + + `greater than 1 are not yet supported. Got dilations ` + + `'${dilations}'`); + const [x4D, $filter, y] = saved; + + let dyActivation: Tensor4D; + if (activation == null || activation === 'linear') { + dyActivation = dy; + } else if (activation === 'relu') { + dyActivation = dy.mul(y.step()); + } else { + throw new Error( + `Gradient for activation ${activation} has not been ` + + `implemented yet.`); + } + + let biasGradient = {}; + if (bias != null) { + biasGradient = { + $bias: () => { + let res = dyActivation; + const reduceAxes = + broadcast_util.getReductionAxes($bias.shape, dyActivation.shape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape($bias.shape); + } + }; + } + + return Object.assign( + { + x: () => depthwiseConv2dDerInput( + (x4D as Tensor4D).shape, dyActivation, $filter as Tensor4D, + convInfo), + $filter: () => depthwiseConv2dDerFilter( + x4D as Tensor4D, dyActivation, ($filter as Tensor4D).shape, + convInfo), + }, + biasGradient); + }; + + const inputs: { + x: Tensor, + $filter: Tensor, + $bias?: Tensor, + $preluActivationWeights?: Tensor + } = {x: x4D, $filter}; + if (bias != null) { + inputs.$bias = $bias; + } + if (preluActivationWeights != null) { + inputs.$preluActivationWeights = $preluActivationWeights; + } + + const res = ENGINE.runKernel((backend, save) => { + const res = backend.fusedDepthwiseConv2D({ + input: x4D, + filter: $filter, + convInfo, + bias: $bias, + activation, + preluActivationWeights: $preluActivationWeights + }); + save([x4D, $filter, res]); + return res; + }, inputs, grad); + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; + } + return res as T; +} + export const matMul = op({matMul_}); export const conv2d = op({conv2d_}); +export const depthwiseConv2d = op({depthwiseConv2d_}); export {Activation}; diff --git a/tfjs-core/src/ops/fused_test.ts b/tfjs-core/src/ops/fused_test.ts index bf8673c8acb..bf99b7013cc 100644 --- a/tfjs-core/src/ops/fused_test.ts +++ b/tfjs-core/src/ops/fused_test.ts @@ -296,6 +296,240 @@ describeWithFlags('fused matmul', ALL_ENVS, () => { }); }); +describeWithFlags('fused depthwiseConv2d', ALL_ENVS, () => { + it('basic', async () => { + const fSize = 2; + const pad = 'valid'; + const strides = 1; + const chMul = 1; + const inDepth = 1; + + const x = tf.tensor4d( + [ + 0.230664, 0.987388, 0.0685208, 0.419224, 0.887861, 0.731641, + 0.0741907, 0.409265, 0.351377 + ], + [1, 3, 3, inDepth]); + const w = tf.tensor4d( + [-0.303873, -0.229223, 0.144333, 0.803373], + [fSize, fSize, inDepth, chMul], + ); + + const result = tf.fused.depthwiseConv2d({x, filter: w, strides, pad}); + expect(result.shape).toEqual([1, 2, 2, 1]); + const expected = [0.47737, 0.40018, 0.00859, -0.09615]; + expectArraysClose(await result.data(), expected); + }); + + it('basic with relu', async () => { + const fSize = 2; + const pad = 'valid'; + const strides = 1; + const chMul = 1; + const inDepth = 1; + + const x = tf.tensor4d( + [ + 0.230664, 0.987388, 0.0685208, 0.419224, 0.887861, 0.731641, + 0.0741907, 0.409265, 0.351377 + ], + [1, 3, 3, inDepth]); + const w = tf.tensor4d( + [-0.303873, -0.229223, 0.144333, 0.803373], + [fSize, fSize, inDepth, chMul], + ); + + const result = tf.fused.depthwiseConv2d( + {x, filter: w, strides, pad, activation: 'relu'}); + expect(result.shape).toEqual([1, 2, 2, 1]); + const expected = [0.47737, 0.40018, 0.00859, 0]; + expectArraysClose(await result.data(), expected); + }); + + it('basic with bias and relu', async () => { + const fSize = 2; + const pad = 'valid'; + const strides = 1; + const chMul = 1; + const inDepth = 1; + + const x = tf.tensor4d( + [ + 0.230664, 0.987388, 0.0685208, 0.419224, 0.887861, 0.731641, + 0.0741907, 0.409265, 0.351377 + ], + [1, 3, 3, inDepth]); + const w = tf.tensor4d( + [-0.303873, -0.229223, 0.144333, 0.803373], + [fSize, fSize, inDepth, chMul], + ); + + const result = tf.fused.depthwiseConv2d( + {x, filter: w, strides, pad, bias: tf.scalar(1), activation: 'relu'}); + expect(result.shape).toEqual([1, 2, 2, 1]); + const expected = [1.47737, 1.40018, 1.00859, 0.90385]; + expectArraysClose(await result.data(), expected); + }); + + it('prelu', async () => { + const fSize = 3; + const pad = 'valid'; + const strides = 1; + const chMul = 1; + const inDepth = 1; + + const x = tf.tensor4d( + [ + 0.149194, 0.089009, 0.654891, 0.083324, 0.537043, 0.644331, 0.563037, + 0.211859, 0.633501, 0.186427, 0.777034, 0.50001, 0.607341, 0.95303, + 0.696479, 0.050387, 0.62045, 0.728049, 0.028043, 0.437009, 0.712881, + 0.741935, 0.974474, 0.621102, 0.171411 + ], + [1, 5, 5, inDepth]); + const alpha = tf.tensor4d( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [1, 3, 3, 1]); + const w = tf.tensor4d( + [ + -0.125386, -0.975199, -0.640437, -0.281895, -0.990968, -0.347208, + -0.889702, -0.180695, -0.691992 + ], + [fSize, fSize, inDepth, chMul], + ); + + const result = tf.fused.depthwiseConv2d({ + x, + filter: w, + strides, + pad, + activation: 'prelu', + preluActivationWeights: alpha + }); + expect(result.shape).toEqual([1, 3, 3, 1]); + const expected = [ + -0.25400, -0.50118, -0.73622, -0.94068, -1.2298, -1.84585, -2.3089, + -2.7499, -2.64077 + ]; + expectArraysClose(await result.data(), expected); + }); + + it('gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0', async () => { + const inputDepth = 1; + const outputDepth = 1; + const inputShape: [number, number, number, number] = [2, 3, 3, inputDepth]; + const filterSize = 2; + const strides = 1; + const pad = 0; + + const filterShape: [number, number, number, number] = + [filterSize, filterSize, inputDepth, outputDepth]; + const filter = tf.tensor4d([-1, 1, -2, 0.5], filterShape); + + const x = tf.tensor4d( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9], inputShape); + const dy = tf.tensor4d([3, 1, 2, 0, 3, 1, 2, 0], [2, 2, 2, 1]); + + const grads = tf.grads( + (x: tf.Tensor4D, filter: tf.Tensor4D) => + tf.fused.depthwiseConv2d({x, filter, strides, pad})); + const [dx, dfilter] = grads([x, filter], dy); + + expect(dx.shape).toEqual(x.shape); + expectArraysClose( + await dx.data(), + [-3, 2, 1, -8, 1.5, 0.5, -4, 1, 0, -3, 2, 1, -8, 1.5, 0.5, -4, 1, 0]); + + expect(dfilter.shape).toEqual(filterShape); + expectArraysClose(await dfilter.data(), [26, 38, 62, 74]); + }); + + it('gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias', async () => { + const inputDepth = 1; + const outputDepth = 1; + const inputShape: [number, number, number, number] = [2, 3, 3, inputDepth]; + const filterSize = 2; + const strides = 1; + const pad = 0; + + const filterShape: [number, number, number, number] = + [filterSize, filterSize, inputDepth, outputDepth]; + const filter = tf.tensor4d([-1, 1, -2, 0.5], filterShape); + const bias = tf.ones([2, 2, 2, 1]); + + const x = tf.tensor4d( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9], inputShape); + const dy = tf.tensor4d([3, 1, 2, 0, 3, 1, 2, 0], [2, 2, 2, 1]); + + const fusedGrads = tf.grads( + (x: tf.Tensor4D, w: tf.Tensor4D, b) => tf.fused.depthwiseConv2d({ + x, + filter: w, + strides, + pad, + dataFormat: 'NHWC', + dilations: [1, 1], + bias: b + })); + const [dxFused, dfilterFused, dbiasFused] = + fusedGrads([x, filter, bias], dy); + + const grads = tf.grads((x: tf.Tensor4D, filter: tf.Tensor4D, bias) => { + const conv = tf.depthwiseConv2d(x, filter, strides, pad); + const sum = tf.add(conv, bias); + return sum; + }); + const [dx, dfilter, dbias] = grads([x, filter, bias], dy); + + expectArraysClose(await dxFused.array(), await dx.array()); + expectArraysClose(await dfilterFused.array(), await dfilter.array()); + expectArraysClose(await dbiasFused.array(), await dbias.array()); + }); + + it('gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias and activation', + async () => { + const inputDepth = 1; + const outputDepth = 1; + const inputShape: [number, number, number, number] = + [2, 3, 3, inputDepth]; + const filterSize = 2; + const strides = 1; + const pad = 0; + + const filterShape: [number, number, number, number] = + [filterSize, filterSize, inputDepth, outputDepth]; + const filter = tf.tensor4d([-1, 1, -2, 0.5], filterShape); + const bias = tf.ones([2, 2, 2, 1]); + + const x = tf.tensor4d( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9], inputShape); + const dy = tf.tensor4d([3, 1, 2, 0, 3, 1, 2, 0], [2, 2, 2, 1]); + + const fusedGrads = tf.grads( + (x: tf.Tensor4D, w: tf.Tensor4D, b) => tf.fused.depthwiseConv2d({ + x, + filter: w, + strides, + pad, + dataFormat: 'NHWC', + dilations: [1, 1], + bias: b, + activation: 'relu' + })); + const [dxFused, dfilterFused, dbiasFused] = + fusedGrads([x, filter, bias], dy); + + const grads = tf.grads((x: tf.Tensor4D, filter: tf.Tensor4D, bias) => { + const conv = tf.depthwiseConv2d(x, filter, strides, pad); + const sum = tf.add(conv, bias); + return tf.relu(sum); + }); + const [dx, dfilter, dbias] = grads([x, filter, bias], dy); + + expectArraysClose(await dxFused.array(), await dx.array()); + expectArraysClose(await dfilterFused.array(), await dfilter.array()); + expectArraysClose(await dbiasFused.array(), await dbias.array()); + }); +}); + describeWithFlags('fused conv2d', ALL_ENVS, () => { it('basic', async () => { const inputDepth = 2; diff --git a/tfjs-core/src/ops/fused_util.ts b/tfjs-core/src/ops/fused_util.ts index a38a3ada9fc..6f76a87f19d 100644 --- a/tfjs-core/src/ops/fused_util.ts +++ b/tfjs-core/src/ops/fused_util.ts @@ -15,7 +15,8 @@ * ============================================================================= */ -import {Tensor, Tensor3D} from '../tensor'; +import {Tensor, Tensor3D, Tensor4D} from '../tensor'; +import {Conv2DInfo} from './conv_util'; export type Activation = 'linear'|'relu'|'prelu'|'elu'; @@ -28,3 +29,12 @@ export type FusedBatchMatMulConfig = { activation?: Activation, preluActivationWeights?: Tensor }; + +export type FusedConv2DConfig = { + input: Tensor4D, + filter: Tensor4D, + convInfo: Conv2DInfo, + bias?: Tensor, + activation?: Activation, + preluActivationWeights?: Tensor +};