diff --git a/tfjs-backend-wasm/src/kernels/FusedBatchNorm.ts b/tfjs-backend-wasm/src/kernels/FusedBatchNorm.ts index 03c82decb15..ac68da3bee7 100644 --- a/tfjs-backend-wasm/src/kernels/FusedBatchNorm.ts +++ b/tfjs-backend-wasm/src/kernels/FusedBatchNorm.ts @@ -32,19 +32,19 @@ interface BatchNormAttrs extends NamedAttrMap { } let wasmBatchNorm: ( - xId: number, meanId: number, varianceId: number, offsetId: number, - scaleId: number, varianceEpsilon: number, outId: number) => void; + xId: number, meanId: number, varianceId: number, offsetId: number, + scaleId: number, varianceEpsilon: number, outId: number) => void; function setup(backend: BackendWasm): void { wasmBatchNorm = backend.wasm.cwrap( - 'FusedBatchNorm', null /* void */, - ['number', 'number', 'number', 'number', 'number', 'number', 'number']); + 'FusedBatchNorm', null /* void */, + ['number', 'number', 'number', 'number', 'number', 'number', 'number']); } function fusedBatchNorm( - args: - {backend: BackendWasm, inputs: BatchNormInputs, attrs: BatchNormAttrs}): - TensorInfo { + args: + {backend: BackendWasm, inputs: BatchNormInputs, attrs: BatchNormAttrs}): + TensorInfo { const {backend, inputs, attrs} = args; const {varianceEpsilon} = attrs; const {x, mean, variance, offset, scale} = inputs; @@ -63,12 +63,12 @@ function fusedBatchNorm( const outId = backend.dataIdMap.get(out.dataId).id; wasmBatchNorm( - xId, meanId, varianceId, offsetId, scaleId, varianceEpsilon, outId); + xId, meanId, varianceId, offsetId, scaleId, varianceEpsilon, outId); return out; } registerKernel({ - kernelName: 'BatchNormalization', + kernelName: 'FusedBatchNorm', backendName: 'wasm', setupFunc: setup, kernelFunc: fusedBatchNorm diff --git a/tfjs-core/src/gradients/FusedBatchNorm_grad.ts b/tfjs-core/src/gradients/FusedBatchNorm_grad.ts new file mode 100644 index 00000000000..630fe68e2fc --- /dev/null +++ b/tfjs-core/src/gradients/FusedBatchNorm_grad.ts @@ -0,0 +1,111 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {FusedBatchNorm, FusedBatchNormAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import {xAs4D} from '../ops/batchnorm_util'; +import {getReductionAxes} from '../ops/broadcast_util'; +import {add, mul, reshape, sub} from '../ops/ops'; +import {sum} from '../ops/reduction_ops'; +import {scalar} from '../ops/tensor_ops'; +import {tile} from '../ops/tile'; +import {rsqrt} from '../ops/unary_ops'; +import {Tensor, Tensor4D} from '../tensor'; +import {Rank, ShapeMap} from '../types'; + +export const fusedBatchNormGradConfig: GradConfig = { + kernelName: FusedBatchNorm, + inputsToSave: ['x', 'mean', 'variance', 'scale'], + gradFunc: ( + dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + const batchNormalizationAttrs: FusedBatchNormAttrs = + attrs as {} as FusedBatchNormAttrs; + const {varianceEpsilon} = batchNormalizationAttrs; + const [x, mean, variance, scale] = saved; + + const x4D: Tensor4D = xAs4D(x); + + const scaleValue = scale == null ? scalar(1) : scale; + const reductionAxes = getReductionAxes(mean.shape, x4D.shape); + const tileShape: number[] = []; + if (mean.rank === 1) { + for (let i = 0; i < x4D.shape.length - 1; ++i) { + tileShape.push(x4D.shape[i]); + } + tileShape.push(1); + } + + const xMinusMean = sub(x, mean); + const dyTimesScaleValue = mul(dy, scaleValue); + const oneOverSqrtVariance = rsqrt(add(variance, scalar(varianceEpsilon))); + const minusHalfRCube = mul( + mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance), + scalar(-0.5)); + + const derX = () => { + if (mean.rank === 1) { + return reshape( + mul(mul(dy, + tile( + oneOverSqrtVariance.as4D(1, 1, 1, mean.shape[0]), + tileShape)), + scaleValue), + x.shape); + } else { + return reshape(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape); + } + }; + const derMean = () => { + let meanDer = + mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue); + if (mean.rank === 1) { + meanDer = sum(meanDer, reductionAxes); + } + return reshape(meanDer, mean.shape as ShapeMap[R]); + }; + const derVariance = () => { + let varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue); + + if (mean.rank === 1) { + varianceDer = sum(varianceDer, reductionAxes); + } + return reshape(varianceDer, mean.shape as ShapeMap[R]); + }; + const derScale = () => { + const xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance); + + let scaleDer = mul(dy, xMinusMean2TimesRsqrt); + if (mean.rank === 1) { + scaleDer = sum(scaleDer, reductionAxes); + } + return reshape(scaleDer, mean.shape as ShapeMap[R]); + }; + const derOffset = () => { + let offsetDer = dy; + if (mean.rank === 1) { + offsetDer = sum(offsetDer, reductionAxes); + } + return reshape(offsetDer, mean.shape as ShapeMap[R]); + }; + return { + x: derX, + mean: derMean, + variance: derVariance, + scale: derScale, + offset: derOffset + }; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 239d9794a46..2c680a2c387 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -21,6 +21,13 @@ import {NamedTensorInfoMap} from './kernel_registry'; import {PixelData} from './types'; +export const FusedBatchNorm = 'FusedBatchNorm'; +export type FusedBatchNormInputs = + Pick; +export interface FusedBatchNormAttrs { + varianceEpsilon: number; +} + export type BinaryInputs = Pick; export const Div = 'Div'; diff --git a/tfjs-core/src/ops/batchnorm.ts b/tfjs-core/src/ops/batchnorm.ts index 94d58fb79a4..d707356ceaf 100644 --- a/tfjs-core/src/ops/batchnorm.ts +++ b/tfjs-core/src/ops/batchnorm.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google Inc. All Rights Reserved. + * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -15,182 +15,17 @@ * ============================================================================= */ -import {ENGINE} from '../engine'; -import {deprecationWarn} from '../globals'; -import {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'; +import {ENGINE, ForwardFunc} from '../engine'; +import {FusedBatchNormAttrs, FusedBatchNormInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor, Tensor1D, Tensor4D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; -import {Rank, ShapeMap, TensorLike} from '../types'; +import {Rank, TensorLike} from '../types'; import * as util from '../util'; -import {getReductionAxes} from './broadcast_util'; +import {warnDeprecation, xAs4D} from './batchnorm_util'; import {op} from './operation'; -import {scalar} from './tensor_ops'; -import {tile} from './tile'; -import {rsqrt} from './unary_ops'; - -/** - * Batch normalization, strictly for 2D. For the more relaxed version, see - * `tf.batchNorm`. - * - * @param x The input Tensor. - * @param mean A mean Tensor. - * @param variance A variance Tensor. - * @param offset An offset Tensor. - * @param scale A scale Tensor. - * @param varianceEpsilon A small float number to avoid dividing by 0. - */ -function batchNorm2d_( - x: Tensor2D|TensorLike, mean: Tensor2D|Tensor1D|TensorLike, - variance: Tensor2D|Tensor1D|TensorLike, - offset?: Tensor2D|Tensor1D|TensorLike, scale?: Tensor2D|Tensor1D|TensorLike, - varianceEpsilon?: number): Tensor2D { - const $x = convertToTensor(x, 'x', 'batchNorm'); - const $mean = convertToTensor(mean, 'mean', 'batchNorm'); - const $variance = convertToTensor(variance, 'variance', 'batchNorm'); - let $scale: Tensor2D|Tensor1D; - if (scale != null) { - $scale = convertToTensor(scale, 'scale', 'batchNorm'); - } - let $offset: Tensor2D|Tensor1D; - if (offset != null) { - $offset = convertToTensor(offset, 'offset', 'batchNorm'); - } - util.assert( - $x.rank === 2, - () => `Error in batchNorm3D: x must be rank 3 but got rank ` + - `${$x.rank}.`); - util.assert( - $mean.rank === 2 || $mean.rank === 1, - () => `Error in batchNorm2D: mean must be rank 2 or rank 1 but ` + - `got rank ${$mean.rank}.`); - util.assert( - $variance.rank === 2 || $variance.rank === 1, - () => `Error in batchNorm2D: variance must be rank 2 or rank 1 ` + - `but got rank ${$variance.rank}.`); - if ($scale != null) { - util.assert( - $scale.rank === 2 || $scale.rank === 1, - () => `Error in batchNorm2D: scale must be rank 2 or rank 1 ` + - `but got rank ${$scale.rank}.`); - } - if ($offset != null) { - util.assert( - $offset.rank === 2 || $offset.rank === 1, - () => `Error in batchNorm2D: offset must be rank 2 or rank 1 ` + - `but got rank ${$offset.rank}.`); - } - - return batchNorm_($x, $mean, $variance, $offset, $scale, varianceEpsilon); -} - -/** - * Batch normalization, strictly for 3D. For the more relaxed version, see - * `tf.batchNorm`. - * - * @param x The input Tensor. - * @param mean A mean Tensor. - * @param variance A variance Tensor. - * @param offset An offset Tensor. - * @param scale A scale Tensor. - * @param varianceEpsilon A small float number to avoid dividing by 0. - */ -function batchNorm3d_( - x: Tensor3D|TensorLike, mean: Tensor3D|Tensor1D|TensorLike, - variance: Tensor3D|Tensor1D|TensorLike, - offset?: Tensor3D|Tensor1D|TensorLike, scale?: Tensor3D|Tensor1D|TensorLike, - varianceEpsilon?: number): Tensor3D { - const $x = convertToTensor(x, 'x', 'batchNorm'); - const $mean = convertToTensor(mean, 'mean', 'batchNorm'); - const $variance = convertToTensor(variance, 'variance', 'batchNorm'); - let $scale: Tensor3D|Tensor1D; - if (scale != null) { - $scale = convertToTensor(scale, 'scale', 'batchNorm'); - } - let $offset: Tensor3D|Tensor1D; - if (offset != null) { - $offset = convertToTensor(offset, 'offset', 'batchNorm'); - } - util.assert( - $x.rank === 3, - () => `Error in batchNorm3D: x must be rank 3 but got rank ` + - `${$x.rank}.`); - util.assert( - $mean.rank === 3 || $mean.rank === 1, - () => `Error in batchNorm3D: mean must be rank 3 or rank 1 but ` + - `got rank ${$mean.rank}.`); - util.assert( - $variance.rank === 3 || $variance.rank === 1, - () => `Error in batchNorm3D: variance must be rank 3 or rank 1 ` + - `but got rank ${$variance.rank}.`); - if ($scale != null) { - util.assert( - $scale.rank === 3 || $scale.rank === 1, - () => `Error in batchNorm3D: scale must be rank 3 or rank 1 ` + - `but got rank ${$scale.rank}.`); - } - if ($offset != null) { - util.assert( - $offset.rank === 3 || $offset.rank === 1, - () => `Error in batchNorm3D: offset must be rank 3 or rank 1 ` + - `but got rank ${$offset.rank}.`); - } - - return batchNorm_($x, $mean, $variance, $offset, $scale, varianceEpsilon); -} - -/** - * Batch normalization, strictly for 4D. For the more relaxed version, see - * `tf.batchNorm`. - * - * @param x The input Tensor. - * @param mean A mean Tensor. - * @param variance A variance Tensor. - * @param offset An offset Tensor. - * @param scale A scale Tensor. - * @param varianceEpsilon A small float number to avoid dividing by 0. - */ -function batchNorm4d_( - x: Tensor4D|TensorLike, mean: Tensor4D|Tensor1D|TensorLike, - variance: Tensor4D|Tensor1D|TensorLike, - offset?: Tensor4D|Tensor1D|TensorLike, scale?: Tensor4D|Tensor1D|TensorLike, - varianceEpsilon?: number): Tensor4D { - const $x = convertToTensor(x, 'x', 'batchNorm'); - const $mean = convertToTensor(mean, 'mean', 'batchNorm'); - const $variance = convertToTensor(variance, 'variance', 'batchNorm'); - let $scale: Tensor4D|Tensor1D; - if (scale != null) { - $scale = convertToTensor(scale, 'scale', 'batchNorm'); - } - let $offset: Tensor4D|Tensor1D; - if (offset != null) { - $offset = convertToTensor(offset, 'offset', 'batchNorm'); - } - util.assert( - $x.rank === 4, - () => `Error in batchNorm4D: x must be rank 4 but got rank ` + - `${$x.rank}.`); - util.assert( - $mean.rank === 4 || $mean.rank === 1, - () => `Error in batchNorm4D: mean must be rank 4 or rank 1 but ` + - `got rank ${$mean.rank}.`); - util.assert( - $variance.rank === 4 || $variance.rank === 1, - () => `Error in batchNorm4D: variance must be rank 4 or rank 1 ` + - `but got rank ${$variance.rank}.`); - if ($scale != null) { - util.assert( - $scale.rank === 4 || $scale.rank === 1, - () => `Error in batchNorm4D: scale must be rank 4 or rank 1 ` + - `but got rank ${$scale.rank}.`); - } - if ($offset != null) { - util.assert( - $offset.rank === 4 || $offset.rank === 1, - () => `Error in batchNorm4D: offset must be rank 4 or rank 1 ` + - `but got rank ${$offset.rank}.`); - } - return batchNorm_($x, $mean, $variance, $offset, $scale, varianceEpsilon); -} /** * @deprecated Please use `tf.batchNorm` instead and note the positional @@ -264,105 +99,31 @@ function batchNorm_( () => 'Batch normalization gradient requires mean and scale to have ' + 'equal ranks.'); - let x4D: Tensor4D; - if ($x.rank === 0 || $x.rank === 1) { - x4D = $x.as4D(1, 1, 1, $x.size); - } else if ($x.rank === 2) { - x4D = $x.as4D(1, 1, $x.shape[0], $x.shape[1]); - } else if ($x.rank === 3) { - x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); - } else { - x4D = $x as Tensor4D; - } + const forward: ForwardFunc = (backend, save) => { + const x4D: Tensor4D = xAs4D($x); - const der = (dy: Tensor, saved: Tensor[]) => { - type Saved = [ - Tensor, Tensor| Tensor1D, Tensor| Tensor1D, Tensor| Tensor1D - ]; - const [$x, $mean, $variance, $scale] = saved as Saved; - const scaleValue = $scale == null ? scalar(1) : $scale; - const reductionAxes = getReductionAxes($mean.shape, x4D.shape); - const tileShape: number[] = []; - if ($mean.rank === 1) { - for (let i = 0; i < x4D.shape.length - 1; ++i) { - tileShape.push(x4D.shape[i]); - } - tileShape.push(1); - } + const res = backend.batchNormalization( + x4D, as1DOr4D($mean), as1DOr4D($variance), varianceEpsilon, + as1DOr4D($scale), as1DOr4D($offset)); - const xMinusMean = $x.sub($mean); - const dyTimesScaleValue = dy.mul(scaleValue); - const oneOverSqrtVariance = rsqrt($variance.add(scalar(varianceEpsilon))); - const minusHalfRCube = oneOverSqrtVariance.mul(oneOverSqrtVariance) - .mul(oneOverSqrtVariance) - .mul(scalar(-0.5)); + save([$x, $mean, $variance, $scale]); - const derX = () => { - if ($mean.rank === 1) { - return dy - .mul(tile( - oneOverSqrtVariance.as4D(1, 1, 1, $mean.shape[0]), tileShape)) - .mul(scaleValue) - .reshape($x.shape); - } else { - return dy.mul(oneOverSqrtVariance).mul(scaleValue).reshape($x.shape); - } - }; - const derMean = () => { - let meanDer = oneOverSqrtVariance.mul(scalar(-1)).mul(dyTimesScaleValue); - if ($mean.rank === 1) { - meanDer = meanDer.sum(reductionAxes); - } - return meanDer.reshape($mean.shape as ShapeMap[R]); - }; - const derVariance = () => { - let varianceDer = minusHalfRCube.mul(xMinusMean).mul(dyTimesScaleValue); - if ($mean.rank === 1) { - varianceDer = varianceDer.sum(reductionAxes); - } - return varianceDer.reshape($mean.shape as ShapeMap[R]); - }; - const derScale = () => { - const xMinusMean2TimesRsqrt = xMinusMean.mul(oneOverSqrtVariance); - let scaleDer = dy.mul(xMinusMean2TimesRsqrt); - if ($mean.rank === 1) { - scaleDer = scaleDer.sum(reductionAxes); - } - return scaleDer.reshape($mean.shape as ShapeMap[R]); - }; - const derOffset = () => { - let offsetDer = dy; - if ($mean.rank === 1) { - offsetDer = offsetDer.sum(reductionAxes); - } - return offsetDer.reshape($mean.shape as ShapeMap[R]); - }; - return { - x: derX, - mean: derMean, - variance: derVariance, - scale: derScale, - offset: derOffset - }; + return res; }; - const inputsToSave = [$x, $mean, $variance, $scale]; + const inputs: FusedBatchNormInputs = + {x: $x, scale: $scale, offset: $offset, mean: $mean, variance: $variance}; + + const attrs: FusedBatchNormAttrs = {varianceEpsilon}; const res = ENGINE.runKernelFunc( - (backend, save) => { - const res = backend.batchNormalization( - x4D, batchnormReshape4D($mean), batchnormReshape4D($variance), - varianceEpsilon, batchnormReshape4D($scale), - batchnormReshape4D($offset)); - save([$x, $mean, $variance, $scale]); - return res; - }, - {x: $x, mean: $mean, variance: $variance, scale: $scale, offset: $offset}, - der, 'BatchNormalization', {varianceEpsilon}, inputsToSave); + forward, inputs as {} as NamedTensorMap, null /* gradient */, + 'FusedBatchNorm', attrs as {} as NamedAttrMap); + return res.reshape($x.shape); } -function batchnormReshape4D(x: Tensor): Tensor4D|Tensor1D { +function as1DOr4D(x: Tensor): Tensor4D|Tensor1D { if (x == null) { return null; } @@ -378,58 +139,6 @@ function batchnormReshape4D(x: Tensor): Tensor4D|Tensor1D { return x as Tensor4D; } -/** - * @deprecated Please use `tf.batchNorm2d` instead and note the positional - * argument change of scale, offset, and varianceEpsilon. - */ -function batchNormalization2d_( - x: Tensor2D|TensorLike, mean: Tensor2D|Tensor1D|TensorLike, - variance: Tensor2D|Tensor1D|TensorLike, varianceEpsilon = .001, - scale?: Tensor2D|Tensor1D|TensorLike, - offset?: Tensor2D|Tensor1D|TensorLike): Tensor2D { - warnDeprecation(); - return batchNorm2d_(x, mean, variance, offset, scale, varianceEpsilon); -} - -/** - * @deprecated Please use `tf.batchNorm3d` instead and note the positional - * argument change of scale, offset, and varianceEpsilon. - */ -function batchNormalization3d_( - x: Tensor3D|TensorLike, mean: Tensor3D|Tensor1D|TensorLike, - variance: Tensor3D|Tensor1D|TensorLike, varianceEpsilon = .001, - scale?: Tensor3D|Tensor1D|TensorLike, - offset?: Tensor3D|Tensor1D|TensorLike): Tensor3D { - warnDeprecation(); - return batchNorm3d_(x, mean, variance, offset, scale, varianceEpsilon); -} - -/** - * @deprecated Please use `tf.batchNorm4d` instead and note the positional - * argument change of scale, offset, and varianceEpsilon. - */ -function batchNormalization4d_( - x: Tensor4D|TensorLike, mean: Tensor4D|Tensor1D|TensorLike, - variance: Tensor4D|Tensor1D|TensorLike, varianceEpsilon = .001, - scale?: Tensor4D|Tensor1D|TensorLike, - offset?: Tensor4D|Tensor1D|TensorLike): Tensor4D { - warnDeprecation(); - return batchNorm4d_(x, mean, variance, offset, scale, varianceEpsilon); -} - -function warnDeprecation() { - deprecationWarn( - 'tf.batchNormalization() is going away. ' + - 'Use tf.batchNorm() instead, and note the positional argument change ' + - 'of scale, offset, and varianceEpsilon'); -} - -export const batchNormalization2d = op({batchNormalization2d_}); -export const batchNormalization3d = op({batchNormalization3d_}); -export const batchNormalization4d = op({batchNormalization4d_}); +// todo(yassogba): Remove batchNormalization since it is deprecated. export const batchNormalization = op({batchNormalization_}); - export const batchNorm = op({batchNorm_}); -export const batchNorm2d = op({batchNorm2d_}); -export const batchNorm3d = op({batchNorm3d_}); -export const batchNorm4d = op({batchNorm4d_}); diff --git a/tfjs-core/src/ops/batchnorm2d.ts b/tfjs-core/src/ops/batchnorm2d.ts new file mode 100644 index 00000000000..4f80363307a --- /dev/null +++ b/tfjs-core/src/ops/batchnorm2d.ts @@ -0,0 +1,96 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Tensor1D, Tensor2D} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {batchNorm} from './batchnorm'; +import {warnDeprecation} from './batchnorm_util'; +import {op} from './operation'; + +/** + * Batch normalization, strictly for 2D. For the more relaxed version, see + * `tf.batchNorm`. + * + * @param x The input Tensor. + * @param mean A mean Tensor. + * @param variance A variance Tensor. + * @param offset An offset Tensor. + * @param scale A scale Tensor. + * @param varianceEpsilon A small float number to avoid dividing by 0. + */ +function batchNorm2d_( + x: Tensor2D|TensorLike, mean: Tensor2D|Tensor1D|TensorLike, + variance: Tensor2D|Tensor1D|TensorLike, + offset?: Tensor2D|Tensor1D|TensorLike, scale?: Tensor2D|Tensor1D|TensorLike, + varianceEpsilon?: number): Tensor2D { + const $x = convertToTensor(x, 'x', 'batchNorm'); + const $mean = convertToTensor(mean, 'mean', 'batchNorm'); + const $variance = convertToTensor(variance, 'variance', 'batchNorm'); + let $scale: Tensor2D|Tensor1D; + if (scale != null) { + $scale = convertToTensor(scale, 'scale', 'batchNorm'); + } + let $offset: Tensor2D|Tensor1D; + if (offset != null) { + $offset = convertToTensor(offset, 'offset', 'batchNorm'); + } + util.assert( + $x.rank === 2, + () => `Error in batchNorm3D: x must be rank 3 but got rank ` + + `${$x.rank}.`); + util.assert( + $mean.rank === 2 || $mean.rank === 1, + () => `Error in batchNorm2D: mean must be rank 2 or rank 1 but ` + + `got rank ${$mean.rank}.`); + util.assert( + $variance.rank === 2 || $variance.rank === 1, + () => `Error in batchNorm2D: variance must be rank 2 or rank 1 ` + + `but got rank ${$variance.rank}.`); + if ($scale != null) { + util.assert( + $scale.rank === 2 || $scale.rank === 1, + () => `Error in batchNorm2D: scale must be rank 2 or rank 1 ` + + `but got rank ${$scale.rank}.`); + } + if ($offset != null) { + util.assert( + $offset.rank === 2 || $offset.rank === 1, + () => `Error in batchNorm2D: offset must be rank 2 or rank 1 ` + + `but got rank ${$offset.rank}.`); + } + + return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon); +} + +/** + * @deprecated Please use `tf.batchNorm2d` instead and note the positional + * argument change of scale, offset, and varianceEpsilon. + */ +function batchNormalization2d_( + x: Tensor2D|TensorLike, mean: Tensor2D|Tensor1D|TensorLike, + variance: Tensor2D|Tensor1D|TensorLike, varianceEpsilon = .001, + scale?: Tensor2D|Tensor1D|TensorLike, + offset?: Tensor2D|Tensor1D|TensorLike): Tensor2D { + warnDeprecation(); + return batchNorm2d_(x, mean, variance, offset, scale, varianceEpsilon); +} + +// todo(yassogba): Remove batchNormalization2d since it is deprecated. +export const batchNormalization2d = op({batchNormalization2d_}); +export const batchNorm2d = op({batchNorm2d_}); diff --git a/tfjs-core/src/ops/batchnorm3d.ts b/tfjs-core/src/ops/batchnorm3d.ts new file mode 100644 index 00000000000..f8cf991e59d --- /dev/null +++ b/tfjs-core/src/ops/batchnorm3d.ts @@ -0,0 +1,96 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Tensor1D, Tensor3D} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {batchNorm} from './batchnorm'; +import {warnDeprecation} from './batchnorm_util'; +import {op} from './operation'; + +/** + * Batch normalization, strictly for 3D. For the more relaxed version, see + * `tf.batchNorm`. + * + * @param x The input Tensor. + * @param mean A mean Tensor. + * @param variance A variance Tensor. + * @param offset An offset Tensor. + * @param scale A scale Tensor. + * @param varianceEpsilon A small float number to avoid dividing by 0. + */ +function batchNorm3d_( + x: Tensor3D|TensorLike, mean: Tensor3D|Tensor1D|TensorLike, + variance: Tensor3D|Tensor1D|TensorLike, + offset?: Tensor3D|Tensor1D|TensorLike, scale?: Tensor3D|Tensor1D|TensorLike, + varianceEpsilon?: number): Tensor3D { + const $x = convertToTensor(x, 'x', 'batchNorm'); + const $mean = convertToTensor(mean, 'mean', 'batchNorm'); + const $variance = convertToTensor(variance, 'variance', 'batchNorm'); + let $scale: Tensor3D|Tensor1D; + if (scale != null) { + $scale = convertToTensor(scale, 'scale', 'batchNorm'); + } + let $offset: Tensor3D|Tensor1D; + if (offset != null) { + $offset = convertToTensor(offset, 'offset', 'batchNorm'); + } + util.assert( + $x.rank === 3, + () => `Error in batchNorm3D: x must be rank 3 but got rank ` + + `${$x.rank}.`); + util.assert( + $mean.rank === 3 || $mean.rank === 1, + () => `Error in batchNorm3D: mean must be rank 3 or rank 1 but ` + + `got rank ${$mean.rank}.`); + util.assert( + $variance.rank === 3 || $variance.rank === 1, + () => `Error in batchNorm3D: variance must be rank 3 or rank 1 ` + + `but got rank ${$variance.rank}.`); + if ($scale != null) { + util.assert( + $scale.rank === 3 || $scale.rank === 1, + () => `Error in batchNorm3D: scale must be rank 3 or rank 1 ` + + `but got rank ${$scale.rank}.`); + } + if ($offset != null) { + util.assert( + $offset.rank === 3 || $offset.rank === 1, + () => `Error in batchNorm3D: offset must be rank 3 or rank 1 ` + + `but got rank ${$offset.rank}.`); + } + + return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon); +} + +/** + * @deprecated Please use `tf.batchNorm3d` instead and note the positional + * argument change of scale, offset, and varianceEpsilon. + */ +function batchNormalization3d_( + x: Tensor3D|TensorLike, mean: Tensor3D|Tensor1D|TensorLike, + variance: Tensor3D|Tensor1D|TensorLike, varianceEpsilon = .001, + scale?: Tensor3D|Tensor1D|TensorLike, + offset?: Tensor3D|Tensor1D|TensorLike): Tensor3D { + warnDeprecation(); + return batchNorm3d_(x, mean, variance, offset, scale, varianceEpsilon); +} + +// todo(yassogba): Remove batchNormalization3d since it is deprecated. +export const batchNormalization3d = op({batchNormalization3d_}); +export const batchNorm3d = op({batchNorm3d_}); diff --git a/tfjs-core/src/ops/batchnorm4d.ts b/tfjs-core/src/ops/batchnorm4d.ts new file mode 100644 index 00000000000..bbb56d0bb33 --- /dev/null +++ b/tfjs-core/src/ops/batchnorm4d.ts @@ -0,0 +1,95 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Tensor1D, Tensor4D} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {batchNorm} from './batchnorm'; +import {warnDeprecation} from './batchnorm_util'; +import {op} from './operation'; + +/** + * Batch normalization, strictly for 4D. For the more relaxed version, see + * `tf.batchNorm`. + * + * @param x The input Tensor. + * @param mean A mean Tensor. + * @param variance A variance Tensor. + * @param offset An offset Tensor. + * @param scale A scale Tensor. + * @param varianceEpsilon A small float number to avoid dividing by 0. + */ +function batchNorm4d_( + x: Tensor4D|TensorLike, mean: Tensor4D|Tensor1D|TensorLike, + variance: Tensor4D|Tensor1D|TensorLike, + offset?: Tensor4D|Tensor1D|TensorLike, scale?: Tensor4D|Tensor1D|TensorLike, + varianceEpsilon?: number): Tensor4D { + const $x = convertToTensor(x, 'x', 'batchNorm'); + const $mean = convertToTensor(mean, 'mean', 'batchNorm'); + const $variance = convertToTensor(variance, 'variance', 'batchNorm'); + let $scale: Tensor4D|Tensor1D; + if (scale != null) { + $scale = convertToTensor(scale, 'scale', 'batchNorm'); + } + let $offset: Tensor4D|Tensor1D; + if (offset != null) { + $offset = convertToTensor(offset, 'offset', 'batchNorm'); + } + util.assert( + $x.rank === 4, + () => `Error in batchNorm4D: x must be rank 4 but got rank ` + + `${$x.rank}.`); + util.assert( + $mean.rank === 4 || $mean.rank === 1, + () => `Error in batchNorm4D: mean must be rank 4 or rank 1 but ` + + `got rank ${$mean.rank}.`); + util.assert( + $variance.rank === 4 || $variance.rank === 1, + () => `Error in batchNorm4D: variance must be rank 4 or rank 1 ` + + `but got rank ${$variance.rank}.`); + if ($scale != null) { + util.assert( + $scale.rank === 4 || $scale.rank === 1, + () => `Error in batchNorm4D: scale must be rank 4 or rank 1 ` + + `but got rank ${$scale.rank}.`); + } + if ($offset != null) { + util.assert( + $offset.rank === 4 || $offset.rank === 1, + () => `Error in batchNorm4D: offset must be rank 4 or rank 1 ` + + `but got rank ${$offset.rank}.`); + } + return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon); +} + +/** + * @deprecated Please use `tf.batchNorm4d` instead and note the positional + * argument change of scale, offset, and varianceEpsilon. + */ +function batchNormalization4d_( + x: Tensor4D|TensorLike, mean: Tensor4D|Tensor1D|TensorLike, + variance: Tensor4D|Tensor1D|TensorLike, varianceEpsilon = .001, + scale?: Tensor4D|Tensor1D|TensorLike, + offset?: Tensor4D|Tensor1D|TensorLike): Tensor4D { + warnDeprecation(); + return batchNorm4d_(x, mean, variance, offset, scale, varianceEpsilon); +} + +// todo(yassogba): Remove batchNormalization4d since it is deprecated. +export const batchNormalization4d = op({batchNormalization4d_}); +export const batchNorm4d = op({batchNorm4d_}); diff --git a/tfjs-core/src/ops/batchnorm_util.ts b/tfjs-core/src/ops/batchnorm_util.ts new file mode 100644 index 00000000000..715a28579d3 --- /dev/null +++ b/tfjs-core/src/ops/batchnorm_util.ts @@ -0,0 +1,41 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {deprecationWarn} from '../globals'; +import {Tensor, Tensor4D} from '../tensor'; +import {Rank} from '../types'; + +export function warnDeprecation(): void { + deprecationWarn( + 'tf.batchNormalization() is going away. ' + + 'Use tf.batchNorm() instead, and note the positional argument change ' + + 'of scale, offset, and varianceEpsilon'); +} + +export function xAs4D(x: Tensor) { + let x4D: Tensor4D; + if (x.rank === 0 || x.rank === 1) { + x4D = x.as4D(1, 1, 1, x.size); + } else if (x.rank === 2) { + x4D = x.as4D(1, 1, x.shape[0], x.shape[1]); + } else if (x.rank === 3) { + x4D = x.as4D(1, x.shape[0], x.shape[1], x.shape[2]); + } else { + x4D = x as Tensor4D; + } + + return x4D; +} diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 404e85097c4..0f45e6e42bf 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -16,6 +16,10 @@ */ // Modularized ops. +export {batchNormalization, batchNorm} from './batchnorm'; +export {batchNormalization2d, batchNorm2d} from './batchnorm2d'; +export {batchNormalization3d, batchNorm3d} from './batchnorm3d'; +export {batchNormalization4d, batchNorm4d} from './batchnorm4d'; export {broadcastTo} from './broadcast_to'; export {clone} from './clone'; export {div} from './div'; @@ -37,7 +41,6 @@ export {squaredDifference} from './squared_difference'; export {tile} from './tile'; export {truncatedNormal} from './truncated_normal'; -export * from './batchnorm'; export * from './boolean_mask'; export * from './complex_ops'; export * from './concat_split'; diff --git a/tfjs-core/src/public/chained_ops/batchnorm.ts b/tfjs-core/src/public/chained_ops/batchnorm.ts new file mode 100644 index 00000000000..ce621c8eb04 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/batchnorm.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {batchNorm} from '../../ops/batchnorm'; +import {Tensor, Tensor1D} from '../../tensor'; +import {Rank, TensorLike} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + batchNorm( + mean: Tensor | Tensor1D | TensorLike, + variance: Tensor | Tensor1D | TensorLike, + offset?: Tensor | Tensor1D | TensorLike, + scale?: Tensor | Tensor1D | TensorLike, + varianceEpsilon?: number + ): Tensor; + } +} + +Tensor.prototype.batchNorm = function ( + this: Tensor | TensorLike, mean: Tensor | Tensor1D | TensorLike, + variance: Tensor | Tensor1D | TensorLike, + offset?: Tensor | Tensor1D | TensorLike, + scale?: Tensor | Tensor1D | TensorLike, + varianceEpsilon?: number): Tensor { + return batchNorm(this, mean, variance, offset, scale, varianceEpsilon); +}; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index d398ba1944d..ca8aa8194db 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -23,3 +23,4 @@ import './tile'; import './one_hot'; import './transpose'; import './pad'; +import './batchnorm'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 92aa87067c0..8f1630653d2 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -25,7 +25,7 @@ import {ALL_ENVS, describeWithFlags} from '../../jasmine_util'; const CHAINED_OPS = [ 'square', 'broadcastTo', 'tile', 'oneHot', 'div', 'divNoNan', 'transpose', - 'pad' + 'pad', 'batchNorm' ]; describeWithFlags('chained ops', ALL_ENVS, () => { @@ -34,7 +34,7 @@ describeWithFlags('chained ops', ALL_ENVS, () => { for (const opName of CHAINED_OPS) { //@ts-ignore expect(typeof tensor[opName]) - .toBe('function', `${opName} chained op not found`); + .toBe('function', `${opName} chained op not found`); } }); }); diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index a35522b2998..35974b826ac 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -16,6 +16,7 @@ */ import {broadcastToGradConfig} from './gradients/BroadcastTo_grad'; import {divGradConfig} from './gradients/Div_grad'; +import {fusedBatchNormGradConfig} from './gradients/FusedBatchNorm_grad'; import {identityGradConfig} from './gradients/Identity_grad'; import {oneHotGradConfig} from './gradients/OneHot_grad'; import {padV2GradConfig} from './gradients/PadV2_grad'; @@ -28,9 +29,9 @@ import {registerGradient} from './kernel_registry'; // Export all kernel configs here so that the package can auto register them const gradConfigs: GradConfig[] = [ - divGradConfig, squareGradConfig, squaredDifferenceGradConfig, - broadcastToGradConfig, identityGradConfig, tileGradConfig, oneHotGradConfig, - transposeGradConfig, padV2GradConfig + broadcastToGradConfig, divGradConfig, fusedBatchNormGradConfig, + identityGradConfig, oneHotGradConfig, padV2GradConfig, squareGradConfig, + squaredDifferenceGradConfig, tileGradConfig, transposeGradConfig ]; for (const gradientConfig of gradConfigs) { diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index c7036e2f3a5..83345f02555 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -195,12 +195,6 @@ export interface OpHandler { concat(tensors: Array, axis: number): T; stack(tensors: Array, axis: number): Tensor; unstack(value: T, axis: number): Tensor[]; - batchNorm( - x: Tensor, mean: Tensor|Tensor1D|TensorLike, - variance: Tensor|Tensor1D|TensorLike, - offset?: Tensor|Tensor1D|TensorLike, - scale?: Tensor|Tensor1D|TensorLike, - varianceEpsilon?: number): Tensor; all(x: Tensor, axis: number|number[], keepDims: boolean): T; any(x: Tensor, axis: number|number[], keepDims: boolean): T; logSumExp( @@ -833,17 +827,6 @@ export class Tensor { return this.batchNorm(mean, variance, offset, scale, varianceEpsilon); } - batchNorm( - mean: Tensor|Tensor1D|TensorLike, - variance: Tensor|Tensor1D|TensorLike, - offset?: Tensor|Tensor1D|TensorLike, - scale?: Tensor|Tensor1D|TensorLike, - varianceEpsilon = .001, - ): Tensor { - this.throwIfDisposed(); - return opHandler.batchNorm( - this, mean, variance, offset, scale, varianceEpsilon); - } // Reduction ops. all(axis: number|number[] = null, keepDims = false): T { this.throwIfDisposed();