diff --git a/tfjs-backend-wasm/src/kernels/FusedConv2D.ts b/tfjs-backend-wasm/src/kernels/FusedConv2D.ts index e22ef305a73..77e426a8f4c 100644 --- a/tfjs-backend-wasm/src/kernels/FusedConv2D.ts +++ b/tfjs-backend-wasm/src/kernels/FusedConv2D.ts @@ -15,18 +15,12 @@ * ============================================================================= */ -import {backend_util, KernelConfig, KernelFunc, NamedTensorInfoMap, TensorInfo} from '@tensorflow/tfjs-core'; +import {backend_util, FusedConv2D, FusedConv2DAttrs, FusedConv2DInputs, KernelConfig, KernelFunc, Tensor4D} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; import {FusableActivation} from './types'; -interface FusedConv2DInputs extends NamedTensorInfoMap { - x: TensorInfo; - filter: TensorInfo; - bias?: TensorInfo; -} - let wasmFusedConv2d: ( xId: number, batchSize: number, inputHeight: number, inputWidth: number, filterId: number, filterHeight: number, filterWidth: number, biasId: number, @@ -66,11 +60,17 @@ function setup(backend: BackendWasm) { function fusedConv2d(args: { inputs: FusedConv2DInputs, backend: BackendWasm, - attrs: - {convInfo: backend_util.Conv2DInfo, activation: backend_util.Activation} + attrs: FusedConv2DAttrs }) { const {inputs, attrs, backend} = args; - const {convInfo, activation} = attrs; + const {x, filter, bias, preluActivationWeights} = inputs; + const {strides, pad, dilations, dataFormat, dimRoundingMode, activation} = + attrs; + + const convInfo = backend_util.computeConv2DInfo( + (x as Tensor4D).shape, (filter as Tensor4D).shape, strides, dilations, + pad, dimRoundingMode); + const fusedActivation = FusableActivation[activation as {} as keyof typeof FusableActivation]; if (fusedActivation == null) { @@ -79,7 +79,6 @@ function fusedConv2d(args: { `in the wasm backend.`); } - const {x, filter, bias, preluActivationWeights} = inputs; const xId = backend.dataIdMap.get(x.dataId).id; const filterId = backend.dataIdMap.get(filter.dataId).id; @@ -117,10 +116,10 @@ function fusedConv2d(args: { const inHeight = convInfo.inHeight; const inWidth = convInfo.inWidth; - if (convInfo.dataFormat !== 'channelsLast') { + if (dataFormat !== 'NHWC') { throw new Error( `wasm backend FusedConv2D does not support dataFormat:'` + - `${convInfo.dataFormat}'. Please use 'channelsLast'.`); + `${dataFormat}'. Please use 'NHWC'.`); } const out = backend.makeOutput(convInfo.outShape, 'float32'); @@ -137,7 +136,7 @@ function fusedConv2d(args: { } export const fusedConv2DConfig: KernelConfig = { - kernelName: 'FusedConv2D', + kernelName: FusedConv2D, backendName: 'wasm', setupFunc: setup, kernelFunc: fusedConv2d as {} as KernelFunc diff --git a/tfjs-backend-wasm/src/kernels/FusedDepthwiseConv2D.ts b/tfjs-backend-wasm/src/kernels/FusedDepthwiseConv2D.ts index e0b95f65c06..8231ccc521b 100644 --- a/tfjs-backend-wasm/src/kernels/FusedDepthwiseConv2D.ts +++ b/tfjs-backend-wasm/src/kernels/FusedDepthwiseConv2D.ts @@ -15,18 +15,12 @@ * ============================================================================= */ -import {backend_util, KernelConfig, KernelFunc, NamedTensorInfoMap, TensorInfo} from '@tensorflow/tfjs-core'; +import {backend_util, FusedDepthwiseConv2D, FusedDepthwiseConv2DAttrs, FusedDepthwiseConv2DInputs, KernelConfig, KernelFunc, Tensor4D} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; import {FusableActivation} from './types'; -interface FusedDepthwiseConv2DInputs extends NamedTensorInfoMap { - x: TensorInfo; - filter: TensorInfo; - bias?: TensorInfo; -} - let wasmFusedDepthwiseConv2d: ( xId: number, batchSize: number, inputHeight: number, inputWidth: number, filterId: number, filterHeight: number, filterWidth: number, biasId: number, @@ -38,7 +32,7 @@ let wasmFusedDepthwiseConv2d: ( function setup(backend: BackendWasm) { wasmFusedDepthwiseConv2d = - backend.wasm.cwrap('FusedDepthwiseConv2D', null /* void */, [ + backend.wasm.cwrap(FusedDepthwiseConv2D, null /* void */, [ 'number', // xId 'number', // batchSize 'number', // inputHeight @@ -67,11 +61,17 @@ function setup(backend: BackendWasm) { function fusedDepthwiseConv2d(args: { inputs: FusedDepthwiseConv2DInputs, backend: BackendWasm, - attrs: - {convInfo: backend_util.Conv2DInfo, activation: backend_util.Activation} + attrs: FusedDepthwiseConv2DAttrs }) { const {inputs, attrs, backend} = args; - const {convInfo, activation} = attrs; + const {x, filter, bias, preluActivationWeights} = inputs; + const {strides, pad, dilations, dataFormat, dimRoundingMode, activation} = + attrs; + + const convInfo = backend_util.computeConv2DInfo( + (x as Tensor4D).shape, (filter as Tensor4D).shape, strides, dilations, + pad, dimRoundingMode); + const fusedActivation = FusableActivation[activation as {} as keyof typeof FusableActivation]; if (fusedActivation == null) { @@ -80,7 +80,6 @@ function fusedDepthwiseConv2d(args: { `in the wasm backend.`); } - const {x, filter, bias, preluActivationWeights} = inputs; const xId = backend.dataIdMap.get(x.dataId).id; const filterId = backend.dataIdMap.get(filter.dataId).id; @@ -118,10 +117,10 @@ function fusedDepthwiseConv2d(args: { const inHeight = convInfo.inHeight; const inWidth = convInfo.inWidth; - if (convInfo.dataFormat !== 'channelsLast') { + if (dataFormat !== 'NHWC') { throw new Error( `wasm backend FusedDepthwiseConv2D does not support dataFormat:'` + - `${convInfo.dataFormat}'. Please use 'channelsLast'.`); + `${dataFormat}'. Please use 'NHWC'.`); } const out = backend.makeOutput(convInfo.outShape, 'float32'); @@ -138,7 +137,7 @@ function fusedDepthwiseConv2d(args: { } export const fusedDepthwiseConv2DConfig: KernelConfig = { - kernelName: 'FusedDepthwiseConv2D', + kernelName: FusedDepthwiseConv2D, backendName: 'wasm', setupFunc: setup, kernelFunc: fusedDepthwiseConv2d as {} as KernelFunc diff --git a/tfjs-backend-wasm/src/kernels/_FusedMatMul.ts b/tfjs-backend-wasm/src/kernels/_FusedMatMul.ts index 16349df81a9..3fcb21745a7 100644 --- a/tfjs-backend-wasm/src/kernels/_FusedMatMul.ts +++ b/tfjs-backend-wasm/src/kernels/_FusedMatMul.ts @@ -15,25 +15,12 @@ * ============================================================================= */ -import {KernelConfig, NamedAttrMap, NamedTensorInfoMap, TensorInfo} from '@tensorflow/tfjs-core'; +import {_FusedMatMul, _FusedMatMulAttrs, _FusedMatMulInputs, KernelConfig, KernelFunc} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; import {FusableActivation} from './types'; -interface FusedMatMulInputs extends NamedTensorInfoMap { - a: TensorInfo; - b: TensorInfo; - bias?: TensorInfo; - preluActivationWeights?: TensorInfo; -} - -interface FusedMatMulAttrs extends NamedAttrMap { - transposeA: boolean; - transposeB: boolean; - activation: FusableActivation; -} - let wasmFusedMatMul: ( aId: number, aShape: Uint8Array, aShapeSize: number, bId: number, bShape: Uint8Array, bShapeSize: number, transposeA: boolean, @@ -41,7 +28,7 @@ let wasmFusedMatMul: ( preluActivationWeightsId: number, outId: number) => void; function setup(backend: BackendWasm) { - wasmFusedMatMul = backend.wasm.cwrap('_FusedMatMul', null /* void */, [ + wasmFusedMatMul = backend.wasm.cwrap(_FusedMatMul, null /* void */, [ 'number', // a_id 'array', // a_shape 'number', // a_shape.length @@ -58,9 +45,9 @@ function setup(backend: BackendWasm) { } function fusedBatchMatMul(args: { - inputs: FusedMatMulInputs, + inputs: _FusedMatMulInputs, backend: BackendWasm, - attrs: FusedMatMulAttrs + attrs: _FusedMatMulAttrs }) { const {inputs, backend, attrs} = args; const {a, b, bias, preluActivationWeights} = inputs; @@ -114,8 +101,8 @@ function fusedBatchMatMul(args: { } export const fusedMatMulConfig: KernelConfig = { - kernelName: '_FusedMatMul', + kernelName: _FusedMatMul, backendName: 'wasm', setupFunc: setup, - kernelFunc: fusedBatchMatMul + kernelFunc: fusedBatchMatMul as {} as KernelFunc }; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index c21a3382585..507703a9298 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -117,7 +117,9 @@ const TEST_FILTERS: TestFilter[] = [ 'basic with elu', // Only fused relu, relu6, prelu activations // supported. 'gradient', // Gradients not defined yet. - 'NCHW', // xnn pack does not support channels first. + 'backProp input x=[2,3,3,1] f=[2,2,1,1] s=1 p=0', // Gradients not + // defined. + 'NCHW', // xnn pack does not support channels first. // Issue: https://github.com/tensorflow/tfjs/issues/3104. // Actual != expected. 'relu bias stride 2 x=[1,8,8,16] f=[3,3,16,1] s=[2,2] d=8 p=same', diff --git a/tfjs-core/src/backends/backend.ts b/tfjs-core/src/backends/backend.ts index a93f3552b15..66907f2fa8d 100644 --- a/tfjs-core/src/backends/backend.ts +++ b/tfjs-core/src/backends/backend.ts @@ -16,7 +16,7 @@ */ import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util'; -import {FusedBatchMatMulConfig, FusedConv2DConfig} from '../ops/fused_util'; +import {FusedBatchMatMulConfig, FusedConv2DConfig} from '../ops/fused_types'; import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor'; import {BackendValues, DataType, Rank, ShapeMap} from '../types'; diff --git a/tfjs-core/src/backends/backend_util.ts b/tfjs-core/src/backends/backend_util.ts index 6baa8b629f5..c1302627922 100644 --- a/tfjs-core/src/backends/backend_util.ts +++ b/tfjs-core/src/backends/backend_util.ts @@ -31,7 +31,8 @@ export * from '../ops/axis_util'; export * from '../ops/broadcast_util'; export * from '../ops/concat_util'; export * from '../ops/conv_util'; -export {Activation, FusedConv2DConfig} from '../ops/fused_util'; +export * from '../ops/fused_util'; +export * from '../ops/fused_types'; export * from '../ops/reduce_util'; export {BackendValues, TypedArray, upcastType, PixelData} from '../types'; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 6cd52eaa22c..425e35c14ce 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -21,6 +21,7 @@ import {ExplicitPadding} from '../src/ops/conv_util'; import {NamedTensorInfoMap, TensorInfo} from './kernel_registry'; +import {Activation} from './ops/fused_types'; import {DataType, PixelData} from './types'; export const Abs = 'Abs'; @@ -786,3 +787,50 @@ export interface RotateWithOffsetAttrs { fillValue: number|[number, number, number]; center: number|[number, number]; } + +export const _FusedMatMul = '_FusedMatMul'; +// tslint:disable-next-line: class-name +export interface _FusedMatMulInputs extends NamedTensorInfoMap { + a: TensorInfo; + b: TensorInfo; + bias?: TensorInfo; + preluActivationWeights?: TensorInfo; +} +// tslint:disable-next-line: class-name +export interface _FusedMatMulAttrs { + transposeA: boolean; + transposeB: boolean; + activation: Activation; +} + +export const FusedConv2D = 'FusedConv2D'; +export interface FusedConv2DInputs extends NamedTensorInfoMap { + x: TensorInfo; + filter: TensorInfo; + bias?: TensorInfo; + preluActivationWeights?: TensorInfo; +} +export interface FusedConv2DAttrs { + strides: [number, number]|number; + pad: 'valid'|'same'|number|ExplicitPadding; + dataFormat: 'NHWC'|'NCHW'; + dilations: [number, number]|number; + dimRoundingMode: 'floor'|'round'|'ceil'; + activation: Activation; +} + +export const FusedDepthwiseConv2D = 'FusedDepthwiseConv2D'; +export interface FusedDepthwiseConv2DInputs extends NamedTensorInfoMap { + x: TensorInfo; + filter: TensorInfo; + bias?: TensorInfo; + preluActivationWeights?: TensorInfo; +} +export interface FusedDepthwiseConv2DAttrs { + strides: [number, number]|number; + pad: 'valid'|'same'|number; + dataFormat: 'NHWC'|'NCHW'; + dilations: [number, number]|number; + dimRoundingMode: 'floor'|'round'|'ceil'; + activation: Activation; +} diff --git a/tfjs-core/src/ops/fused_conv2d.ts b/tfjs-core/src/ops/fused_conv2d.ts new file mode 100644 index 00000000000..c1e86df2c6c --- /dev/null +++ b/tfjs-core/src/ops/fused_conv2d.ts @@ -0,0 +1,269 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {customGrad} from '../gradients'; +import {FusedConv2D, FusedConv2DAttrs, FusedConv2DInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {conv2DBackpropFilter} from '../ops/conv2d_backprop_filter'; +import {conv2DBackpropInput} from '../ops/conv2d_backprop_input'; +import {Tensor, Tensor3D, Tensor4D} from '../tensor'; +import {GradSaveFunc, NamedTensorMap} from '../tensor_types'; +import {makeTypesMatch} from '../tensor_util'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {add} from './add'; +import * as broadcast_util from './broadcast_util'; +import {conv2d as unfusedConv2d} from './conv2d'; +import * as conv_util from './conv_util'; +import {Activation} from './fused_types'; +import {applyActivation, getFusedBiasGradient, getFusedDyActivation, shouldFuse} from './fused_util'; +import {op} from './operation'; + +/** + * Computes a 2D convolution over the input x, optionally fused with adding a + * bias and applying an activation. + * + * ```js + * const inputDepth = 2; + * const inShape = [2, 2, 2, inputDepth]; + * const outputDepth = 2; + * const fSize = 1; + * const pad = 0; + * const strides = 1; + * + * const x = tf.tensor4d( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + * 16], inShape); + * const w = tf.tensor4d([-1, 1, -2, 0.5], [fSize, fSize, inputDepth, + * outputDepth]); + * + * tf.fused.conv2d({ x, filter: w, strides, pad, dataFormat: 'NHWC', + * dilations: [1, 1], bias: tf.scalar(5), activation: 'relu' }).print(); + * ``` + * + * @param obj An object with the following properties: + * @param x The input tensor, of rank 4 or rank 3, of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is + * assumed. + * @param filter The filter, rank 4, of shape + * `[filterHeight, filterWidth, inDepth, outDepth]`. + * @param strides The strides of the convolution: `[strideHeight, + * strideWidth]`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid` output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to + * "NHWC". Specify the data format of the input and output data. With the + * default format "NHWC", the data is stored in the order of: [batch, + * height, width, channels]. Only "NHWC" is currently supported. + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` + * in which we sample input values across the height and width dimensions + * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single + * number, then `dilationHeight == dilationWidth`. If it is greater than + * 1, then all values of `strides` must be 1. + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + * @param bias Tensor to be added to the result. + * @param activation Name of activation kernel (defaults to `linear`) to be + * applied + * after biasAdd. + * @param preluActivationWeights Tensor of prelu weights to be applied as part + * of a `prelu` activation, typically the same shape as `x`. + */ +function fusedConv2d_({ + x, + filter, + strides, + pad, + dataFormat = 'NHWC', + dilations = [1, 1], + dimRoundingMode, + bias, + activation = 'linear', + preluActivationWeights +}: { + x: T|TensorLike, + filter: Tensor4D|TensorLike, + strides: [number, number]|number, + pad: 'valid'|'same'|number|conv_util.ExplicitPadding, + dataFormat?: 'NHWC'|'NCHW', + dilations?: [number, number]|number, + dimRoundingMode?: 'floor'|'round'|'ceil', + bias?: Tensor|TensorLike, + activation?: Activation, + preluActivationWeights?: Tensor +}): T { + activation = activation || 'linear'; + + if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) { + let result = unfusedConv2d( + x, filter, strides, pad, dataFormat, dilations, dimRoundingMode); + if (bias != null) { + result = add(result, bias); + } + + return applyActivation(result, activation, preluActivationWeights) as T; + } + + const $x = convertToTensor(x, 'x', 'conv2d'); + const $filter = convertToTensor(filter, 'filter', 'conv2d'); + + let x4D = $x as Tensor4D; + let reshapedTo4D = false; + + if ($x.rank === 3) { + reshapedTo4D = true; + x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); + } + util.assert( + x4D.rank === 4, + () => `Error in fused conv2d: input must be rank 4, but got rank ` + + `${x4D.rank}.`); + util.assert( + $filter.rank === 4, + () => `Error in fused conv2d: filter must be rank 4, but got rank ` + + `${$filter.rank}.`); + if (dimRoundingMode != null) { + util.assert( + util.isInt(pad as number), + () => `Error in fused conv2d: pad must be an integer when using, ` + + `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); + } + + util.assert( + x4D.shape[3] === $filter.shape[2], + () => `Error in conv2d: depth of input (${x4D.shape[3]}) must match ` + + `input depth for filter ${$filter.shape[2]}.`); + util.assert( + conv_util.eitherStridesOrDilationsAreOne(strides, dilations), + () => 'Error in conv2D: Either strides or dilations must be 1. ' + + `Got strides ${strides} and dilations '${dilations}'`); + util.assert( + dataFormat === 'NHWC', + () => `Error in conv2d: got dataFormat of ${ + dataFormat} but only NHWC is currently supported.`); + + const convInfo = conv_util.computeConv2DInfo( + x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode); + + let $bias: Tensor; + if (bias != null) { + $bias = convertToTensor(bias, 'bias', 'fused conv2d'); + [$bias] = makeTypesMatch($bias, $x); + + broadcast_util.assertAndGetBroadcastShape(convInfo.outShape, $bias.shape); + } + + let $preluActivationWeights: Tensor; + if (preluActivationWeights != null) { + $preluActivationWeights = convertToTensor( + preluActivationWeights, 'prelu weights', 'fused conv2d'); + } + + const grad = (dy: Tensor4D, saved: Tensor[]) => { + const [$filter, x4D, y, $bias] = + saved as [Tensor4D, Tensor4D, Tensor4D, Tensor]; + + const dyActivation = getFusedDyActivation(dy, y, activation) as Tensor4D; + + util.assert( + conv_util.tupleValuesAreOne(dilations), + () => 'Error in gradient of fused conv2D: ' + + `dilation rates greater than 1 ` + + `are not yet supported in gradients. Got dilations '${dilations}'`); + + const xDer = + conv2DBackpropInput(x4D.shape, dyActivation, $filter, strides, pad); + const filterDer = + conv2DBackpropFilter(x4D, dyActivation, $filter.shape, strides, pad); + const der: Tensor[] = [xDer, filterDer]; + + if ($bias != null) { + const biasDer = getFusedBiasGradient($bias, dyActivation); + der.push(biasDer); + } + return der; + }; + + const forward: ForwardFunc = (backend) => { + const res = backend.fusedConv2d({ + input: x4D, + filter: $filter, + convInfo, + bias: $bias, + activation, + preluActivationWeights: $preluActivationWeights + }); + return res; + }; + + const inputs: FusedConv2DInputs = { + x: x4D, + filter: $filter, + bias: $bias, + preluActivationWeights: $preluActivationWeights + }; + + const attrs: FusedConv2DAttrs = + {strides, pad, dataFormat, dilations, dimRoundingMode, activation}; + + // Depending on the the params passed in we will have different number of + // inputs and thus a a different number of elements in the gradient. + if (bias == null) { + const customOp = + customGrad((x4D: Tensor4D, filter: Tensor4D, save: GradSaveFunc) => { + let res = ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, + FusedConv2D, attrs as {} as NamedAttrMap); + + save([filter, x4D, res]); + + if (reshapedTo4D) { + res = res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; + } + + return {value: res, gradFunc: grad}; + }); + return customOp(x4D, $filter) as T; + } else { + const customOpWithBias = customGrad( + (x4D: Tensor4D, filter: Tensor4D, bias: Tensor, save: GradSaveFunc) => { + let res = ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, + FusedConv2D, attrs as {} as NamedAttrMap); + + save([filter, x4D, res, bias]); + + if (reshapedTo4D) { + res = res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; + } + + return {value: res, gradFunc: grad}; + }); + + return customOpWithBias(x4D, $filter, $bias) as T; + } +} +export const conv2d = op({fusedConv2d_}); diff --git a/tfjs-core/src/ops/fused_test.ts b/tfjs-core/src/ops/fused_conv2d_test.ts similarity index 62% rename from tfjs-core/src/ops/fused_test.ts rename to tfjs-core/src/ops/fused_conv2d_test.ts index 7003c3070de..fa20cd0d87d 100644 --- a/tfjs-core/src/ops/fused_test.ts +++ b/tfjs-core/src/ops/fused_conv2d_test.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2019 Google LLC. All Rights Reserved. + * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -34,530 +34,6 @@ function generateCaseInputs(totalSizeTensor: number, totalSizeFilter: number) { return {input: inp, filter: filt}; } -describeWithFlags('fused matmul', ALL_ENVS, () => { - it('fused A x B', async () => { - const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); - const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); - - const c = tf.fused.matMul({a, b}); - - expect(c.shape).toEqual([2, 2]); - expectArraysClose(await c.data(), [0, 8, -3, 20]); - }); - - it('fused A x B with relu', async () => { - const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); - const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); - const transposeA = false; - const transposeB = false; - - const c = tf.fused.matMul( - {a, b, transposeA, transposeB, bias: null, activation: 'relu'}); - - expect(c.shape).toEqual([2, 2]); - expectArraysClose(await c.data(), [0, 8, 0, 20]); - }); - - it('fused A x B with elu', async () => { - const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); - const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); - const transposeA = false; - const transposeB = false; - - const c = tf.fused.matMul( - {a, b, transposeA, transposeB, bias: null, activation: 'elu'}); - - expect(c.shape).toEqual([2, 2]); - expectArraysClose(await c.data(), [0, 8, -0.9502, 20]); - }); - - it('fused A x B with relu6', async () => { - const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); - const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); - const transposeA = false; - const transposeB = false; - - const c = tf.fused.matMul( - {a, b, transposeA, transposeB, bias: null, activation: 'relu6'}); - - expect(c.shape).toEqual([2, 2]); - expectArraysClose(await c.data(), [0, 6, 0, 6]); - }); - - it('fused A x B with prelu', async () => { - const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); - const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); - const alpha = tf.tensor2d([0.5, 0.5], [1, 2]); - const transposeA = false; - const transposeB = false; - - const c = tf.fused.matMul({ - a, - b, - transposeA, - transposeB, - bias: null, - activation: 'prelu', - preluActivationWeights: alpha - }); - - expect(c.shape).toEqual([2, 2]); - expectArraysClose(await c.data(), [0, 8, -1.5, 20]); - }); - - it('fused A x B with relu transpose', async () => { - const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); - const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [2, 3]); - const transposeA = false; - const transposeB = true; - - const c = tf.fused.matMul( - {a, b, transposeA, transposeB, bias: null, activation: 'relu'}); - - expect(c.shape).toEqual([2, 2]); - expectArraysClose(await c.data(), [0, 9, 0, 24]); - }); - - it('fused A x B with 2d bias and relu', async () => { - const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); - const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); - const c = tf.tensor2d([1, 1, 1, 1], [2, 2]); - const transposeA = false; - const transposeB = false; - - const d = tf.fused.matMul( - {a, b, transposeA, transposeB, bias: c, activation: 'relu'}); - - expect(d.shape).toEqual([2, 2]); - expectArraysClose(await d.data(), [1, 9, 0, 21]); - }); - - it('fused A x B with relu and broadcasted bias', async () => { - const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); - const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); - const c = tf.tensor1d([1, 1]); - const act: tf.fused.Activation = 'relu'; - const transposeA = false; - const transposeB = false; - - const d = tf.fused.matMul( - {a, b, transposeA, transposeB, bias: c, activation: act}); - - expect(d.shape).toEqual([2, 2]); - expectArraysClose(await d.data(), [1, 9, 0, 21]); - }); - - it('fused A x B with elu and broadcasted bias', async () => { - const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); - const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); - const c = tf.tensor1d([1, 1]); - const act: tf.fused.Activation = 'elu'; - const transposeA = false; - const transposeB = false; - - const d = tf.fused.matMul( - {a, b, transposeA, transposeB, bias: c, activation: act}); - - expect(d.shape).toEqual([2, 2]); - expectArraysClose(await d.data(), [1, 9, -0.8647, 21]); - }); - - it('fused A x B with relu and broadcasted bias different rank', async () => { - const a = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [2, 2, 3]); - const b = tf.tensor3d([0, 1, -3, 2, 2, 1, 0, 1, -3, 2, 2, 1], [2, 3, 2]); - const c = tf.tensor2d([1, 2], [1, 2]); - const act: tf.fused.Activation = 'relu'; - const transposeA = false; - const transposeB = false; - - const d = tf.fused.matMul( - {a, b, transposeA, transposeB, bias: c, activation: act}); - - expect(d.shape).toEqual([2, 2, 2]); - expectArraysClose(await d.data(), [2, 6, 0, 18, 0, 30, 0, 42]); - }); - - it('fused A x B with 2d bias only', async () => { - const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); - const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); - const c = tf.tensor2d([1, 1, 1, 1], [2, 2]); - const transposeA = false; - const transposeB = false; - - const d = tf.fused.matMul( - {a, b, transposeA, transposeB, bias: c, activation: 'linear'}); - - expect(d.shape).toEqual([2, 2]); - expectArraysClose(await d.data(), [1, 9, -2, 21]); - }); - - it('fused A x B with relu gradient', async () => { - const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]); - const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); - const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]); - const transposeA = false; - const transposeB = false; - - const grads = tf.grads((a, b) => { - const prod = tf.matMul(a, b, transposeA, transposeB); - return tf.relu(prod); - }); - - const fusedGrads = tf.grads((a, b) => { - return tf.fused.matMul( - {a, b, transposeA, transposeB, bias: null, activation: 'relu'}); - }); - - const [da, db] = grads([a, b], dy); - const [fusedDa, fusedDb] = fusedGrads([a, b], dy); - expectArraysClose(await da.array(), await fusedDa.array()); - expectArraysClose(await db.data(), await fusedDb.array()); - }); - - it('gradient with clones A x B with relu', () => { - const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]); - const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); - const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]); - const transposeA = false; - const transposeB = false; - - const fusedGrads = tf.grads((a, b) => { - return tf.fused - .matMul({ - a: a.clone(), - b: b.clone(), - transposeA, - transposeB, - bias: null, - activation: 'relu' - }) - .clone(); - }); - - const [fusedDa, fusedDb] = fusedGrads([a, b], dy); - expect(fusedDa.shape).toEqual(a.shape); - expect(fusedDb.shape).toEqual(b.shape); - }); - - it('fused A x B with relu bias gradient', async () => { - const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]); - const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); - const c = tf.tensor2d([1, 1, 1, 1], [2, 2]); - const transposeA = false; - const transposeB = false; - - const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]); - - const grads = tf.grads((a, b, c) => { - const prod = tf.matMul(a, b, transposeA, transposeB); - const sum = tf.add(prod, c); - return tf.relu(sum); - }); - - const fusedGrads = tf.grads((a, b, c) => { - return tf.fused.matMul( - {a, b, transposeA, transposeB, bias: c, activation: 'relu'}); - }); - - const [da, db, dc] = grads([a, b, c], dy); - const [fusedDa, fusedDb, fusedDc] = fusedGrads([a, b, c], dy); - - expectArraysClose(await da.array(), await fusedDa.array()); - expectArraysClose(await db.array(), await fusedDb.array()); - expectArraysClose(await dc.array(), await fusedDc.array()); - }); - - it('fused A x B with relu bias gradient transpose', async () => { - const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [3, 2]); - const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); - const c = tf.tensor2d([1, 1, 1, 1], [2, 2]); - const transposeA = true; - const transposeB = false; - - const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]); - - const grads = tf.grads((a, b, c) => { - const prod = tf.matMul(a, b, transposeA, transposeB); - const sum = tf.add(prod, c); - return tf.relu(sum); - }); - - const fusedGrads = tf.grads((a, b, c) => { - return tf.fused.matMul( - {a, b, transposeA, transposeB, bias: c, activation: 'relu'}); - }); - - const [da, db, dc] = grads([a, b, c], dy); - const [fusedDa, fusedDb, fusedDc] = fusedGrads([a, b, c], dy); - - expectArraysClose(await da.array(), await fusedDa.array()); - expectArraysClose(await db.array(), await fusedDb.array()); - expectArraysClose(await dc.array(), await fusedDc.array()); - }); - - it('fused A x B with relu and broadcasted bias gradient', async () => { - const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]); - const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); - const c = tf.tensor2d([[1]]); - const transposeA = false; - const transposeB = false; - - const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]); - - const grads = tf.grads((a, b, c) => { - const prod = tf.matMul(a, b, transposeA, transposeB); - const sum = tf.add(prod, c); - return tf.relu(sum); - }); - - const fusedGrads = tf.grads((a, b, c) => { - return tf.fused.matMul( - {a, b, transposeA, transposeB, bias: c, activation: 'relu'}); - }); - - const [da, db, dc] = grads([a, b, c], dy); - const [fusedDa, fusedDb, fusedDc] = fusedGrads([a, b, c], dy); - - expectArraysClose(await da.array(), await fusedDa.array()); - expectArraysClose(await db.array(), await fusedDb.array()); - expectArraysClose(await dc.array(), await fusedDc.array()); - }); -}); - -describeWithFlags('fused depthwiseConv2D', ALL_ENVS, () => { - it('basic', async () => { - const fSize = 2; - const pad = 'valid'; - const strides = 1; - const chMul = 1; - const inDepth = 1; - - const x = tf.tensor4d( - [ - 0.230664, 0.987388, 0.0685208, 0.419224, 0.887861, 0.731641, - 0.0741907, 0.409265, 0.351377 - ], - [1, 3, 3, inDepth]); - const w = tf.tensor4d( - [-0.303873, -0.229223, 0.144333, 0.803373], - [fSize, fSize, inDepth, chMul], - ); - - const result = tf.fused.depthwiseConv2d({x, filter: w, strides, pad}); - expect(result.shape).toEqual([1, 2, 2, 1]); - const expected = [0.47737, 0.40018, 0.00859, -0.09615]; - expectArraysClose(await result.data(), expected); - }); - - it('basic with relu', async () => { - const fSize = 2; - const pad = 'valid'; - const strides = 1; - const chMul = 1; - const inDepth = 1; - - const x = tf.tensor4d( - [ - 0.230664, 0.987388, 0.0685208, 0.419224, 0.887861, 0.731641, - 0.0741907, 0.409265, 0.351377 - ], - [1, 3, 3, inDepth]); - const w = tf.tensor4d( - [-0.303873, -0.229223, 0.144333, 0.803373], - [fSize, fSize, inDepth, chMul], - ); - - const result = tf.fused.depthwiseConv2d( - {x, filter: w, strides, pad, activation: 'relu'}); - expect(result.shape).toEqual([1, 2, 2, 1]); - const expected = [0.47737, 0.40018, 0.00859, 0]; - expectArraysClose(await result.data(), expected); - }); - - it('basic with broadcasted bias and relu', async () => { - const fSize = 2; - const pad = 'valid'; - const strides = 1; - const chMul = 1; - const inDepth = 1; - - const x = tf.tensor4d( - [ - 0.230664, 0.987388, 0.0685208, 0.419224, 0.887861, 0.731641, - 0.0741907, 0.409265, 0.351377 - ], - [1, 3, 3, inDepth]); - const w = tf.tensor4d( - [-0.303873, -0.229223, 0.144333, 0.803373], - [fSize, fSize, inDepth, chMul], - ); - - const result = tf.fused.depthwiseConv2d( - {x, filter: w, strides, pad, bias: tf.scalar(1), activation: 'relu'}); - expect(result.shape).toEqual([1, 2, 2, 1]); - const expected = [1.47737, 1.40018, 1.00859, 0.90385]; - expectArraysClose(await result.data(), expected); - }); - - it('prelu', async () => { - const fSize = 3; - const pad = 'valid'; - const strides = 1; - const chMul = 1; - const inDepth = 1; - - const x = tf.tensor4d( - [ - 0.149194, 0.089009, 0.654891, 0.083324, 0.537043, 0.644331, 0.563037, - 0.211859, 0.633501, 0.186427, 0.777034, 0.50001, 0.607341, 0.95303, - 0.696479, 0.050387, 0.62045, 0.728049, 0.028043, 0.437009, 0.712881, - 0.741935, 0.974474, 0.621102, 0.171411 - ], - [1, 5, 5, inDepth]); - const alpha = tf.tensor4d( - [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [1, 3, 3, 1]); - const w = tf.tensor4d( - [ - -0.125386, -0.975199, -0.640437, -0.281895, -0.990968, -0.347208, - -0.889702, -0.180695, -0.691992 - ], - [fSize, fSize, inDepth, chMul], - ); - - const result = tf.fused.depthwiseConv2d({ - x, - filter: w, - strides, - pad, - activation: 'prelu', - preluActivationWeights: alpha - }); - expect(result.shape).toEqual([1, 3, 3, 1]); - const expected = [ - -0.25400, -0.50118, -0.73622, -0.94068, -1.2298, -1.84585, -2.3089, - -2.7499, -2.64077 - ]; - expectArraysClose(await result.data(), expected); - }); - - it('gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0', async () => { - const inputDepth = 1; - const outputDepth = 1; - const inputShape: [number, number, number, number] = [2, 3, 3, inputDepth]; - const filterSize = 2; - const strides = 1; - const pad = 0; - - const filterShape: [number, number, number, number] = - [filterSize, filterSize, inputDepth, outputDepth]; - const filter = tf.tensor4d([-1, 1, -2, 0.5], filterShape); - - const x = tf.tensor4d( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9], inputShape); - const dy = tf.tensor4d([3, 1, 2, 0, 3, 1, 2, 0], [2, 2, 2, 1]); - - const grads = tf.grads( - (x: tf.Tensor4D, filter: tf.Tensor4D) => - tf.fused.depthwiseConv2d({x, filter, strides, pad})); - const [dx, dfilter] = grads([x, filter], dy); - - expect(dx.shape).toEqual(x.shape); - expectArraysClose( - await dx.data(), - [-3, 2, 1, -8, 1.5, 0.5, -4, 1, 0, -3, 2, 1, -8, 1.5, 0.5, -4, 1, 0]); - - expect(dfilter.shape).toEqual(filterShape); - expectArraysClose(await dfilter.data(), [26, 38, 62, 74]); - }); - - it('gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias', async () => { - const inputDepth = 1; - const outputDepth = 1; - const inputShape: [number, number, number, number] = [2, 3, 3, inputDepth]; - const filterSize = 2; - const strides = 1; - const pad = 0; - - const filterShape: [number, number, number, number] = - [filterSize, filterSize, inputDepth, outputDepth]; - const filter = tf.tensor4d([-1, 1, -2, 0.5], filterShape); - const bias = tf.ones([2, 2, 2, 1]); - - const x = tf.tensor4d( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9], inputShape); - const dy = tf.tensor4d([3, 1, 2, 0, 3, 1, 2, 0], [2, 2, 2, 1]); - - const fusedGrads = tf.grads( - (x: tf.Tensor4D, w: tf.Tensor4D, b) => tf.fused.depthwiseConv2d({ - x, - filter: w, - strides, - pad, - dataFormat: 'NHWC', - dilations: [1, 1], - bias: b - })); - const [dxFused, dfilterFused, dbiasFused] = - fusedGrads([x, filter, bias], dy); - - const grads = tf.grads((x: tf.Tensor4D, filter: tf.Tensor4D, bias) => { - const conv = tf.depthwiseConv2d(x, filter, strides, pad); - const sum = tf.add(conv, bias); - return sum; - }); - const [dx, dfilter, dbias] = grads([x, filter, bias], dy); - - expectArraysClose(await dxFused.array(), await dx.array()); - expectArraysClose(await dfilterFused.array(), await dfilter.array()); - expectArraysClose(await dbiasFused.array(), await dbias.array()); - }); - - it('gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias and activation', - async () => { - const inputDepth = 1; - const outputDepth = 1; - const inputShape: [number, number, number, number] = - [2, 3, 3, inputDepth]; - const filterSize = 2; - const strides = 1; - const pad = 0; - - const filterShape: [number, number, number, number] = - [filterSize, filterSize, inputDepth, outputDepth]; - const filter = tf.tensor4d([-1, 1, -2, 0.5], filterShape); - const bias = tf.ones([2, 2, 2, 1]); - - const x = tf.tensor4d( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9], inputShape); - const dy = tf.tensor4d([3, 1, 2, 0, 3, 1, 2, 0], [2, 2, 2, 1]); - - const fusedGrads = tf.grads( - (x: tf.Tensor4D, w: tf.Tensor4D, b) => tf.fused.depthwiseConv2d({ - x, - filter: w, - strides, - pad, - dataFormat: 'NHWC', - dilations: [1, 1], - bias: b, - activation: 'relu' - })); - const [dxFused, dfilterFused, dbiasFused] = - fusedGrads([x, filter, bias], dy); - - const grads = tf.grads((x: tf.Tensor4D, filter: tf.Tensor4D, bias) => { - const conv = tf.depthwiseConv2d(x, filter, strides, pad); - const sum = tf.add(conv, bias); - return tf.relu(sum); - }); - const [dx, dfilter, dbias] = grads([x, filter, bias], dy); - - expectArraysClose(await dxFused.array(), await dx.array()); - expectArraysClose(await dfilterFused.array(), await dfilter.array()); - expectArraysClose(await dbiasFused.array(), await dbias.array()); - }); -}); - describeWithFlags('fused conv2d', ALL_ENVS, () => { it('basic', async () => { const inputDepth = 2; @@ -1365,27 +841,4 @@ describeWithFlags('fused conv2d', ALL_ENVS, () => { expectArraysClose(await dfilterFused.array(), await dfilter.array()); expectArraysClose(await dbiasFused.array(), await dbias.array()); }); - - it('fused matmul with relu6 and gradients', async () => { - const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]); - const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); - const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]); - const transposeA = false; - const transposeB = false; - - const fusedGrads = tf.grads((a, b) => { - return tf.fused.matMul( - {a, b, transposeA, transposeB, bias: null, activation: 'relu6'}); - }); - const [fusedDa, fusedDb] = fusedGrads([a, b], dy); - - const grads = tf.grads((a, b) => { - const prod = tf.matMul(a, b, transposeA, transposeB); - return tf.relu6(prod); - }); - const [da, db] = grads([a, b], dy); - - expectArraysClose(await da.array(), await fusedDa.array()); - expectArraysClose(await db.data(), await fusedDb.array()); - }); }); diff --git a/tfjs-core/src/ops/fused_depthwise_conv2d.ts b/tfjs-core/src/ops/fused_depthwise_conv2d.ts new file mode 100644 index 00000000000..5f62f1617c0 --- /dev/null +++ b/tfjs-core/src/ops/fused_depthwise_conv2d.ts @@ -0,0 +1,258 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {customGrad} from '../gradients'; +import {FusedDepthwiseConv2D, FusedDepthwiseConv2DAttrs, FusedDepthwiseConv2DInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor, Tensor3D, Tensor4D} from '../tensor'; +import {GradSaveFunc, NamedTensorMap} from '../tensor_types'; +import {makeTypesMatch} from '../tensor_util'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {add} from './add'; +import * as broadcast_util from './broadcast_util'; +import * as conv_util from './conv_util'; +import {depthwiseConv2d as unfusedDepthwiseConv2d} from './depthwise_conv2d'; +import {depthwiseConv2dNativeBackpropFilter} from './depthwise_conv2d_native_backprop_filter'; +import {depthwiseConv2dNativeBackpropInput} from './depthwise_conv2d_native_backprop_input'; +import {Activation} from './fused_types'; +import {applyActivation, getFusedBiasGradient, getFusedDyActivation, shouldFuse} from './fused_util'; +import {op} from './operation'; + +/** + * Computes depthwise 2D convolution, optionally fused with adding a + * bias and applying an activation. + * + * Given a 4D `input` array and a `filter` array of shape + * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing + * `inChannels` convolutional filters of depth 1, this op applies a + * different filter to each input channel (expanding from 1 channel to + * `channelMultiplier` channels for each), then concatenates the results + * together. The output has `inChannels * channelMultiplier` channels. + * + * See + * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d]( + * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d) + * for more details. + * + * @param obj An object with the following properties: + * @param x The input tensor, of rank 4 or rank 3, of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is + * assumed. + * @param filter The filter tensor, rank 4, of shape + * `[filterHeight, filterWidth, inChannels, channelMultiplier]`. + * @param strides The strides of the convolution: `[strideHeight, + * strideWidth]`. If strides is a single number, then `strideHeight == + * strideWidth`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` + * in which we sample input values across the height and width dimensions + * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single + * number, then `dilationHeight == dilationWidth`. If it is greater than + * 1, then all values of `strides` must be 1. + * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to + * "NHWC". Specify the data format of the input and output data. With the + * default format "NHWC", the data is stored in the order of: [batch, + * height, width, channels]. Only "NHWC" is currently supported. + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + * @param bias Tensor to be added to the result. + * @param activation Name of activation kernel (defaults to `linear`). + * @param preluActivationWeights Tensor of prelu weights to be applied as part + * of a `prelu` activation, typically the same shape as `x`. + */ +function fusedDepthwiseConv2d_({ + x, + filter, + strides, + pad, + dataFormat = 'NHWC', + dilations = [1, 1], + dimRoundingMode, + bias, + activation = 'linear', + preluActivationWeights +}: { + x: T|TensorLike, + filter: Tensor4D|TensorLike, + strides: [number, number]|number, + pad: 'valid'|'same'|number, + dataFormat?: 'NHWC'|'NCHW', + dilations?: [number, number]|number, + dimRoundingMode?: 'floor'|'round'|'ceil', + bias?: Tensor|TensorLike, + activation?: Activation, + preluActivationWeights?: Tensor +}): T { + if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) { + let result = unfusedDepthwiseConv2d( + x, filter, strides, pad, dataFormat, dilations, dimRoundingMode); + if (bias != null) { + result = add(result, bias); + } + + return applyActivation(result, activation, preluActivationWeights) as T; + } + + const $x = convertToTensor(x, 'x', 'depthwiseConv2d'); + const $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d'); + + let x4D = $x as Tensor4D; + let reshapedTo4D = false; + if ($x.rank === 3) { + reshapedTo4D = true; + x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); + } + util.assert( + x4D.rank === 4, + () => `Error in fused depthwiseConv2d: input must be rank 4, but got ` + + `rank ${x4D.rank}.`); + util.assert( + $filter.rank === 4, + () => `Error in fused depthwiseConv2d: filter must be rank 4, ` + + `but got rank ${$filter.rank}.`); + util.assert( + x4D.shape[3] === $filter.shape[2], + () => `Error in fused depthwiseConv2d: number of input channels ` + + `(${x4D.shape[3]}) must match the inChannels dimension in ` + + `filter ${$filter.shape[2]}.`); + if (dilations == null) { + dilations = [1, 1]; + } + util.assert( + conv_util.eitherStridesOrDilationsAreOne(strides, dilations), + () => + 'Error in fused depthwiseConv2d: Either strides or dilations must ' + + `be 1. Got strides ${strides} and dilations '${dilations}'`); + + if (dimRoundingMode != null) { + util.assert( + util.isInt(pad as number), + () => `Error in fused depthwiseConv2d: pad must be an integer when ` + + `using dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); + } + + const convInfo = conv_util.computeConv2DInfo( + x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, + true /* depthwise */); + + let $bias: Tensor; + if (bias != null) { + $bias = convertToTensor(bias, 'bias', 'fused conv2d'); + [$bias] = makeTypesMatch($bias, $x); + + broadcast_util.assertAndGetBroadcastShape(convInfo.outShape, $bias.shape); + } + + let $preluActivationWeights: Tensor; + if (preluActivationWeights != null) { + $preluActivationWeights = convertToTensor( + preluActivationWeights, 'prelu weights', 'fused depthwiseConv2d'); + } + + const grad = (dy: Tensor4D, saved: Tensor[]) => { + util.assert( + conv_util.tupleValuesAreOne(dilations), + () => 'Error in gradient of fused depthwiseConv2d: dilation rates ' + + `greater than 1 are not yet supported. Got dilations ` + + `'${dilations}'`); + const [$filter, x4D, y, bias] = saved; + + const dyActivation = getFusedDyActivation(dy, y, activation) as Tensor4D; + + const xDer = depthwiseConv2dNativeBackpropInput( + (x4D as Tensor4D).shape, dyActivation, $filter as Tensor4D, convInfo); + const filterDer = depthwiseConv2dNativeBackpropFilter( + x4D as Tensor4D, dyActivation, ($filter as Tensor4D).shape, convInfo); + + if (bias != null) { + const biasDer = getFusedBiasGradient($bias, dyActivation); + return [xDer, filterDer, biasDer]; + } + return [xDer, filterDer]; + }; + + const forward: ForwardFunc = (backend) => { + const res = backend.fusedDepthwiseConv2D({ + input: x4D, + filter: $filter, + convInfo, + bias: $bias, + activation, + preluActivationWeights: $preluActivationWeights + }); + return res; + }; + + const inputs: FusedDepthwiseConv2DInputs = { + x: x4D, + filter: $filter, + bias: $bias, + preluActivationWeights: $preluActivationWeights + }; + const attrs: FusedDepthwiseConv2DAttrs = + {strides, pad, dataFormat, dilations, dimRoundingMode, activation}; + + // Depending on the the params passed in we will have different number of + // inputs and thus a a different number of elements in the gradient. + if (bias == null) { + const customOp = + customGrad((x4D: Tensor4D, filter: Tensor4D, save: GradSaveFunc) => { + let res = ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, + FusedDepthwiseConv2D, attrs as {} as NamedAttrMap); + + save([filter, x4D, res]); + + if (reshapedTo4D) { + res = res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; + } + + return {value: res, gradFunc: grad}; + }); + return customOp(x4D, $filter) as T; + } else { + const customOpWithBias = customGrad( + (x4D: Tensor4D, filter: Tensor4D, bias: Tensor, save: GradSaveFunc) => { + let res = ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, + FusedDepthwiseConv2D, attrs as {} as NamedAttrMap); + + save([filter, x4D, res, bias]); + + if (reshapedTo4D) { + res = res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; + } + + return {value: res, gradFunc: grad}; + }); + + return customOpWithBias(x4D, $filter, $bias) as T; + } +} +export const depthwiseConv2d = op({fusedDepthwiseConv2d_}); diff --git a/tfjs-core/src/ops/fused_depthwise_conv2d_test.ts b/tfjs-core/src/ops/fused_depthwise_conv2d_test.ts new file mode 100644 index 00000000000..ada8531b9c8 --- /dev/null +++ b/tfjs-core/src/ops/fused_depthwise_conv2d_test.ts @@ -0,0 +1,254 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('fused depthwiseConv2D', ALL_ENVS, () => { + it('basic', async () => { + const fSize = 2; + const pad = 'valid'; + const strides = 1; + const chMul = 1; + const inDepth = 1; + + const x = tf.tensor4d( + [ + 0.230664, 0.987388, 0.0685208, 0.419224, 0.887861, 0.731641, + 0.0741907, 0.409265, 0.351377 + ], + [1, 3, 3, inDepth]); + const w = tf.tensor4d( + [-0.303873, -0.229223, 0.144333, 0.803373], + [fSize, fSize, inDepth, chMul], + ); + + const result = tf.fused.depthwiseConv2d({x, filter: w, strides, pad}); + expect(result.shape).toEqual([1, 2, 2, 1]); + const expected = [0.47737, 0.40018, 0.00859, -0.09615]; + expectArraysClose(await result.data(), expected); + }); + + it('basic with relu', async () => { + const fSize = 2; + const pad = 'valid'; + const strides = 1; + const chMul = 1; + const inDepth = 1; + + const x = tf.tensor4d( + [ + 0.230664, 0.987388, 0.0685208, 0.419224, 0.887861, 0.731641, + 0.0741907, 0.409265, 0.351377 + ], + [1, 3, 3, inDepth]); + const w = tf.tensor4d( + [-0.303873, -0.229223, 0.144333, 0.803373], + [fSize, fSize, inDepth, chMul], + ); + + const result = tf.fused.depthwiseConv2d( + {x, filter: w, strides, pad, activation: 'relu'}); + expect(result.shape).toEqual([1, 2, 2, 1]); + const expected = [0.47737, 0.40018, 0.00859, 0]; + expectArraysClose(await result.data(), expected); + }); + + it('basic with broadcasted bias and relu', async () => { + const fSize = 2; + const pad = 'valid'; + const strides = 1; + const chMul = 1; + const inDepth = 1; + + const x = tf.tensor4d( + [ + 0.230664, 0.987388, 0.0685208, 0.419224, 0.887861, 0.731641, + 0.0741907, 0.409265, 0.351377 + ], + [1, 3, 3, inDepth]); + const w = tf.tensor4d( + [-0.303873, -0.229223, 0.144333, 0.803373], + [fSize, fSize, inDepth, chMul], + ); + + const result = tf.fused.depthwiseConv2d( + {x, filter: w, strides, pad, bias: tf.scalar(1), activation: 'relu'}); + expect(result.shape).toEqual([1, 2, 2, 1]); + const expected = [1.47737, 1.40018, 1.00859, 0.90385]; + expectArraysClose(await result.data(), expected); + }); + + it('prelu', async () => { + const fSize = 3; + const pad = 'valid'; + const strides = 1; + const chMul = 1; + const inDepth = 1; + + const x = tf.tensor4d( + [ + 0.149194, 0.089009, 0.654891, 0.083324, 0.537043, 0.644331, 0.563037, + 0.211859, 0.633501, 0.186427, 0.777034, 0.50001, 0.607341, 0.95303, + 0.696479, 0.050387, 0.62045, 0.728049, 0.028043, 0.437009, 0.712881, + 0.741935, 0.974474, 0.621102, 0.171411 + ], + [1, 5, 5, inDepth]); + const alpha = tf.tensor4d( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [1, 3, 3, 1]); + const w = tf.tensor4d( + [ + -0.125386, -0.975199, -0.640437, -0.281895, -0.990968, -0.347208, + -0.889702, -0.180695, -0.691992 + ], + [fSize, fSize, inDepth, chMul], + ); + + const result = tf.fused.depthwiseConv2d({ + x, + filter: w, + strides, + pad, + activation: 'prelu', + preluActivationWeights: alpha + }); + expect(result.shape).toEqual([1, 3, 3, 1]); + const expected = [ + -0.25400, -0.50118, -0.73622, -0.94068, -1.2298, -1.84585, -2.3089, + -2.7499, -2.64077 + ]; + expectArraysClose(await result.data(), expected); + }); + + it('gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0', async () => { + const inputDepth = 1; + const outputDepth = 1; + const inputShape: [number, number, number, number] = [2, 3, 3, inputDepth]; + const filterSize = 2; + const strides = 1; + const pad = 0; + + const filterShape: [number, number, number, number] = + [filterSize, filterSize, inputDepth, outputDepth]; + const filter = tf.tensor4d([-1, 1, -2, 0.5], filterShape); + + const x = tf.tensor4d( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9], inputShape); + const dy = tf.tensor4d([3, 1, 2, 0, 3, 1, 2, 0], [2, 2, 2, 1]); + + const grads = tf.grads( + (x: tf.Tensor4D, filter: tf.Tensor4D) => + tf.fused.depthwiseConv2d({x, filter, strides, pad})); + const [dx, dfilter] = grads([x, filter], dy); + + expect(dx.shape).toEqual(x.shape); + expectArraysClose( + await dx.data(), + [-3, 2, 1, -8, 1.5, 0.5, -4, 1, 0, -3, 2, 1, -8, 1.5, 0.5, -4, 1, 0]); + + expect(dfilter.shape).toEqual(filterShape); + expectArraysClose(await dfilter.data(), [26, 38, 62, 74]); + }); + + it('gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias', async () => { + const inputDepth = 1; + const outputDepth = 1; + const inputShape: [number, number, number, number] = [2, 3, 3, inputDepth]; + const filterSize = 2; + const strides = 1; + const pad = 0; + + const filterShape: [number, number, number, number] = + [filterSize, filterSize, inputDepth, outputDepth]; + const filter = tf.tensor4d([-1, 1, -2, 0.5], filterShape); + const bias = tf.ones([2, 2, 2, 1]); + + const x = tf.tensor4d( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9], inputShape); + const dy = tf.tensor4d([3, 1, 2, 0, 3, 1, 2, 0], [2, 2, 2, 1]); + + const fusedGrads = tf.grads( + (x: tf.Tensor4D, w: tf.Tensor4D, b) => tf.fused.depthwiseConv2d({ + x, + filter: w, + strides, + pad, + dataFormat: 'NHWC', + dilations: [1, 1], + bias: b + })); + const [dxFused, dfilterFused, dbiasFused] = + fusedGrads([x, filter, bias], dy); + + const grads = tf.grads((x: tf.Tensor4D, filter: tf.Tensor4D, bias) => { + const conv = tf.depthwiseConv2d(x, filter, strides, pad); + const sum = tf.add(conv, bias); + return sum; + }); + const [dx, dfilter, dbias] = grads([x, filter, bias], dy); + + expectArraysClose(await dxFused.array(), await dx.array()); + expectArraysClose(await dfilterFused.array(), await dfilter.array()); + expectArraysClose(await dbiasFused.array(), await dbias.array()); + }); + + it('gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias and activation', + async () => { + const inputDepth = 1; + const outputDepth = 1; + const inputShape: [number, number, number, number] = + [2, 3, 3, inputDepth]; + const filterSize = 2; + const strides = 1; + const pad = 0; + + const filterShape: [number, number, number, number] = + [filterSize, filterSize, inputDepth, outputDepth]; + const filter = tf.tensor4d([-1, 1, -2, 0.5], filterShape); + const bias = tf.ones([2, 2, 2, 1]); + + const x = tf.tensor4d( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9], inputShape); + const dy = tf.tensor4d([3, 1, 2, 0, 3, 1, 2, 0], [2, 2, 2, 1]); + + const fusedGrads = tf.grads( + (x: tf.Tensor4D, w: tf.Tensor4D, b) => tf.fused.depthwiseConv2d({ + x, + filter: w, + strides, + pad, + dataFormat: 'NHWC', + dilations: [1, 1], + bias: b, + activation: 'relu' + })); + const [dxFused, dfilterFused, dbiasFused] = + fusedGrads([x, filter, bias], dy); + + const grads = tf.grads((x: tf.Tensor4D, filter: tf.Tensor4D, bias) => { + const conv = tf.depthwiseConv2d(x, filter, strides, pad); + const sum = tf.add(conv, bias); + return tf.relu(sum); + }); + const [dx, dfilter, dbias] = grads([x, filter, bias], dy); + + expectArraysClose(await dxFused.array(), await dx.array()); + expectArraysClose(await dfilterFused.array(), await dfilter.array()); + expectArraysClose(await dbiasFused.array(), await dbias.array()); + }); +}); diff --git a/tfjs-core/src/ops/fused_mat_mul.ts b/tfjs-core/src/ops/fused_mat_mul.ts new file mode 100644 index 00000000000..7abf4e65e00 --- /dev/null +++ b/tfjs-core/src/ops/fused_mat_mul.ts @@ -0,0 +1,225 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE, ForwardFunc} from '../engine'; +import {customGrad} from '../gradients'; +import {_FusedMatMul, _FusedMatMulAttrs, _FusedMatMulInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor, Tensor3D} from '../tensor'; +import {GradSaveFunc, NamedTensorMap} from '../tensor_types'; +import {makeTypesMatch} from '../tensor_util'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {add} from './add'; +import * as broadcast_util from './broadcast_util'; +import {Activation} from './fused_types'; +import {applyActivation, getFusedBiasGradient, getFusedDyActivation, shouldFuse} from './fused_util'; +import {matMul as unfusedMatMul} from './mat_mul'; +import {op} from './operation'; +import {reshape} from './reshape'; + +/** + * Computes the dot product of two matrices with optional activation and bias. + * + * ```js + * const a = tf.tensor2d([-1, -2], [1, 2]); + * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * const bias = tf.tensor2d([1, 2], [1, 2]); + * + * tf.fused.matMul({a, b, bias, activation: 'relu'}).print(); + * ``` + * + * @param obj An object with the following properties: + * - `a` First matrix in dot product operation. + * - `b` Second matrix in dot product operation. + * - `transposeA` If true, `a` is transposed before multiplication. + * - `transposeB` If true, `b` is transposed before multiplication. + * - `bias` Matrix to be added to the result. + * - `activation` Name of activation kernel (defaults to `linear`). + * - `preluActivationWeights` Tensor of prelu weights. + */ +function fusedMatMul_({ + a, + b, + transposeA = false, + transposeB = false, + bias, + activation = 'linear', + preluActivationWeights +}: { + a: T|TensorLike, + b: T|TensorLike, + transposeA?: boolean, + transposeB?: boolean, + bias?: Tensor|TensorLike, + activation?: Activation, + preluActivationWeights?: Tensor +}): T { + if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) { + let result = unfusedMatMul(a, b, transposeA, transposeB); + if (bias != null) { + result = add(result, bias); + } + + return applyActivation(result, activation, preluActivationWeights) as T; + } + + let $a = convertToTensor(a, 'a', 'fused matMul'); + let $b = convertToTensor(b, 'b', 'fused matMul'); + [$a, $b] = makeTypesMatch($a, $b); + + const innerShapeA = + transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1]; + const innerShapeB = + transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2]; + + const outerShapeA = + transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2]; + const outerShapeB = + transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1]; + + const outerDimsA = $a.shape.slice(0, -2); + const outerDimsB = $b.shape.slice(0, -2); + const batchDimA = util.sizeFromShape(outerDimsA); + const batchDimB = util.sizeFromShape(outerDimsB); + + util.assert( + $a.rank >= 2 && $b.rank >= 2 && $a.rank === $b.rank, + () => + `Error in fused matMul: inputs must have the same rank of at least ` + + `2, got ranks ${$a.rank} and ${$b.rank}.`); + + util.assert( + util.arraysEqual(outerDimsA, outerDimsB), + () => `Error in fused matMul: outer dimensions (${outerDimsA}) and (` + + `${outerDimsB}) of Tensors with shapes ${$a.shape} and ` + + `${$b.shape} must match.`); + + util.assert( + innerShapeA === innerShapeB, + () => `Error in fused matMul: inner shapes (${innerShapeA}) and (` + + `${innerShapeB}) of Tensors with shapes ${$a.shape} and ` + + `${$b.shape} and transposeA=${transposeA}` + + ` and transposeB=${transposeB} must match.`); + + const outShape = $a.shape.slice(0, -2).concat([outerShapeA, outerShapeB]); + + const a3D = transposeA ? $a.as3D(batchDimA, innerShapeA, outerShapeA) : + $a.as3D(batchDimA, outerShapeA, innerShapeA); + const b3D = transposeB ? $b.as3D(batchDimB, outerShapeB, innerShapeB) : + $b.as3D(batchDimB, innerShapeB, outerShapeB); + + let $bias: Tensor; + if (bias != null) { + $bias = convertToTensor(bias, 'bias', 'fused matMul'); + [$bias] = makeTypesMatch($bias, $a); + + broadcast_util.assertAndGetBroadcastShape(outShape, $bias.shape); + } + + let $preluActivationWeights: Tensor; + if (preluActivationWeights != null) { + $preluActivationWeights = convertToTensor( + preluActivationWeights, 'prelu weights', 'fused matMul'); + } + + const grad = (dy: Tensor3D, saved: Tensor[]) => { + const [a3D, b3D, y, $bias] = saved; + // we reshape dy because the result of the forward is not + // necessarily going to be a 3d tensor due to a reshape done at the end of + // the customOp. + const dyActivation = + getFusedDyActivation(reshape(dy, y.shape), y, activation); + let aDer: Tensor; + let bDer: Tensor; + + if (!transposeA && !transposeB) { + aDer = unfusedMatMul(dyActivation, b3D, false, true); + bDer = unfusedMatMul(a3D, dyActivation, true, false); + } else if (!transposeA && transposeB) { + aDer = unfusedMatMul(dyActivation, b3D, false, false); + bDer = unfusedMatMul(dyActivation, a3D, true, false); + } else if (transposeA && !transposeB) { + aDer = unfusedMatMul(b3D, dyActivation, false, true); + bDer = unfusedMatMul(a3D, dyActivation, false, false); + } else { + aDer = unfusedMatMul(b3D, dyActivation, true, true); + bDer = unfusedMatMul(dyActivation, a3D, true, true); + } + + if (bias != null) { + const biasDer = getFusedBiasGradient($bias, dyActivation); + return [aDer, bDer, biasDer]; + } else { + return [aDer, bDer]; + } + }; + + const forward: ForwardFunc = (backend) => { + const y = backend.fusedBatchMatMul({ + a: a3D, + b: b3D, + transposeA, + transposeB, + bias: $bias, + activation, + preluActivationWeights: $preluActivationWeights + }); + return y; + }; + + const inputs: _FusedMatMulInputs = { + a: a3D, + b: b3D, + bias: $bias, + preluActivationWeights: $preluActivationWeights + }; + const attrs: _FusedMatMulAttrs = {transposeA, transposeB, activation}; + + // Depending on the the params passed in we will have different number of + // inputs and thus a a different number of elements in the gradient. + if (bias == null) { + const customOp = + customGrad((a3D: Tensor3D, b3D: Tensor3D, save: GradSaveFunc) => { + const res = ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, + _FusedMatMul, attrs as {} as NamedAttrMap); + + save([a3D, b3D, res]); + + return {value: reshape(res, outShape), gradFunc: grad}; + }); + return customOp(a3D, b3D) as T; + } else { + const customOpWithBias = customGrad( + (a3D: Tensor3D, b3D: Tensor3D, $bias: Tensor, save: GradSaveFunc) => { + const res = ENGINE.runKernelFunc( + forward, inputs as {} as NamedTensorMap, null /* grad */, + _FusedMatMul, attrs as {} as NamedAttrMap); + + save([a3D, b3D, res, $bias]); + + return {value: reshape(res, outShape), gradFunc: grad}; + }); + + return customOpWithBias(a3D, b3D, $bias) as T; + } +} + +export const matMul = op({fusedMatMul_}); diff --git a/tfjs-core/src/ops/fused_mat_mul_test.ts b/tfjs-core/src/ops/fused_mat_mul_test.ts new file mode 100644 index 00000000000..cb15f706b42 --- /dev/null +++ b/tfjs-core/src/ops/fused_mat_mul_test.ts @@ -0,0 +1,333 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('fused matmul', ALL_ENVS, () => { + it('fused A x B', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); + const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); + + const c = tf.fused.matMul({a, b}); + + expect(c.shape).toEqual([2, 2]); + expectArraysClose(await c.data(), [0, 8, -3, 20]); + }); + + it('fused A x B with relu', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); + const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); + const transposeA = false; + const transposeB = false; + + const c = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: null, activation: 'relu'}); + + expect(c.shape).toEqual([2, 2]); + expectArraysClose(await c.data(), [0, 8, 0, 20]); + }); + + it('fused A x B with elu', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); + const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); + const transposeA = false; + const transposeB = false; + + const c = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: null, activation: 'elu'}); + + expect(c.shape).toEqual([2, 2]); + expectArraysClose(await c.data(), [0, 8, -0.9502, 20]); + }); + + it('fused A x B with relu6', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); + const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); + const transposeA = false; + const transposeB = false; + + const c = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: null, activation: 'relu6'}); + + expect(c.shape).toEqual([2, 2]); + expectArraysClose(await c.data(), [0, 6, 0, 6]); + }); + + it('fused A x B with prelu', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); + const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); + const alpha = tf.tensor2d([0.5, 0.5], [1, 2]); + const transposeA = false; + const transposeB = false; + + const c = tf.fused.matMul({ + a, + b, + transposeA, + transposeB, + bias: null, + activation: 'prelu', + preluActivationWeights: alpha + }); + + expect(c.shape).toEqual([2, 2]); + expectArraysClose(await c.data(), [0, 8, -1.5, 20]); + }); + + it('fused A x B with relu transpose', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); + const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [2, 3]); + const transposeA = false; + const transposeB = true; + + const c = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: null, activation: 'relu'}); + + expect(c.shape).toEqual([2, 2]); + expectArraysClose(await c.data(), [0, 9, 0, 24]); + }); + + it('fused A x B with 2d bias and relu', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); + const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); + const c = tf.tensor2d([1, 1, 1, 1], [2, 2]); + const transposeA = false; + const transposeB = false; + + const d = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: c, activation: 'relu'}); + + expect(d.shape).toEqual([2, 2]); + expectArraysClose(await d.data(), [1, 9, 0, 21]); + }); + + it('fused A x B with relu and broadcasted bias', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); + const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); + const c = tf.tensor1d([1, 1]); + const act: tf.fused.Activation = 'relu'; + const transposeA = false; + const transposeB = false; + + const d = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: c, activation: act}); + + expect(d.shape).toEqual([2, 2]); + expectArraysClose(await d.data(), [1, 9, 0, 21]); + }); + + it('fused A x B with elu and broadcasted bias', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); + const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); + const c = tf.tensor1d([1, 1]); + const act: tf.fused.Activation = 'elu'; + const transposeA = false; + const transposeB = false; + + const d = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: c, activation: act}); + + expect(d.shape).toEqual([2, 2]); + expectArraysClose(await d.data(), [1, 9, -0.8647, 21]); + }); + + it('fused A x B with relu and broadcasted bias different rank', async () => { + const a = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [2, 2, 3]); + const b = tf.tensor3d([0, 1, -3, 2, 2, 1, 0, 1, -3, 2, 2, 1], [2, 3, 2]); + const c = tf.tensor2d([1, 2], [1, 2]); + const act: tf.fused.Activation = 'relu'; + const transposeA = false; + const transposeB = false; + + const d = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: c, activation: act}); + + expect(d.shape).toEqual([2, 2, 2]); + expectArraysClose(await d.data(), [2, 6, 0, 18, 0, 30, 0, 42]); + }); + + it('fused A x B with 2d bias only', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); + const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); + const c = tf.tensor2d([1, 1, 1, 1], [2, 2]); + const transposeA = false; + const transposeB = false; + + const d = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: c, activation: 'linear'}); + + expect(d.shape).toEqual([2, 2]); + expectArraysClose(await d.data(), [1, 9, -2, 21]); + }); + + it('fused A x B with relu gradient', async () => { + const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]); + const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); + const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]); + const transposeA = false; + const transposeB = false; + + const grads = tf.grads((a, b) => { + const prod = tf.matMul(a, b, transposeA, transposeB); + return tf.relu(prod); + }); + + const fusedGrads = tf.grads((a, b) => { + return tf.fused.matMul( + {a, b, transposeA, transposeB, bias: null, activation: 'relu'}); + }); + + const [da, db] = grads([a, b], dy); + const [fusedDa, fusedDb] = fusedGrads([a, b], dy); + expectArraysClose(await da.array(), await fusedDa.array()); + expectArraysClose(await db.data(), await fusedDb.array()); + }); + + it('gradient with clones A x B with relu', () => { + const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]); + const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); + const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]); + const transposeA = false; + const transposeB = false; + + const fusedGrads = tf.grads((a, b) => { + return tf.fused + .matMul({ + a: a.clone(), + b: b.clone(), + transposeA, + transposeB, + bias: null, + activation: 'relu' + }) + .clone(); + }); + + const [fusedDa, fusedDb] = fusedGrads([a, b], dy); + expect(fusedDa.shape).toEqual(a.shape); + expect(fusedDb.shape).toEqual(b.shape); + }); + + it('fused A x B with relu bias gradient', async () => { + const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]); + const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); + const c = tf.tensor2d([1, 1, 1, 1], [2, 2]); + const transposeA = false; + const transposeB = false; + + const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]); + + const grads = tf.grads((a, b, c) => { + const prod = tf.matMul(a, b, transposeA, transposeB); + const sum = tf.add(prod, c); + return tf.relu(sum); + }); + + const fusedGrads = tf.grads((a, b, c) => { + return tf.fused.matMul( + {a, b, transposeA, transposeB, bias: c, activation: 'relu'}); + }); + + const [da, db, dc] = grads([a, b, c], dy); + const [fusedDa, fusedDb, fusedDc] = fusedGrads([a, b, c], dy); + + expectArraysClose(await da.array(), await fusedDa.array()); + expectArraysClose(await db.array(), await fusedDb.array()); + expectArraysClose(await dc.array(), await fusedDc.array()); + }); + + it('fused A x B with relu bias gradient transpose', async () => { + const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [3, 2]); + const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); + const c = tf.tensor2d([1, 1, 1, 1], [2, 2]); + const transposeA = true; + const transposeB = false; + + const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]); + + const grads = tf.grads((a, b, c) => { + const prod = tf.matMul(a, b, transposeA, transposeB); + const sum = tf.add(prod, c); + return tf.relu(sum); + }); + + const fusedGrads = tf.grads((a, b, c) => { + return tf.fused.matMul( + {a, b, transposeA, transposeB, bias: c, activation: 'relu'}); + }); + + const [da, db, dc] = grads([a, b, c], dy); + const [fusedDa, fusedDb, fusedDc] = fusedGrads([a, b, c], dy); + + expectArraysClose(await da.array(), await fusedDa.array()); + expectArraysClose(await db.array(), await fusedDb.array()); + expectArraysClose(await dc.array(), await fusedDc.array()); + }); + + it('fused A x B with relu and broadcasted bias gradient', async () => { + const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]); + const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); + const c = tf.tensor2d([[1]]); + const transposeA = false; + const transposeB = false; + + const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]); + + const grads = tf.grads((a, b, c) => { + const prod = tf.matMul(a, b, transposeA, transposeB); + const sum = tf.add(prod, c); + return tf.relu(sum); + }); + + const fusedGrads = tf.grads((a, b, c) => { + return tf.fused.matMul( + {a, b, transposeA, transposeB, bias: c, activation: 'relu'}); + }); + + const [da, db, dc] = grads([a, b, c], dy); + const [fusedDa, fusedDb, fusedDc] = fusedGrads([a, b, c], dy); + + expectArraysClose(await da.array(), await fusedDa.array()); + expectArraysClose(await db.array(), await fusedDb.array()); + expectArraysClose(await dc.array(), await fusedDc.array()); + }); + + it('fused matmul with relu6 and gradients', async () => { + const a = tf.tensor2d([1, 2, 3, 10, 20, -30], [2, 3]); + const b = tf.tensor2d([2, 3, 4, -1, 2, 3], [3, 2]); + const dy = tf.tensor2d([1, 10, 20, 30], [2, 2]); + const transposeA = false; + const transposeB = false; + + const fusedGrads = tf.grads((a, b) => { + return tf.fused.matMul( + {a, b, transposeA, transposeB, bias: null, activation: 'relu6'}); + }); + const [fusedDa, fusedDb] = fusedGrads([a, b], dy); + + const grads = tf.grads((a, b) => { + const prod = tf.matMul(a, b, transposeA, transposeB); + return tf.relu6(prod); + }); + const [da, db] = grads([a, b], dy); + + expectArraysClose(await da.array(), await fusedDa.array()); + expectArraysClose(await db.data(), await fusedDb.array()); + }); +}); diff --git a/tfjs-core/src/ops/fused_ops.ts b/tfjs-core/src/ops/fused_ops.ts index f416d224a7d..32d8b26770c 100644 --- a/tfjs-core/src/ops/fused_ops.ts +++ b/tfjs-core/src/ops/fused_ops.ts @@ -15,663 +15,9 @@ * ============================================================================= */ -import {ENGINE} from '../engine'; -import * as conv_util from '../ops/conv_util'; -import {op} from '../ops/operation'; -import {Tensor, Tensor3D, Tensor4D} from '../tensor'; -import {makeTypesMatch} from '../tensor_util'; -import {convertToTensor} from '../tensor_util_env'; -import {TensorLike} from '../types'; -import * as util from '../util'; +import {conv2d} from './fused_conv2d'; +import {depthwiseConv2d} from './fused_depthwise_conv2d'; +import {matMul} from './fused_mat_mul'; +import {Activation} from './fused_types'; -import {add} from './add'; -import * as broadcast_util from './broadcast_util'; -import {conv2d as unfusedConv2d} from './conv2d'; -import {conv2DBackpropFilter} from './conv2d_backprop_filter'; -import {conv2DBackpropInput} from './conv2d_backprop_input'; -import {depthwiseConv2d as unfusedDepthwiseConv2d} from './depthwise_conv2d'; -import {depthwiseConv2dNativeBackpropFilter} from './depthwise_conv2d_native_backprop_filter'; -import {depthwiseConv2dNativeBackpropInput} from './depthwise_conv2d_native_backprop_input'; -import {elu} from './elu'; -import {Activation, shouldFuse} from './fused_util'; -import {matMul as unfusedMatMul} from './mat_mul'; -import {prelu} from './prelu'; -import {relu} from './relu'; -import {relu6} from './relu6'; - -// Returns gradient for fused activation. -const getFusedDyActivation = - (dy: Tensor, y: Tensor, activation: Activation): Tensor => { - if (activation == null || activation === 'linear') { - return dy; - } - if (activation === 'relu') { - return dy.mul(y.step()); - } - throw new Error( - `Gradient for activation ${activation} has not been ` + - `implemented yet.`); - }; - -// Returns gradient for fused bias. -const getFusedBiasGradient = (bias: Tensor, dyActivation: Tensor): Tensor => { - let res = dyActivation; - const reduceAxes = - broadcast_util.getReductionAxes(bias.shape, dyActivation.shape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes); - } - return res.reshape(bias.shape); -}; - -const applyActivation = - (x: Tensor, activation: Activation, preluActivationWeights?: Tensor): - Tensor => { - if (activation === 'linear') { - return x; - } else if (activation === 'relu') { - return relu(x); - } else if (activation === 'elu') { - return elu(x); - } else if (activation === 'relu6') { - return relu6(x); - } else if (activation === 'prelu') { - return prelu(x, preluActivationWeights); - } - throw new Error(`Unknown fused activation ${activation}.`); - }; - -/** - * Computes the dot product of two matrices with optional activation and bias. - * - * ```js - * const a = tf.tensor2d([-1, -2], [1, 2]); - * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]); - * const bias = tf.tensor2d([1, 2], [1, 2]); - * - * tf.fused.matMul({a, b, bias, activation: 'relu'}).print(); - * ``` - * - * @param obj An object with the following properties: - * - `a` First matrix in dot product operation. - * - `b` Second matrix in dot product operation. - * - `transposeA` If true, `a` is transposed before multiplication. - * - `transposeB` If true, `b` is transposed before multiplication. - * - `bias` Matrix to be added to the result. - * - `activation` Name of activation kernel (defaults to `linear`). - * - `preluActivationWeights` Tensor of prelu weights. - */ -function fusedMatMul_({ - a, - b, - transposeA = false, - transposeB = false, - bias, - activation = 'linear', - preluActivationWeights -}: { - a: T|TensorLike, - b: T|TensorLike, - transposeA?: boolean, - transposeB?: boolean, - bias?: Tensor|TensorLike, - activation?: Activation, - preluActivationWeights?: Tensor -}): T { - if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) { - let result = unfusedMatMul(a, b, transposeA, transposeB); - if (bias != null) { - result = add(result, bias); - } - - return applyActivation(result, activation, preluActivationWeights) as T; - } - - let $a = convertToTensor(a, 'a', 'fused matMul'); - let $b = convertToTensor(b, 'b', 'fused matMul'); - [$a, $b] = makeTypesMatch($a, $b); - - const innerShapeA = - transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1]; - const innerShapeB = - transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2]; - - const outerShapeA = - transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2]; - const outerShapeB = - transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1]; - - const outerDimsA = $a.shape.slice(0, -2); - const outerDimsB = $b.shape.slice(0, -2); - const batchDimA = util.sizeFromShape(outerDimsA); - const batchDimB = util.sizeFromShape(outerDimsB); - - util.assert( - $a.rank >= 2 && $b.rank >= 2 && $a.rank === $b.rank, - () => - `Error in fused matMul: inputs must have the same rank of at least ` + - `2, got ranks ${$a.rank} and ${$b.rank}.`); - - util.assert( - util.arraysEqual(outerDimsA, outerDimsB), - () => `Error in fused matMul: outer dimensions (${outerDimsA}) and (` + - `${outerDimsB}) of Tensors with shapes ${$a.shape} and ` + - `${$b.shape} must match.`); - - util.assert( - innerShapeA === innerShapeB, - () => `Error in fused matMul: inner shapes (${innerShapeA}) and (` + - `${innerShapeB}) of Tensors with shapes ${$a.shape} and ` + - `${$b.shape} and transposeA=${transposeA}` + - ` and transposeB=${transposeB} must match.`); - - const outShape = $a.shape.slice(0, -2).concat([outerShapeA, outerShapeB]); - - const a3D = transposeA ? $a.as3D(batchDimA, innerShapeA, outerShapeA) : - $a.as3D(batchDimA, outerShapeA, innerShapeA); - const b3D = transposeB ? $b.as3D(batchDimB, outerShapeB, innerShapeB) : - $b.as3D(batchDimB, innerShapeB, outerShapeB); - - let $bias: Tensor; - if (bias != null) { - $bias = convertToTensor(bias, 'bias', 'fused matMul'); - [$bias] = makeTypesMatch($bias, $a); - - broadcast_util.assertAndGetBroadcastShape(outShape, $bias.shape); - } - - let $preluActivationWeights: Tensor; - if (preluActivationWeights != null) { - $preluActivationWeights = convertToTensor( - preluActivationWeights, 'prelu weights', 'fused matMul'); - } - - const grad = (dy: Tensor3D, saved: Tensor[]) => { - const [a3D, b3D, y] = saved; - const dyActivation = getFusedDyActivation(dy, y, activation); - - let biasGradient = {}; - if (bias != null) { - biasGradient = {bias: () => getFusedBiasGradient($bias, dyActivation)}; - } - - if (!transposeA && !transposeB) { - return Object.assign( - { - a: () => dyActivation.matMul(b3D as Tensor3D, false, true), - b: () => a3D.matMul(dyActivation, true, false) - }, - biasGradient); - } else if (!transposeA && transposeB) { - return Object.assign( - { - a: () => dyActivation.matMul(b3D as Tensor3D, false, false), - b: () => dyActivation.matMul(a3D as Tensor3D, true, false) - }, - biasGradient); - } else if (transposeA && !transposeB) { - return Object.assign( - { - a: () => b3D.matMul(dyActivation, false, true), - b: () => a3D.matMul(dyActivation, false, false) - }, - biasGradient); - } else { - return Object.assign( - { - a: () => b3D.matMul(dyActivation, true, true), - b: () => dyActivation.matMul(a3D as Tensor3D, true, true) - }, - biasGradient); - } - }; - - const inputs: - {a: Tensor, b: Tensor, - bias?: Tensor, - preluActivationWeights?: Tensor} = {a: a3D, b: b3D}; - if (bias != null) { - inputs.bias = $bias; - } - if (preluActivationWeights != null) { - inputs.preluActivationWeights = $preluActivationWeights; - } - - const inputsToSave = [a3D, b3D]; - const outputsToSave = [true]; - - const res = ENGINE.runKernelFunc( - (backend, save) => { - const y = backend.fusedBatchMatMul({ - a: a3D, - b: b3D, - transposeA, - transposeB, - bias: $bias, - activation, - preluActivationWeights: $preluActivationWeights - }); - save([a3D, b3D, y]); - return y; - }, - inputs, grad, '_FusedMatMul', {transposeA, transposeB, activation}, - inputsToSave, outputsToSave); - return res.reshape(outShape); -} - -/** - * Computes a 2D convolution over the input x, optionally fused with adding a - * bias and applying an activation. - * - * ```js - * const inputDepth = 2; - * const inShape = [2, 2, 2, inputDepth]; - * const outputDepth = 2; - * const fSize = 1; - * const pad = 0; - * const strides = 1; - * - * const x = tf.tensor4d( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - * 16], inShape); - * const w = tf.tensor4d([-1, 1, -2, 0.5], [fSize, fSize, inputDepth, - * outputDepth]); - * - * tf.fused.conv2d({ x, filter: w, strides, pad, dataFormat: 'NHWC', - * dilations: [1, 1], bias: tf.scalar(5), activation: 'relu' }).print(); - * ``` - * - * @param obj An object with the following properties: - * @param x The input tensor, of rank 4 or rank 3, of shape - * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is - * assumed. - * @param filter The filter, rank 4, of shape - * `[filterHeight, filterWidth, inDepth, outDepth]`. - * @param strides The strides of the convolution: `[strideHeight, - * strideWidth]`. - * @param pad The type of padding algorithm. - * - `same` and stride 1: output will be of same size as input, - * regardless of filter size. - * - `valid` output will be smaller than input if filter is larger - * than 1x1. - * - For more info, see this guide: - * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( - * https://www.tensorflow.org/api_guides/python/nn#Convolution) - * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to - * "NHWC". Specify the data format of the input and output data. With the - * default format "NHWC", the data is stored in the order of: [batch, - * height, width, channels]. Only "NHWC" is currently supported. - * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` - * in which we sample input values across the height and width dimensions - * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single - * number, then `dilationHeight == dilationWidth`. If it is greater than - * 1, then all values of `strides` must be 1. - * @param dimRoundingMode The rounding mode used when computing output - * dimensions if pad is a number. If none is provided, it will not round - * and error if the output is of fractional size. - * @param bias Tensor to be added to the result. - * @param activation Name of activation kernel (defaults to `linear`) to be - * applied - * after biasAdd. - * @param preluActivationWeights Tensor of prelu weights to be applied as part - * of a `prelu` activation, typically the same shape as `x`. - */ -function fusedConv2d_({ - x, - filter, - strides, - pad, - dataFormat = 'NHWC', - dilations = [1, 1], - dimRoundingMode, - bias, - activation = 'linear', - preluActivationWeights -}: { - x: T|TensorLike, - filter: Tensor4D|TensorLike, - strides: [number, number]|number, - pad: 'valid'|'same'|number|conv_util.ExplicitPadding, - dataFormat?: 'NHWC'|'NCHW', - dilations?: [number, number]|number, - dimRoundingMode?: 'floor'|'round'|'ceil', - bias?: Tensor|TensorLike, - activation?: Activation, - preluActivationWeights?: Tensor -}): T { - activation = activation || 'linear'; - if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) { - let result = unfusedConv2d( - x, filter, strides, pad, dataFormat, dilations, dimRoundingMode); - if (bias != null) { - result = add(result, bias); - } - - return applyActivation(result, activation, preluActivationWeights) as T; - } - - const $x = convertToTensor(x, 'x', 'conv2d'); - const $filter = convertToTensor(filter, 'filter', 'conv2d'); - - let x4D = $x as Tensor4D; - let reshapedTo4D = false; - - if ($x.rank === 3) { - reshapedTo4D = true; - x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); - } - util.assert( - x4D.rank === 4, - () => `Error in fused conv2d: input must be rank 4, but got rank ` + - `${x4D.rank}.`); - util.assert( - $filter.rank === 4, - () => `Error in fused conv2d: filter must be rank 4, but got rank ` + - `${$filter.rank}.`); - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in fused conv2d: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - - util.assert( - x4D.shape[3] === $filter.shape[2], - () => `Error in conv2d: depth of input (${x4D.shape[3]}) must match ` + - `input depth for filter ${$filter.shape[2]}.`); - util.assert( - conv_util.eitherStridesOrDilationsAreOne(strides, dilations), - () => 'Error in conv2D: Either strides or dilations must be 1. ' + - `Got strides ${strides} and dilations '${dilations}'`); - util.assert( - dataFormat === 'NHWC', - () => `Error in conv2d: got dataFormat of ${ - dataFormat} but only NHWC is currently supported.`); - - const convInfo = conv_util.computeConv2DInfo( - x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode); - - let $bias: Tensor; - if (bias != null) { - $bias = convertToTensor(bias, 'bias', 'fused conv2d'); - [$bias] = makeTypesMatch($bias, $x); - - broadcast_util.assertAndGetBroadcastShape(convInfo.outShape, $bias.shape); - } - - let $preluActivationWeights: Tensor; - if (preluActivationWeights != null) { - $preluActivationWeights = convertToTensor( - preluActivationWeights, 'prelu weights', 'fused conv2d'); - } - - const grad = (dy: Tensor4D, saved: Tensor[]) => { - const [$filter, x4D, y] = saved as [Tensor4D, Tensor4D, Tensor4D]; - - const dyActivation = getFusedDyActivation(dy, y, activation) as Tensor4D; - - util.assert( - conv_util.tupleValuesAreOne(dilations), - () => 'Error in gradient of fused conv2D: ' + - `dilation rates greater than 1 ` + - `are not yet supported in gradients. Got dilations '${dilations}'`); - - let biasGradient = {}; - if (bias != null) { - biasGradient = {bias: () => getFusedBiasGradient($bias, dyActivation)}; - } - - return Object.assign( - { - x: () => conv2DBackpropInput( - x4D.shape, dyActivation, $filter, strides, pad), - filter: () => conv2DBackpropFilter( - x4D, dyActivation, $filter.shape, strides, pad) - }, - biasGradient); - }; - - const inputs: { - x: Tensor, - filter: Tensor, - bias?: Tensor, - preluActivationWeights?: Tensor - } = {x: x4D, filter: $filter}; - if (bias != null) { - inputs.bias = $bias; - } - if (preluActivationWeights != null) { - inputs.preluActivationWeights = $preluActivationWeights; - } - - const inputsToSave = [$filter, x4D]; - const outputsToSave = [true]; // Save the only output. - const res = ENGINE.runKernelFunc( - (backend, save) => { - const res = backend.fusedConv2d({ - input: x4D, - filter: $filter, - convInfo, - bias: $bias, - activation, - preluActivationWeights: $preluActivationWeights - }); - save([$filter, x4D, res]); - return res; - }, - inputs, grad, 'FusedConv2D', {convInfo, activation}, inputsToSave, - outputsToSave); - - if (reshapedTo4D) { - return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; - } - - return res as T; -} - -/** - * Computes depthwise 2D convolution, optionally fused with adding a - * bias and applying an activation. - * - * Given a 4D `input` array and a `filter` array of shape - * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing - * `inChannels` convolutional filters of depth 1, this op applies a - * different filter to each input channel (expanding from 1 channel to - * `channelMultiplier` channels for each), then concatenates the results - * together. The output has `inChannels * channelMultiplier` channels. - * - * See - * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d]( - * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d) - * for more details. - * - * @param obj An object with the following properties: - * @param x The input tensor, of rank 4 or rank 3, of shape - * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is - * assumed. - * @param filter The filter tensor, rank 4, of shape - * `[filterHeight, filterWidth, inChannels, channelMultiplier]`. - * @param strides The strides of the convolution: `[strideHeight, - * strideWidth]`. If strides is a single number, then `strideHeight == - * strideWidth`. - * @param pad The type of padding algorithm. - * - `same` and stride 1: output will be of same size as input, - * regardless of filter size. - * - `valid`: output will be smaller than input if filter is larger - * than 1x1. - * - For more info, see this guide: - * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( - * https://www.tensorflow.org/api_guides/python/nn#Convolution) - * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` - * in which we sample input values across the height and width dimensions - * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single - * number, then `dilationHeight == dilationWidth`. If it is greater than - * 1, then all values of `strides` must be 1. - * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to - * "NHWC". Specify the data format of the input and output data. With the - * default format "NHWC", the data is stored in the order of: [batch, - * height, width, channels]. Only "NHWC" is currently supported. - * @param dimRoundingMode The rounding mode used when computing output - * dimensions if pad is a number. If none is provided, it will not round - * and error if the output is of fractional size. - * @param bias Tensor to be added to the result. - * @param activation Name of activation kernel (defaults to `linear`). - * @param preluActivationWeights Tensor of prelu weights to be applied as part - * of a `prelu` activation, typically the same shape as `x`. - */ -function fusedDepthwiseConv2d_({ - x, - filter, - strides, - pad, - dataFormat = 'NHWC', - dilations = [1, 1], - dimRoundingMode, - bias, - activation = 'linear', - preluActivationWeights -}: { - x: T|TensorLike, - filter: Tensor4D|TensorLike, - strides: [number, number]|number, - pad: 'valid'|'same'|number, - dataFormat?: 'NHWC'|'NCHW', - dilations?: [number, number]|number, - dimRoundingMode?: 'floor'|'round'|'ceil', - bias?: Tensor|TensorLike, - activation?: Activation, - preluActivationWeights?: Tensor -}): T { - if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) { - let result = unfusedDepthwiseConv2d( - x, filter, strides, pad, dataFormat, dilations, dimRoundingMode); - if (bias != null) { - result = add(result, bias); - } - - return applyActivation(result, activation, preluActivationWeights) as T; - } - - const $x = convertToTensor(x, 'x', 'depthwiseConv2d'); - const $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d'); - - let x4D = $x as Tensor4D; - let reshapedTo4D = false; - if ($x.rank === 3) { - reshapedTo4D = true; - x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); - } - util.assert( - x4D.rank === 4, - () => `Error in fused depthwiseConv2d: input must be rank 4, but got ` + - `rank ${x4D.rank}.`); - util.assert( - $filter.rank === 4, - () => `Error in fused depthwiseConv2d: filter must be rank 4, ` + - `but got rank ${$filter.rank}.`); - util.assert( - x4D.shape[3] === $filter.shape[2], - () => `Error in fused depthwiseConv2d: number of input channels ` + - `(${x4D.shape[3]}) must match the inChannels dimension in ` + - `filter ${$filter.shape[2]}.`); - if (dilations == null) { - dilations = [1, 1]; - } - util.assert( - conv_util.eitherStridesOrDilationsAreOne(strides, dilations), - () => - 'Error in fused depthwiseConv2d: Either strides or dilations must ' + - `be 1. Got strides ${strides} and dilations '${dilations}'`); - - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in fused depthwiseConv2d: pad must be an integer when ` + - `using dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - - const convInfo = conv_util.computeConv2DInfo( - x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, - true /* depthwise */); - - let $bias: Tensor; - if (bias != null) { - $bias = convertToTensor(bias, 'bias', 'fused conv2d'); - [$bias] = makeTypesMatch($bias, $x); - - broadcast_util.assertAndGetBroadcastShape(convInfo.outShape, $bias.shape); - } - - let $preluActivationWeights: Tensor; - if (preluActivationWeights != null) { - $preluActivationWeights = convertToTensor( - preluActivationWeights, 'prelu weights', 'fused depthwiseConv2d'); - } - - const grad = (dy: Tensor4D, saved: Tensor[]) => { - util.assert( - conv_util.tupleValuesAreOne(dilations), - () => 'Error in gradient of fused depthwiseConv2d: dilation rates ' + - `greater than 1 are not yet supported. Got dilations ` + - `'${dilations}'`); - const [$filter, x4D, y] = saved; - - const dyActivation = getFusedDyActivation(dy, y, activation) as Tensor4D; - - let biasGradient = {}; - if (bias != null) { - biasGradient = {bias: () => getFusedBiasGradient($bias, dyActivation)}; - } - - return Object.assign( - { - x: () => depthwiseConv2dNativeBackpropInput( - (x4D as Tensor4D).shape, dyActivation, $filter as Tensor4D, - convInfo), - filter: () => depthwiseConv2dNativeBackpropFilter( - x4D as Tensor4D, dyActivation, ($filter as Tensor4D).shape, - convInfo), - }, - biasGradient); - }; - - const inputs: { - x: Tensor, - filter: Tensor, - bias?: Tensor, - preluActivationWeights?: Tensor - } = {x: x4D, filter: $filter}; - if (bias != null) { - inputs.bias = $bias; - } - if (preluActivationWeights != null) { - inputs.preluActivationWeights = $preluActivationWeights; - } - - const inputsToSave = [$filter, x4D]; - const outputsToSave = [true]; - const res = ENGINE.runKernelFunc( - (backend, save) => { - const res = backend.fusedDepthwiseConv2D({ - input: x4D, - filter: $filter, - convInfo, - bias: $bias, - activation, - preluActivationWeights: $preluActivationWeights - }); - save([$filter, x4D, res]); - return res; - }, - inputs, grad, 'FusedDepthwiseConv2D', {convInfo, activation}, - inputsToSave, outputsToSave); - if (reshapedTo4D) { - return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T; - } - return res as T; -} - -export const matMul = op({fusedMatMul_}); -export const conv2d = op({fusedConv2d_}); -export const depthwiseConv2d = op({fusedDepthwiseConv2d_}); - -export {Activation}; +export {Activation, conv2d, depthwiseConv2d, matMul}; diff --git a/tfjs-core/src/ops/fused_types.ts b/tfjs-core/src/ops/fused_types.ts new file mode 100644 index 00000000000..894e2708869 --- /dev/null +++ b/tfjs-core/src/ops/fused_types.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor, Tensor3D, Tensor4D} from '../tensor'; +import {Conv2DInfo} from './conv_util'; + +export type FusedConv2DConfig = { + input: Tensor4D, + filter: Tensor4D, + convInfo: Conv2DInfo, + bias?: Tensor, + activation?: Activation, + preluActivationWeights?: Tensor +}; + +export type FusedBatchMatMulConfig = { + a: Tensor3D, + b: Tensor3D, + transposeA: boolean, + transposeB: boolean, + bias?: Tensor, + activation?: Activation, + preluActivationWeights?: Tensor +}; + +export type Activation = 'linear'|'relu'|'prelu'|'elu'|'relu6'; diff --git a/tfjs-core/src/ops/fused_util.ts b/tfjs-core/src/ops/fused_util.ts index d9f16133966..a41c7a574bb 100644 --- a/tfjs-core/src/ops/fused_util.ts +++ b/tfjs-core/src/ops/fused_util.ts @@ -15,30 +15,56 @@ * ============================================================================= */ -import {Tensor, Tensor3D, Tensor4D} from '../tensor'; +import {Tensor} from '../tensor'; -import {Conv2DInfo} from './conv_util'; +import * as broadcast_util from './broadcast_util'; +import {elu} from './elu'; +import {Activation} from './fused_types'; +import {prelu} from './prelu'; +import {relu} from './relu'; +import {relu6} from './relu6'; -export type Activation = 'linear'|'relu'|'prelu'|'elu'|'relu6'; +// Returns gradient for fused activation. +export function getFusedDyActivation( + dy: Tensor, y: Tensor, activation: Activation): Tensor { + if (activation == null || activation === 'linear') { + return dy; + } + if (activation === 'relu') { + return dy.mul(y.step()); + } + throw new Error( + `Cannot compute gradient for fused activation ${activation}.`); +} -export type FusedBatchMatMulConfig = { - a: Tensor3D, - b: Tensor3D, - transposeA: boolean, - transposeB: boolean, - bias?: Tensor, - activation?: Activation, - preluActivationWeights?: Tensor -}; +// Returns gradient for fused bias. +export function getFusedBiasGradient( + bias: Tensor, dyActivation: Tensor): Tensor { + let res = dyActivation; + const reduceAxes = + broadcast_util.getReductionAxes(bias.shape, dyActivation.shape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape(bias.shape); +} -export type FusedConv2DConfig = { - input: Tensor4D, - filter: Tensor4D, - convInfo: Conv2DInfo, - bias?: Tensor, - activation?: Activation, - preluActivationWeights?: Tensor -}; +export function applyActivation( + x: Tensor, activation: Activation, + preluActivationWeights?: Tensor): Tensor { + if (activation === 'linear') { + return x; + } else if (activation === 'relu') { + return relu(x); + } else if (activation === 'elu') { + return elu(x); + } else if (activation === 'relu6') { + return relu6(x); + } else if (activation === 'prelu') { + return prelu(x, preluActivationWeights); + } + throw new Error(`Unknown fused activation ${activation}.`); +} // Whether we should call fused ops. export const shouldFuse = (gradientDepth: number, activation: Activation) => { diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index 0e08d5965d2..e24f27b30a9 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -103,7 +103,9 @@ import './ops/fill_test'; import './ops/floor_test'; import './ops/frame_test'; import './ops/from_pixels_test'; -import './ops/fused_test'; +import './ops/fused_conv2d_test'; +import './ops/fused_depthwise_conv2d_test'; +import './ops/fused_mat_mul_test'; import './ops/gather_nd_test'; import './ops/gather_test'; import './ops/gram_schmidt_test';