From 55333499bfa4fbb1d2c99675136d5bf517b91442 Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Tue, 22 Sep 2020 11:39:18 -0700 Subject: [PATCH 1/7] batchnorm --- tfjs-backend-cpu/src/backend_cpu.ts | 44 -------- tfjs-backend-cpu/src/kernels/BatchNorm.ts | 107 +++++++++++++++++++ tfjs-backend-cpu/src/register_all_kernels.ts | 2 + 3 files changed, 109 insertions(+), 44 deletions(-) create mode 100644 tfjs-backend-cpu/src/kernels/BatchNorm.ts diff --git a/tfjs-backend-cpu/src/backend_cpu.ts b/tfjs-backend-cpu/src/backend_cpu.ts index 5c4cddec857..d9ecd0bb6be 100644 --- a/tfjs-backend-cpu/src/backend_cpu.ts +++ b/tfjs-backend-cpu/src/backend_cpu.ts @@ -2306,50 +2306,6 @@ export class MathBackendCPU extends KernelBackend { return tf.tensor4d(output, x.shape, x.dtype); } - batchNorm( - x: Tensor4D, mean: Tensor4D|Tensor1D, variance: Tensor4D|Tensor1D, - offset?: Tensor4D|Tensor1D, scale?: Tensor4D|Tensor1D, - varianceEpsilon?: number): Tensor4D { - assertNotComplex([x, mean, variance, scale, offset], 'batchNorm'); - - const xVals = this.readSync(x.dataId) as TypedArray; - const mVals = this.readSync(mean.dataId) as TypedArray; - const varVals = this.readSync(variance.dataId) as TypedArray; - const sVals = scale ? this.readSync(scale.dataId) as TypedArray : - new Float32Array([1]); - const offVals = offset ? this.readSync(offset.dataId) as TypedArray : - new Float32Array([0]); - const outVals = new Float32Array(xVals.length); - - const offValsLength = offVals.length; - const sValsLength = sVals.length; - const varValsLength = varVals.length; - const mValsLength = mVals.length; - - let offi = 0; - let mi = 0; - let si = 0; - let vi = 0; - for (let i = 0; i < xVals.length; ++i) { - outVals[i] = offVals[offi++] + - (xVals[i] - mVals[mi++]) * sVals[si++] / - Math.sqrt(varVals[vi++] + varianceEpsilon); - if (offi >= offValsLength) { - offi = 0; - } - if (mi >= mValsLength) { - mi = 0; - } - if (si >= sValsLength) { - si = 0; - } - if (vi >= varValsLength) { - vi = 0; - } - } - return tf.tensor4d(outVals, x.shape); - } - localResponseNormalization4D( x: Tensor4D, depthRadius: number, bias: number, alpha: number, beta: number): Tensor4D { diff --git a/tfjs-backend-cpu/src/kernels/BatchNorm.ts b/tfjs-backend-cpu/src/kernels/BatchNorm.ts new file mode 100644 index 00000000000..d9bdc91c9da --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/BatchNorm.ts @@ -0,0 +1,107 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {FusedBatchNorm, FusedBatchNormAttrs, FusedBatchNormInputs, KernelConfig, KernelFunc, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; +import {assertNotComplex} from '../cpu_util'; + +export function batchNormKernelFunc(args: { + inputs: FusedBatchNormInputs, + backend: MathBackendCPU, + attrs: FusedBatchNormAttrs +}): TensorInfo { + const {inputs, backend, attrs} = args; + let {varianceEpsilon} = attrs; + const {x, scale, offset, mean, variance} = inputs; + + if (varianceEpsilon == null) { + varianceEpsilon = 0.001; + } + + util.assert( + mean.shape.length === variance.shape.length, + () => 'Batch normalization gradient requires mean and variance to have ' + + 'equal ranks.'); + util.assert( + offset == null || mean.shape.length === offset.shape.length, + () => 'Batch normalization gradient requires mean and offset to have ' + + 'equal ranks.'); + util.assert( + scale == null || mean.shape.length === scale.shape.length, + () => 'Batch normalization gradient requires mean and scale to have ' + + 'equal ranks.'); + + assertNotComplex([x, mean, variance, scale, offset], 'batchNorm'); + + const xVals = backend.data.get(x.dataId).values as TypedArray; + const mVals = backend.data.get(mean.dataId).values as TypedArray; + const varVals = backend.data.get(variance.dataId).values as TypedArray; + const sVals = scale ? backend.data.get(scale.dataId).values as TypedArray : + new Float32Array([1]); + const offVals = offset ? + backend.data.get(offset.dataId).values as TypedArray : + new Float32Array([0]); + const outVals = new Float32Array(xVals.length); + + const offValsLength = offVals.length; + const sValsLength = sVals.length; + const varValsLength = varVals.length; + const mValsLength = mVals.length; + + let offi = 0; + let mi = 0; + let si = 0; + let vi = 0; + for (let i = 0; i < xVals.length; ++i) { + outVals[i] = offVals[offi++] + + (xVals[i] - mVals[mi++]) * sVals[si++] / + Math.sqrt(varVals[vi++] + varianceEpsilon); + if (offi >= offValsLength) { + offi = 0; + } + if (mi >= mValsLength) { + mi = 0; + } + if (si >= sValsLength) { + si = 0; + } + if (vi >= varValsLength) { + vi = 0; + } + } + return backend.makeTensorInfo(get4DShapeFromX(x), x.dtype, outVals); +} + +export const batchNormConfig: KernelConfig = { + kernelName: FusedBatchNorm, + backendName: 'cpu', + kernelFunc: batchNormKernelFunc as {} as KernelFunc, +}; + +function get4DShapeFromX(x: TensorInfo): number[] { + const xRank = x.shape.length; + if (xRank === 0 || xRank === 1) { + return [1, 1, 1, util.sizeFromShape(x.shape)]; + } else if (xRank === 2) { + return [1, 1, x.shape[0], x.shape[1]]; + } else if (xRank === 3) { + return [1, x.shape[0], x.shape[1], x.shape[2]]; + } else { + return x.shape; + } +} diff --git a/tfjs-backend-cpu/src/register_all_kernels.ts b/tfjs-backend-cpu/src/register_all_kernels.ts index 9403e7754aa..12cda333dd9 100644 --- a/tfjs-backend-cpu/src/register_all_kernels.ts +++ b/tfjs-backend-cpu/src/register_all_kernels.ts @@ -27,6 +27,7 @@ import {asinConfig} from './kernels/Asin'; import {asinhConfig} from './kernels/Asinh'; import {atanConfig} from './kernels/Atan'; import {atanhConfig} from './kernels/Atanh'; +import {batchNormConfig} from './kernels/Batchnorm'; import {castConfig} from './kernels/Cast'; import {ceilConfig} from './kernels/Ceil'; import {clipConfig} from './kernels/Clip'; @@ -94,6 +95,7 @@ const kernelConfigs: KernelConfig[] = [ asinhConfig, atanConfig, atanhConfig, + batchNormConfig, castConfig, ceilConfig, clipConfig, From b9e3ab83c72c6b66f16f4a35e604c04341632b49 Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Tue, 22 Sep 2020 11:43:40 -0700 Subject: [PATCH 2/7] refactor --- tfjs-backend-cpu/src/kernels/BatchNorm.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tfjs-backend-cpu/src/kernels/BatchNorm.ts b/tfjs-backend-cpu/src/kernels/BatchNorm.ts index d9bdc91c9da..4bc2d91711b 100644 --- a/tfjs-backend-cpu/src/kernels/BatchNorm.ts +++ b/tfjs-backend-cpu/src/kernels/BatchNorm.ts @@ -26,13 +26,8 @@ export function batchNormKernelFunc(args: { attrs: FusedBatchNormAttrs }): TensorInfo { const {inputs, backend, attrs} = args; - let {varianceEpsilon} = attrs; const {x, scale, offset, mean, variance} = inputs; - if (varianceEpsilon == null) { - varianceEpsilon = 0.001; - } - util.assert( mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' + @@ -48,6 +43,11 @@ export function batchNormKernelFunc(args: { assertNotComplex([x, mean, variance, scale, offset], 'batchNorm'); + let {varianceEpsilon} = attrs; + if (varianceEpsilon == null) { + varianceEpsilon = 0.001; + } + const xVals = backend.data.get(x.dataId).values as TypedArray; const mVals = backend.data.get(mean.dataId).values as TypedArray; const varVals = backend.data.get(variance.dataId).values as TypedArray; From ebd39097a7e24c11879e33308db614f4743c4861 Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Tue, 22 Sep 2020 16:20:42 -0700 Subject: [PATCH 3/7] batchnorm in webgl --- tfjs-backend-webgl/src/backend_webgl.ts | 62 ++++--- tfjs-backend-webgl/src/kernels/BatchNorm.ts | 155 ++++++++++++++++++ .../src/register_all_kernels.ts | 11 +- 3 files changed, 191 insertions(+), 37 deletions(-) create mode 100644 tfjs-backend-webgl/src/kernels/BatchNorm.ts diff --git a/tfjs-backend-webgl/src/backend_webgl.ts b/tfjs-backend-webgl/src/backend_webgl.ts index 05f22edc232..063a5ded2bd 100644 --- a/tfjs-backend-webgl/src/backend_webgl.ts +++ b/tfjs-backend-webgl/src/backend_webgl.ts @@ -34,8 +34,6 @@ import {AddNPackedProgram} from './addn_packed_gpu'; import {ArgMinMaxProgram} from './argminmax_gpu'; import {ArgMinMaxPackedProgram} from './argminmax_packed_gpu'; import {AvgPool2DBackpropProgram, AvgPool3DBackpropProgram} from './avg_pool_backprop_gpu'; -import {BatchNormProgram} from './batchnorm_gpu'; -import {BatchNormPackedProgram} from './batchnorm_packed_gpu'; import * as binaryop_complex_gpu from './binaryop_complex_gpu'; import {BinaryOpComplexProgram} from './binaryop_complex_gpu'; import * as binaryop_gpu from './binaryop_gpu'; @@ -943,36 +941,36 @@ export class MathBackendWebGL extends KernelBackend { return this.compileAndRun(program, [a, b], a.dtype); } - batchNorm( - x: Tensor4D, mean: Tensor4D|Tensor1D, variance: Tensor4D|Tensor1D, - offset?: Tensor4D|Tensor1D, scale?: Tensor4D|Tensor1D, - varianceEpsilon?: number): Tensor4D { - const inputs = [x, mean, variance]; - - let offsetShape = null; - if (offset != null) { - offsetShape = offset.shape; - inputs.push(offset); - } - - let scaleShape = null; - if (scale != null) { - scaleShape = scale.shape; - inputs.push(scale); - } - - if (env().getBool('WEBGL_PACK_NORMALIZATION')) { - const batchNormPackedProgram = new BatchNormPackedProgram( - x.shape, mean.shape, variance.shape, offsetShape, scaleShape, - varianceEpsilon); - return this.compileAndRun(batchNormPackedProgram, inputs); - } - - const batchNormProgram = new BatchNormProgram( - x.shape, mean.shape, variance.shape, offsetShape, scaleShape, - varianceEpsilon); - return this.compileAndRun(batchNormProgram, inputs); - } + // batchNorm( + // x: Tensor4D, mean: Tensor4D|Tensor1D, variance: Tensor4D|Tensor1D, + // offset?: Tensor4D|Tensor1D, scale?: Tensor4D|Tensor1D, + // varianceEpsilon?: number): Tensor4D { + // const inputs = [x, mean, variance]; + + // let offsetShape = null; + // if (offset != null) { + // offsetShape = offset.shape; + // inputs.push(offset); + // } + + // let scaleShape = null; + // if (scale != null) { + // scaleShape = scale.shape; + // inputs.push(scale); + // } + + // if (env().getBool('WEBGL_PACK_NORMALIZATION')) { + // const batchNormPackedProgram = new BatchNormPackedProgram( + // x.shape, mean.shape, variance.shape, offsetShape, scaleShape, + // varianceEpsilon); + // return this.compileAndRun(batchNormPackedProgram, inputs); + // } + + // const batchNormProgram = new BatchNormProgram( + // x.shape, mean.shape, variance.shape, offsetShape, scaleShape, + // varianceEpsilon); + // return this.compileAndRun(batchNormProgram, inputs); + // } localResponseNormalization4D( x: Tensor4D, radius: number, bias: number, alpha: number, diff --git a/tfjs-backend-webgl/src/kernels/BatchNorm.ts b/tfjs-backend-webgl/src/kernels/BatchNorm.ts new file mode 100644 index 00000000000..d7225f3fffc --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/BatchNorm.ts @@ -0,0 +1,155 @@ + +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {env, FusedBatchNorm, FusedBatchNormAttrs, FusedBatchNormInputs, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {BatchNormProgram} from '../batchnorm_gpu'; +import {BatchNormPackedProgram} from '../batchnorm_packed_gpu'; + +import {reshape} from './Reshape'; + +export const batchNormKernelFunc: (params: { + inputs: FusedBatchNormInputs, + backend: MathBackendWebGL, + attrs: FusedBatchNormAttrs +}) => TensorInfo | TensorInfo[] = ({inputs, backend, attrs}) => { + const {x, mean, variance, offset, scale} = inputs; + + util.assert( + mean.shape.length === variance.shape.length, + () => 'Batch normalization gradient requires mean and variance to have ' + + 'equal ranks.'); + util.assert( + offset == null || mean.shape.length === offset.shape.length, + () => 'Batch normalization gradient requires mean and offset to have ' + + 'equal ranks.'); + util.assert( + scale == null || mean.shape.length === scale.shape.length, + () => 'Batch normalization gradient requires mean and scale to have ' + + 'equal ranks.'); + + let {varianceEpsilon} = attrs; + if (varianceEpsilon == null) { + varianceEpsilon = 0.001; + } + + const $x: TensorInfo = xAs4D(x, backend); + const $mean = as1DOr4D(mean, backend); + const $variance = as1DOr4D(variance, backend); + const $offset = as1DOr4D(offset, backend); + const $scale = as1DOr4D(scale, backend); + + const finalInputs = [$x, $mean, $variance]; + + let offsetShape = null; + if ($offset != null) { + offsetShape = $offset.shape; + finalInputs.push($offset); + } + + let scaleShape = null; + if ($scale != null) { + scaleShape = $scale.shape; + finalInputs.push($scale); + } + + const program = env().getBool('WEBGL_PACK_NORMALIZATION') ? + new BatchNormPackedProgram( + $x.shape, $mean.shape, $variance.shape, offsetShape, scaleShape, + varianceEpsilon) : + new BatchNormProgram( + $x.shape, $mean.shape, $variance.shape, offsetShape, scaleShape, + varianceEpsilon); + const output = + backend.runWebGLProgram(program, finalInputs, finalInputs[0].dtype); + + backend.disposeIntermediateTensorInfo($x); + backend.disposeIntermediateTensorInfo($mean); + backend.disposeIntermediateTensorInfo($variance); + if ($offset != null) { + backend.disposeIntermediateTensorInfo($offset); + } + if ($scale != null) { + backend.disposeIntermediateTensorInfo($scale); + } + + return output; +}; + +export const batchNormConfig: KernelConfig = { + kernelName: FusedBatchNorm, + backendName: 'webgl', + kernelFunc: batchNormKernelFunc as {} as KernelFunc, +}; + +function xAs4D(x: TensorInfo, backend: MathBackendWebGL): TensorInfo { + const xRank = x.shape.length; + if (xRank === 0 || xRank === 1) { + return reshape({ + inputs: {x}, + attrs: {shape: [1, 1, 1, util.sizeFromShape(x.shape)]}, + backend, + }); + } else if (xRank === 2) { + return reshape({ + inputs: {x}, + attrs: {shape: [1, 1, x.shape[0], x.shape[1]]}, + backend, + }); + } else if (xRank === 3) { + return reshape({ + inputs: {x}, + attrs: {shape: [1, x.shape[0], x.shape[1], x.shape[2]]}, + backend, + }); + } else { + backend.incRef(x.dataId); + return {...x}; + } +} + +function as1DOr4D(x: TensorInfo, backend: MathBackendWebGL): TensorInfo { + if (x == null) { + return null; + } + const xRank = x.shape.length; + if (xRank === 0) { + return reshape({ + inputs: {x}, + attrs: {shape: [util.sizeFromShape(x.shape)]}, + backend, + }); + } else if (xRank === 2) { + return reshape({ + inputs: {x}, + attrs: {shape: [1, 1, x.shape[0], x.shape[1]]}, + backend, + }); + } else if (xRank === 3) { + // tslint:disable-next-line:no-unnecessary-type-assertion + return reshape({ + inputs: {x}, + attrs: {shape: [1, x.shape[0], x.shape[1], x.shape[2]]}, + backend, + }); + } else { + backend.incRef(x.dataId); + return {...x}; + } +} diff --git a/tfjs-backend-webgl/src/register_all_kernels.ts b/tfjs-backend-webgl/src/register_all_kernels.ts index 8ef929aa25e..ccf46782f38 100644 --- a/tfjs-backend-webgl/src/register_all_kernels.ts +++ b/tfjs-backend-webgl/src/register_all_kernels.ts @@ -17,6 +17,7 @@ import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; import {atan2Config} from './kernels/Atan2'; +import {batchNormConfig} from './kernels/BatchNorm'; import {cosConfig} from './kernels/Cos'; import {divConfig} from './kernels/Div'; import {flipLeftRightConfig} from './kernels/FlipLeftRight'; @@ -36,11 +37,11 @@ import {transposeConfig} from './kernels/Transpose'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ - atan2Config, cosConfig, maxConfig, flipLeftRightConfig, fromPixelsConfig, - divConfig, maxPoolWithArgmaxConfig, nonMaxSuppressionV3Config, - nonMaxSuppressionV4Config, nonMaxSuppressionV5Config, reshapeConfig, - rotateWithOffsetConfig, sinConfig, squareConfig, squaredDifferenceConfig, - tanConfig, transposeConfig + atan2Config, batchNormConfig, cosConfig, maxConfig, flipLeftRightConfig, + fromPixelsConfig, divConfig, maxPoolWithArgmaxConfig, + nonMaxSuppressionV3Config, nonMaxSuppressionV4Config, + nonMaxSuppressionV5Config, reshapeConfig, rotateWithOffsetConfig, sinConfig, + squareConfig, squaredDifferenceConfig, tanConfig, transposeConfig ]; for (const kernelConfig of kernelConfigs) { From 24fffa3fc98c1356bc26026a86fb43dcb384d1ed Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Tue, 22 Sep 2020 16:22:46 -0700 Subject: [PATCH 4/7] update --- tfjs-backend-webgl/src/backend_webgl.ts | 31 ------------------------- 1 file changed, 31 deletions(-) diff --git a/tfjs-backend-webgl/src/backend_webgl.ts b/tfjs-backend-webgl/src/backend_webgl.ts index 063a5ded2bd..18c238c6c1b 100644 --- a/tfjs-backend-webgl/src/backend_webgl.ts +++ b/tfjs-backend-webgl/src/backend_webgl.ts @@ -941,37 +941,6 @@ export class MathBackendWebGL extends KernelBackend { return this.compileAndRun(program, [a, b], a.dtype); } - // batchNorm( - // x: Tensor4D, mean: Tensor4D|Tensor1D, variance: Tensor4D|Tensor1D, - // offset?: Tensor4D|Tensor1D, scale?: Tensor4D|Tensor1D, - // varianceEpsilon?: number): Tensor4D { - // const inputs = [x, mean, variance]; - - // let offsetShape = null; - // if (offset != null) { - // offsetShape = offset.shape; - // inputs.push(offset); - // } - - // let scaleShape = null; - // if (scale != null) { - // scaleShape = scale.shape; - // inputs.push(scale); - // } - - // if (env().getBool('WEBGL_PACK_NORMALIZATION')) { - // const batchNormPackedProgram = new BatchNormPackedProgram( - // x.shape, mean.shape, variance.shape, offsetShape, scaleShape, - // varianceEpsilon); - // return this.compileAndRun(batchNormPackedProgram, inputs); - // } - - // const batchNormProgram = new BatchNormProgram( - // x.shape, mean.shape, variance.shape, offsetShape, scaleShape, - // varianceEpsilon); - // return this.compileAndRun(batchNormProgram, inputs); - // } - localResponseNormalization4D( x: Tensor4D, radius: number, bias: number, alpha: number, beta: number): Tensor4D { From 154233420c49e4f3ad34fa78074fb729a48f392e Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Tue, 22 Sep 2020 16:24:23 -0700 Subject: [PATCH 5/7] remove comment --- tfjs-backend-webgl/src/kernels/BatchNorm.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/tfjs-backend-webgl/src/kernels/BatchNorm.ts b/tfjs-backend-webgl/src/kernels/BatchNorm.ts index d7225f3fffc..2ed6e8a468b 100644 --- a/tfjs-backend-webgl/src/kernels/BatchNorm.ts +++ b/tfjs-backend-webgl/src/kernels/BatchNorm.ts @@ -142,7 +142,6 @@ function as1DOr4D(x: TensorInfo, backend: MathBackendWebGL): TensorInfo { backend, }); } else if (xRank === 3) { - // tslint:disable-next-line:no-unnecessary-type-assertion return reshape({ inputs: {x}, attrs: {shape: [1, x.shape[0], x.shape[1], x.shape[2]]}, From a9e9a0dda5f37177a673d01159a438fee1d6f180 Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Wed, 23 Sep 2020 16:17:42 -0700 Subject: [PATCH 6/7] address comment --- tfjs-backend-cpu/src/kernels/BatchNorm.ts | 15 +--- tfjs-backend-webgl/src/kernels/BatchNorm.ts | 93 +++------------------ 2 files changed, 11 insertions(+), 97 deletions(-) diff --git a/tfjs-backend-cpu/src/kernels/BatchNorm.ts b/tfjs-backend-cpu/src/kernels/BatchNorm.ts index 4bc2d91711b..0389ac7850a 100644 --- a/tfjs-backend-cpu/src/kernels/BatchNorm.ts +++ b/tfjs-backend-cpu/src/kernels/BatchNorm.ts @@ -84,7 +84,7 @@ export function batchNormKernelFunc(args: { vi = 0; } } - return backend.makeTensorInfo(get4DShapeFromX(x), x.dtype, outVals); + return backend.makeTensorInfo(x.shape, x.dtype, outVals); } export const batchNormConfig: KernelConfig = { @@ -92,16 +92,3 @@ export const batchNormConfig: KernelConfig = { backendName: 'cpu', kernelFunc: batchNormKernelFunc as {} as KernelFunc, }; - -function get4DShapeFromX(x: TensorInfo): number[] { - const xRank = x.shape.length; - if (xRank === 0 || xRank === 1) { - return [1, 1, 1, util.sizeFromShape(x.shape)]; - } else if (xRank === 2) { - return [1, 1, x.shape[0], x.shape[1]]; - } else if (xRank === 3) { - return [1, x.shape[0], x.shape[1], x.shape[2]]; - } else { - return x.shape; - } -} diff --git a/tfjs-backend-webgl/src/kernels/BatchNorm.ts b/tfjs-backend-webgl/src/kernels/BatchNorm.ts index 2ed6e8a468b..84caf95b6d3 100644 --- a/tfjs-backend-webgl/src/kernels/BatchNorm.ts +++ b/tfjs-backend-webgl/src/kernels/BatchNorm.ts @@ -22,13 +22,11 @@ import {MathBackendWebGL} from '../backend_webgl'; import {BatchNormProgram} from '../batchnorm_gpu'; import {BatchNormPackedProgram} from '../batchnorm_packed_gpu'; -import {reshape} from './Reshape'; - export const batchNormKernelFunc: (params: { inputs: FusedBatchNormInputs, backend: MathBackendWebGL, attrs: FusedBatchNormAttrs -}) => TensorInfo | TensorInfo[] = ({inputs, backend, attrs}) => { +}) => TensorInfo = ({inputs, backend, attrs}) => { const {x, mean, variance, offset, scale} = inputs; util.assert( @@ -49,46 +47,30 @@ export const batchNormKernelFunc: (params: { varianceEpsilon = 0.001; } - const $x: TensorInfo = xAs4D(x, backend); - const $mean = as1DOr4D(mean, backend); - const $variance = as1DOr4D(variance, backend); - const $offset = as1DOr4D(offset, backend); - const $scale = as1DOr4D(scale, backend); - - const finalInputs = [$x, $mean, $variance]; + const finalInputs = [x, mean, variance]; let offsetShape = null; - if ($offset != null) { - offsetShape = $offset.shape; - finalInputs.push($offset); + if (offset != null) { + offsetShape = offset.shape; + finalInputs.push(offset); } let scaleShape = null; - if ($scale != null) { - scaleShape = $scale.shape; - finalInputs.push($scale); + if (scale != null) { + scaleShape = scale.shape; + finalInputs.push(scale); } const program = env().getBool('WEBGL_PACK_NORMALIZATION') ? new BatchNormPackedProgram( - $x.shape, $mean.shape, $variance.shape, offsetShape, scaleShape, + x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon) : new BatchNormProgram( - $x.shape, $mean.shape, $variance.shape, offsetShape, scaleShape, + x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon); const output = backend.runWebGLProgram(program, finalInputs, finalInputs[0].dtype); - backend.disposeIntermediateTensorInfo($x); - backend.disposeIntermediateTensorInfo($mean); - backend.disposeIntermediateTensorInfo($variance); - if ($offset != null) { - backend.disposeIntermediateTensorInfo($offset); - } - if ($scale != null) { - backend.disposeIntermediateTensorInfo($scale); - } - return output; }; @@ -97,58 +79,3 @@ export const batchNormConfig: KernelConfig = { backendName: 'webgl', kernelFunc: batchNormKernelFunc as {} as KernelFunc, }; - -function xAs4D(x: TensorInfo, backend: MathBackendWebGL): TensorInfo { - const xRank = x.shape.length; - if (xRank === 0 || xRank === 1) { - return reshape({ - inputs: {x}, - attrs: {shape: [1, 1, 1, util.sizeFromShape(x.shape)]}, - backend, - }); - } else if (xRank === 2) { - return reshape({ - inputs: {x}, - attrs: {shape: [1, 1, x.shape[0], x.shape[1]]}, - backend, - }); - } else if (xRank === 3) { - return reshape({ - inputs: {x}, - attrs: {shape: [1, x.shape[0], x.shape[1], x.shape[2]]}, - backend, - }); - } else { - backend.incRef(x.dataId); - return {...x}; - } -} - -function as1DOr4D(x: TensorInfo, backend: MathBackendWebGL): TensorInfo { - if (x == null) { - return null; - } - const xRank = x.shape.length; - if (xRank === 0) { - return reshape({ - inputs: {x}, - attrs: {shape: [util.sizeFromShape(x.shape)]}, - backend, - }); - } else if (xRank === 2) { - return reshape({ - inputs: {x}, - attrs: {shape: [1, 1, x.shape[0], x.shape[1]]}, - backend, - }); - } else if (xRank === 3) { - return reshape({ - inputs: {x}, - attrs: {shape: [1, x.shape[0], x.shape[1], x.shape[2]]}, - backend, - }); - } else { - backend.incRef(x.dataId); - return {...x}; - } -} From 64c3c29aa7f07dba7d56ae62ae3306e9ae80089e Mon Sep 17 00:00:00 2001 From: Jing Jin <8752427+jinjingforever@users.noreply.github.com> Date: Thu, 24 Sep 2020 09:42:43 -0700 Subject: [PATCH 7/7] fix typo --- tfjs-backend-cpu/src/register_all_kernels.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-backend-cpu/src/register_all_kernels.ts b/tfjs-backend-cpu/src/register_all_kernels.ts index 12cda333dd9..499a271523b 100644 --- a/tfjs-backend-cpu/src/register_all_kernels.ts +++ b/tfjs-backend-cpu/src/register_all_kernels.ts @@ -27,7 +27,7 @@ import {asinConfig} from './kernels/Asin'; import {asinhConfig} from './kernels/Asinh'; import {atanConfig} from './kernels/Atan'; import {atanhConfig} from './kernels/Atanh'; -import {batchNormConfig} from './kernels/Batchnorm'; +import {batchNormConfig} from './kernels/BatchNorm'; import {castConfig} from './kernels/Cast'; import {ceilConfig} from './kernels/Ceil'; import {clipConfig} from './kernels/Clip';