diff --git a/tfjs-backend-webgpu/src/kernels/FusedBatchNorm.ts b/tfjs-backend-webgpu/src/kernels/FusedBatchNorm.ts new file mode 100644 index 00000000000..c68e382b348 --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/FusedBatchNorm.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {FusedBatchNorm, FusedBatchNormAttrs, FusedBatchNormInputs, KernelConfig, Tensor} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; + +import {BatchNormProgram} from './batchnorm_webgpu'; + +export const fusedBatchNormConfig: KernelConfig = { + kernelName: FusedBatchNorm, + backendName: 'webgpu', + kernelFunc: ({inputs, attrs, backend}) => { + const {x, scale, offset, mean, variance} = inputs as FusedBatchNormInputs; + const {varianceEpsilon} = attrs as unknown as FusedBatchNormAttrs; + const webGPUBackend = backend as WebGPUBackend; + const batchNormInputs = [x as Tensor, mean as Tensor, variance as Tensor]; + let offsetShape = null; + if (offset != null) { + offsetShape = offset.shape; + batchNormInputs.push(offset as Tensor); + } + let scaleShape = null; + if (scale != null) { + scaleShape = scale.shape; + batchNormInputs.push(scale as Tensor); + } + const program = new BatchNormProgram( + x.shape, mean.shape, variance.shape, offsetShape, scaleShape, + varianceEpsilon); + return webGPUBackend.compileAndRun(program, batchNormInputs); + } +}; diff --git a/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts b/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts new file mode 100644 index 00000000000..525ed5b870b --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts @@ -0,0 +1,90 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util} from '@tensorflow/tfjs-core'; + +import {getCoordsDataType} from '../shader_preprocessor'; +import {computeDispatch} from '../webgpu_util'; + +import {WebGPUProgram} from './webgpu_program'; + +export class BatchNormProgram implements WebGPUProgram { + outputShape: number[]; + shaderKey: string; + userCode: string; + dispatchLayout: {x: number[], y: number[], z: number[]}; + dispatch: [number, number, number]; + variableNames: string[]; + workGroupSize: [4, 4, 4]; + + constructor( + xShape: number[], meanShape: number[], varianceShape: number[], + offsetShape: number[]|null, scaleShape: number[]|null, + varianceEpsilon: number) { + this.variableNames = ['x', 'mean', 'variance']; + backend_util.assertAndGetBroadcastShape(xShape, meanShape); + backend_util.assertAndGetBroadcastShape(xShape, varianceShape); + this.outputShape = xShape; + this.dispatchLayout = {x: [1, 2], y: [0], z: [3]}; + const dim = this.outputShape.length; + const coordsDataType = getCoordsDataType(dim); + let setOutput = + 'setOutput(coords[0], coords[1], coords[2], coords[3], value);'; + if (dim === 2) { + this.dispatchLayout = {x: [1], y: [0], z: []}; + setOutput = 'setOutput(coords[0], coords[1], value);'; + } + if (dim === 3) { + this.dispatchLayout = {x: [1, 2], y: [0], z: []}; + setOutput = 'setOutput(coords[0], coords[1], coords[2], value);'; + } + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workGroupSize); + + let offsetSnippet = '0.0'; + if (offsetShape != null) { + backend_util.assertAndGetBroadcastShape(xShape, offsetShape); + this.variableNames.push('offset'); + offsetSnippet = 'getOffsetAtOutCoords()'; + } + + let scaleSnippet = '1.0'; + if (scaleShape != null) { + backend_util.assertAndGetBroadcastShape(xShape, scaleShape); + this.variableNames.push('scale'); + scaleSnippet = 'getScaleAtOutCoords()'; + } + + this.userCode = ` + void writeResult(${coordsDataType} coords,float value) { + if (coordsInBounds(coords, outShape)) { + ${setOutput} + } + } + void main() { + ${coordsDataType} coords = getOutputCoords(); + float x = getXAtOutCoords(); + float mean = getMeanAtOutCoords(); + float variance = getVarianceAtOutCoords(); + float offset = ${offsetSnippet}; + float scale = ${scaleSnippet}; + float inv = scale * inversesqrt(variance + float(${varianceEpsilon})); + writeResult(coords,dot(vec3(x, -mean, offset), vec3(inv, inv, 1))); + } + `; + } +} diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index 0661369792e..f4f0367b600 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -16,12 +16,13 @@ */ import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; +import {fusedBatchNormConfig} from './kernels/FusedBatchNorm'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; // List all kernel configs here const kernelConfigs: KernelConfig[] = - [squareConfig, squaredDifferenceConfig]; + [squareConfig, squaredDifferenceConfig, fusedBatchNormConfig]; for (const kernelConfig of kernelConfigs) { registerKernel(kernelConfig); diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index 6d3c976930e..5b290a5139e 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -354,6 +354,12 @@ const TEST_FILTERS: TestFilter[] = [ // reason 'MultipleBoxes-DifferentBoxes', // TimeOut ] + }, + { + include: 'batchNorm', + excludes: [ + 'gradient', + ] } ]; diff --git a/tfjs-backend-webgpu/src/shader_preprocessor.ts b/tfjs-backend-webgpu/src/shader_preprocessor.ts index 9490afeb218..33a51f6e0dc 100644 --- a/tfjs-backend-webgpu/src/shader_preprocessor.ts +++ b/tfjs-backend-webgpu/src/shader_preprocessor.ts @@ -143,6 +143,10 @@ const SHADER_PREFIX = `#version 450 all(lessThan(coord, shape)); } + bool coordsInBounds(ivec3 coord, ivec3 shape) { + return all(greaterThanEqual(coord, ivec3(0))) && + all(lessThan(coord, shape)); + bool coordsInBounds(ivec2 coord, ivec2 shape) { return all(greaterThanEqual(coord, ivec2(0))) && all(lessThan(coord, shape));