diff --git a/tfjs-backend-cpu/src/backend_cpu.ts b/tfjs-backend-cpu/src/backend_cpu.ts index 5c4cddec857..d9ecd0bb6be 100644 --- a/tfjs-backend-cpu/src/backend_cpu.ts +++ b/tfjs-backend-cpu/src/backend_cpu.ts @@ -2306,50 +2306,6 @@ export class MathBackendCPU extends KernelBackend { return tf.tensor4d(output, x.shape, x.dtype); } - batchNorm( - x: Tensor4D, mean: Tensor4D|Tensor1D, variance: Tensor4D|Tensor1D, - offset?: Tensor4D|Tensor1D, scale?: Tensor4D|Tensor1D, - varianceEpsilon?: number): Tensor4D { - assertNotComplex([x, mean, variance, scale, offset], 'batchNorm'); - - const xVals = this.readSync(x.dataId) as TypedArray; - const mVals = this.readSync(mean.dataId) as TypedArray; - const varVals = this.readSync(variance.dataId) as TypedArray; - const sVals = scale ? this.readSync(scale.dataId) as TypedArray : - new Float32Array([1]); - const offVals = offset ? this.readSync(offset.dataId) as TypedArray : - new Float32Array([0]); - const outVals = new Float32Array(xVals.length); - - const offValsLength = offVals.length; - const sValsLength = sVals.length; - const varValsLength = varVals.length; - const mValsLength = mVals.length; - - let offi = 0; - let mi = 0; - let si = 0; - let vi = 0; - for (let i = 0; i < xVals.length; ++i) { - outVals[i] = offVals[offi++] + - (xVals[i] - mVals[mi++]) * sVals[si++] / - Math.sqrt(varVals[vi++] + varianceEpsilon); - if (offi >= offValsLength) { - offi = 0; - } - if (mi >= mValsLength) { - mi = 0; - } - if (si >= sValsLength) { - si = 0; - } - if (vi >= varValsLength) { - vi = 0; - } - } - return tf.tensor4d(outVals, x.shape); - } - localResponseNormalization4D( x: Tensor4D, depthRadius: number, bias: number, alpha: number, beta: number): Tensor4D { diff --git a/tfjs-backend-cpu/src/kernels/BatchNorm.ts b/tfjs-backend-cpu/src/kernels/BatchNorm.ts new file mode 100644 index 00000000000..0389ac7850a --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/BatchNorm.ts @@ -0,0 +1,94 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {FusedBatchNorm, FusedBatchNormAttrs, FusedBatchNormInputs, KernelConfig, KernelFunc, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; +import {assertNotComplex} from '../cpu_util'; + +export function batchNormKernelFunc(args: { + inputs: FusedBatchNormInputs, + backend: MathBackendCPU, + attrs: FusedBatchNormAttrs +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {x, scale, offset, mean, variance} = inputs; + + util.assert( + mean.shape.length === variance.shape.length, + () => 'Batch normalization gradient requires mean and variance to have ' + + 'equal ranks.'); + util.assert( + offset == null || mean.shape.length === offset.shape.length, + () => 'Batch normalization gradient requires mean and offset to have ' + + 'equal ranks.'); + util.assert( + scale == null || mean.shape.length === scale.shape.length, + () => 'Batch normalization gradient requires mean and scale to have ' + + 'equal ranks.'); + + assertNotComplex([x, mean, variance, scale, offset], 'batchNorm'); + + let {varianceEpsilon} = attrs; + if (varianceEpsilon == null) { + varianceEpsilon = 0.001; + } + + const xVals = backend.data.get(x.dataId).values as TypedArray; + const mVals = backend.data.get(mean.dataId).values as TypedArray; + const varVals = backend.data.get(variance.dataId).values as TypedArray; + const sVals = scale ? backend.data.get(scale.dataId).values as TypedArray : + new Float32Array([1]); + const offVals = offset ? + backend.data.get(offset.dataId).values as TypedArray : + new Float32Array([0]); + const outVals = new Float32Array(xVals.length); + + const offValsLength = offVals.length; + const sValsLength = sVals.length; + const varValsLength = varVals.length; + const mValsLength = mVals.length; + + let offi = 0; + let mi = 0; + let si = 0; + let vi = 0; + for (let i = 0; i < xVals.length; ++i) { + outVals[i] = offVals[offi++] + + (xVals[i] - mVals[mi++]) * sVals[si++] / + Math.sqrt(varVals[vi++] + varianceEpsilon); + if (offi >= offValsLength) { + offi = 0; + } + if (mi >= mValsLength) { + mi = 0; + } + if (si >= sValsLength) { + si = 0; + } + if (vi >= varValsLength) { + vi = 0; + } + } + return backend.makeTensorInfo(x.shape, x.dtype, outVals); +} + +export const batchNormConfig: KernelConfig = { + kernelName: FusedBatchNorm, + backendName: 'cpu', + kernelFunc: batchNormKernelFunc as {} as KernelFunc, +}; diff --git a/tfjs-backend-cpu/src/register_all_kernels.ts b/tfjs-backend-cpu/src/register_all_kernels.ts index 9403e7754aa..499a271523b 100644 --- a/tfjs-backend-cpu/src/register_all_kernels.ts +++ b/tfjs-backend-cpu/src/register_all_kernels.ts @@ -27,6 +27,7 @@ import {asinConfig} from './kernels/Asin'; import {asinhConfig} from './kernels/Asinh'; import {atanConfig} from './kernels/Atan'; import {atanhConfig} from './kernels/Atanh'; +import {batchNormConfig} from './kernels/BatchNorm'; import {castConfig} from './kernels/Cast'; import {ceilConfig} from './kernels/Ceil'; import {clipConfig} from './kernels/Clip'; @@ -94,6 +95,7 @@ const kernelConfigs: KernelConfig[] = [ asinhConfig, atanConfig, atanhConfig, + batchNormConfig, castConfig, ceilConfig, clipConfig, diff --git a/tfjs-backend-webgl/src/backend_webgl.ts b/tfjs-backend-webgl/src/backend_webgl.ts index 05f22edc232..18c238c6c1b 100644 --- a/tfjs-backend-webgl/src/backend_webgl.ts +++ b/tfjs-backend-webgl/src/backend_webgl.ts @@ -34,8 +34,6 @@ import {AddNPackedProgram} from './addn_packed_gpu'; import {ArgMinMaxProgram} from './argminmax_gpu'; import {ArgMinMaxPackedProgram} from './argminmax_packed_gpu'; import {AvgPool2DBackpropProgram, AvgPool3DBackpropProgram} from './avg_pool_backprop_gpu'; -import {BatchNormProgram} from './batchnorm_gpu'; -import {BatchNormPackedProgram} from './batchnorm_packed_gpu'; import * as binaryop_complex_gpu from './binaryop_complex_gpu'; import {BinaryOpComplexProgram} from './binaryop_complex_gpu'; import * as binaryop_gpu from './binaryop_gpu'; @@ -943,37 +941,6 @@ export class MathBackendWebGL extends KernelBackend { return this.compileAndRun(program, [a, b], a.dtype); } - batchNorm( - x: Tensor4D, mean: Tensor4D|Tensor1D, variance: Tensor4D|Tensor1D, - offset?: Tensor4D|Tensor1D, scale?: Tensor4D|Tensor1D, - varianceEpsilon?: number): Tensor4D { - const inputs = [x, mean, variance]; - - let offsetShape = null; - if (offset != null) { - offsetShape = offset.shape; - inputs.push(offset); - } - - let scaleShape = null; - if (scale != null) { - scaleShape = scale.shape; - inputs.push(scale); - } - - if (env().getBool('WEBGL_PACK_NORMALIZATION')) { - const batchNormPackedProgram = new BatchNormPackedProgram( - x.shape, mean.shape, variance.shape, offsetShape, scaleShape, - varianceEpsilon); - return this.compileAndRun(batchNormPackedProgram, inputs); - } - - const batchNormProgram = new BatchNormProgram( - x.shape, mean.shape, variance.shape, offsetShape, scaleShape, - varianceEpsilon); - return this.compileAndRun(batchNormProgram, inputs); - } - localResponseNormalization4D( x: Tensor4D, radius: number, bias: number, alpha: number, beta: number): Tensor4D { diff --git a/tfjs-backend-webgl/src/kernels/BatchNorm.ts b/tfjs-backend-webgl/src/kernels/BatchNorm.ts new file mode 100644 index 00000000000..84caf95b6d3 --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/BatchNorm.ts @@ -0,0 +1,81 @@ + +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {env, FusedBatchNorm, FusedBatchNormAttrs, FusedBatchNormInputs, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {BatchNormProgram} from '../batchnorm_gpu'; +import {BatchNormPackedProgram} from '../batchnorm_packed_gpu'; + +export const batchNormKernelFunc: (params: { + inputs: FusedBatchNormInputs, + backend: MathBackendWebGL, + attrs: FusedBatchNormAttrs +}) => TensorInfo = ({inputs, backend, attrs}) => { + const {x, mean, variance, offset, scale} = inputs; + + util.assert( + mean.shape.length === variance.shape.length, + () => 'Batch normalization gradient requires mean and variance to have ' + + 'equal ranks.'); + util.assert( + offset == null || mean.shape.length === offset.shape.length, + () => 'Batch normalization gradient requires mean and offset to have ' + + 'equal ranks.'); + util.assert( + scale == null || mean.shape.length === scale.shape.length, + () => 'Batch normalization gradient requires mean and scale to have ' + + 'equal ranks.'); + + let {varianceEpsilon} = attrs; + if (varianceEpsilon == null) { + varianceEpsilon = 0.001; + } + + const finalInputs = [x, mean, variance]; + + let offsetShape = null; + if (offset != null) { + offsetShape = offset.shape; + finalInputs.push(offset); + } + + let scaleShape = null; + if (scale != null) { + scaleShape = scale.shape; + finalInputs.push(scale); + } + + const program = env().getBool('WEBGL_PACK_NORMALIZATION') ? + new BatchNormPackedProgram( + x.shape, mean.shape, variance.shape, offsetShape, scaleShape, + varianceEpsilon) : + new BatchNormProgram( + x.shape, mean.shape, variance.shape, offsetShape, scaleShape, + varianceEpsilon); + const output = + backend.runWebGLProgram(program, finalInputs, finalInputs[0].dtype); + + return output; +}; + +export const batchNormConfig: KernelConfig = { + kernelName: FusedBatchNorm, + backendName: 'webgl', + kernelFunc: batchNormKernelFunc as {} as KernelFunc, +}; diff --git a/tfjs-backend-webgl/src/register_all_kernels.ts b/tfjs-backend-webgl/src/register_all_kernels.ts index 8ef929aa25e..ccf46782f38 100644 --- a/tfjs-backend-webgl/src/register_all_kernels.ts +++ b/tfjs-backend-webgl/src/register_all_kernels.ts @@ -17,6 +17,7 @@ import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; import {atan2Config} from './kernels/Atan2'; +import {batchNormConfig} from './kernels/BatchNorm'; import {cosConfig} from './kernels/Cos'; import {divConfig} from './kernels/Div'; import {flipLeftRightConfig} from './kernels/FlipLeftRight'; @@ -36,11 +37,11 @@ import {transposeConfig} from './kernels/Transpose'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ - atan2Config, cosConfig, maxConfig, flipLeftRightConfig, fromPixelsConfig, - divConfig, maxPoolWithArgmaxConfig, nonMaxSuppressionV3Config, - nonMaxSuppressionV4Config, nonMaxSuppressionV5Config, reshapeConfig, - rotateWithOffsetConfig, sinConfig, squareConfig, squaredDifferenceConfig, - tanConfig, transposeConfig + atan2Config, batchNormConfig, cosConfig, maxConfig, flipLeftRightConfig, + fromPixelsConfig, divConfig, maxPoolWithArgmaxConfig, + nonMaxSuppressionV3Config, nonMaxSuppressionV4Config, + nonMaxSuppressionV5Config, reshapeConfig, rotateWithOffsetConfig, sinConfig, + squareConfig, squaredDifferenceConfig, tanConfig, transposeConfig ]; for (const kernelConfig of kernelConfigs) {