From 104ff306ae1a4dced35601bfce12966b921aedca Mon Sep 17 00:00:00 2001 From: NALLEIN Date: Wed, 8 Jan 2020 13:36:54 +0800 Subject: [PATCH 1/8] webgpu Add batchNormal kernel --- tfjs-backend-webgpu/src/backend_webgpu.ts | 25 ++ .../src/kernels/batchnorm_webgpu.ts | 77 ++++++ tfjs-backend-webgpu/src/setup_test.ts | 249 +----------------- 3 files changed, 103 insertions(+), 248 deletions(-) create mode 100644 tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index 6fff463bcd9..15f4b79bc5f 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -24,6 +24,7 @@ import {Glslang} from '@webgpu/glslang/dist/web-devel/glslang.onefile'; import {BufferManager} from './buffer_manager'; import {ArgMinMaxProgram} from './kernels/argminmax_webgpu'; +import {BatchNormProgram} from './kernels/batchnorm_webgpu'; import * as binary_op from './kernels/binary_op_webgpu'; import {BinaryOpProgram} from './kernels/binary_op_webgpu'; import {ClipProgram} from './kernels/clip_webgpu'; @@ -549,6 +550,30 @@ export class WebGPUBackend extends KernelBackend { input.size < sizeThreshold); } + batchNormalization( + x: Tensor4D, mean: Tensor4D|Tensor1D, variance: Tensor4D|Tensor1D, + varianceEpsilon: number, scale?: Tensor4D|Tensor1D, + offset?: Tensor4D|Tensor1D): Tensor4D { + const inputs = [x, mean, variance]; + + let offsetShape = null; + if (offset != null) { + offsetShape = offset.shape; + inputs.push(offset); + } + + let scaleShape = null; + if (scale != null) { + scaleShape = scale.shape; + inputs.push(scale); + } + + const batchNormProgram = new BatchNormProgram( + x.shape, mean.shape, variance.shape, offsetShape, scaleShape, + varianceEpsilon); + return this.compileAndRun(batchNormProgram, inputs); + } + pad( x: T, paddings: Array<[number, number]>, constantValue: number): T { const program = new PadProgram(x.shape, paddings, constantValue); diff --git a/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts b/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts new file mode 100644 index 00000000000..8fe2ef7e534 --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts @@ -0,0 +1,77 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util} from '@tensorflow/tfjs-core' + +import {computeDispatch} from '../webgpu_util'; + +import {WebGPUProgram} from './webgpu_program'; + +export class BatchNormProgram implements WebGPUProgram { + outputShape: number[]; + shaderKey: string; + userCode: string; + dispatchLayout: {x: number[], y: number[], z: number[]}; + dispatch: [number, number, number]; + variableNames: string[]; + workGroupSize: [4, 4, 4]; + + constructor( + xShape: number[], meanShape: number[], varianceShape: number[], + offsetShape: number[]|null, scaleShape: number[]|null, + varianceEpsilon: number) { + this.variableNames = ['x', 'mean', 'variance']; + backend_util.assertAndGetBroadcastShape(xShape, meanShape); + backend_util.assertAndGetBroadcastShape(xShape, varianceShape); + this.outputShape = xShape; + this.dispatchLayout = {x: [0, 1], y: [2], z: [3]}; + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workGroupSize); + + let offsetSnippet = '0.0'; + if (offsetShape != null) { + backend_util.assertAndGetBroadcastShape(xShape, offsetShape); + this.variableNames.push('offset'); + offsetSnippet = 'getOffsetAtOutCoords()'; + } + + let scaleSnippet = '1.0'; + if (scaleShape != null) { + backend_util.assertAndGetBroadcastShape(xShape, scaleShape); + this.variableNames.push('scale'); + scaleSnippet = 'getScaleAtOutCoords()'; + } + + this.userCode = ` + void writeResult(ivec4 coords,float value) { + if (coordsInBounds(coords, outShape)) { + setOutput(coords[0], coords[1], coords[2], coords[3], value); + } + } + void main() { + ivec4 coords = getOutputCoords(); + float x = getXAtOutCoords(); + float mean = getMeanAtOutCoords(); + float variance = getVarianceAtOutCoords(); + float offset = ${offsetSnippet}; + float scale = ${scaleSnippet}; + float inv = scale * inversesqrt(variance + float(${varianceEpsilon})); + writeResult(coords,dot(vec3(x, -mean, offset), vec3(inv, inv, 1))); + } + `; + } +} diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index ed08a305af6..4f95d62e816 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -84,254 +84,7 @@ setTestEnvs([{ }]); const TEST_FILTERS: TestFilter[] = [ - { - include: 'less', - excludes: [ - 'upcasts when dtypes dont match', // Actual != expected. - 'NaNs in', // Actual != expected. - 'broadcasting Tensor2D shapes', // Actual != expected. - 'derivat', // logicalAnd not yet implemented. - ] - }, - { - include: 'clip', - excludes: [ - 'derivat', // logicalAnd not yet implemented. - 'gradient', // logicalAnd not yet implemented. - ] - }, - { - include: 'greater', - excludes: [ - 'upcasts when dtypes dont match', // Actual != expected. - 'NaNs in', // Actual != expected. - 'broadcasting Tensor2D shapes', // Actual != expected. - 'works with 0 sized tensors', // Timeout. - 'gradient', // zerosLike not yet implemented. - ] - }, - { - include: 'div', - excludes: [ - 'broadcast 2D + 1D', // Actual != expected. - 'upcasts when dtypes dont match', // Actual != expected. - 'gradient', // square, sum not yet implemented. - 'divNoNan' // Equal not yet implemented. - ] - }, - { - include: 'depthwise', - excludes: [ - 'gradient', // depthwiseConv2DDerInput not yet implemented. - 'fused', // Not yet implemented. - ] - }, - { - include: 'fromPixels', - excludes: [ - 'HTMLVideolement', // Failed to execute 'getImageData' on - // 'CanvasRenderingContext2D': The source width is 0 - ] - }, - { - include: 'argmax', - excludes: [ - '5D', // Rank 5 is not yet implemented. - '6D', // Rank 5 is not yet implemented. - 'accepts tensor with bool', // Actual != Expected. - 'gradient', // zerosLike not yet implemented. - ] - }, - { - include: 'concat', - excludes: [ - 'complex', // No complex support yet. - 'concat a large number of tensors', // Actual != Expected. - 'gradient', // split not yet implemented. - ] - }, - { - include: 'transpose', - excludes: [ - 'oneHot', // Not yet implemented. - 'fused', // Not yet implemented. - 'batched matmul', // Actual != expected, shape mismatch. - 'shape has ones', // Actual != expected. - '5D', // Rank 5 is not yet implemented. - '6D', // Rank 5 is not yet implemented. - ] - }, - { - include: 'relu', - excludes: [ - 'valueAndGradients', // sum not yet implemented. - 'gradient', // sum not yet implemented. - 'fused', // Not yet implemented. - '5D', // Rank 5 is not yet implemented. - '6D', // Rank 5 is not yet implemented. - 'propagates NaNs', // Arrays differ. - 'derivative', // sum not yet implemented. - 'gradient with clones', // sum not yet implemented. - 'derivative where alpha got broadcasted', // sum not yet implemented. - ] - }, - { - include: 'resizeBilinear', - excludes: [ - 'gradient', // Not yet implemented. - 'works for ints' // Actual != expected. - ] - }, - {include: 'floor divide ', excludes: []}, - { - include: 'fused', - excludes: [ - 'A x B', // fusedBatchMatMul not yet implemented. - 'elu', // elu not yet implemented. - 'A x B with bias only', // fusedBatchMatMul not yet implemented. - 'basic with bias', // Actual != expected. - 'gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0', // conv2dDerInput not yet - // implemented. - 'gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias', // conv2dDerInput - // not yet - // implemented. - ] - }, - { - include: 'maxPool', - excludes: [ - 'maxPoolBackprop', // Not yet implemented. - 'maxPool3d', // Not yet implemented. - ] - }, - { - include: 'pool', - excludes: [ - 'avg x=[', // backend.avgPool not implemented. - 'max x=[4,3,1] f=[2,2] s=1 d=2', // spaceToBatchND not yet implemented. - 'max x=[2,4,4,1] f=[2,2] s=1 d=2', // spaceToBatchND not yet implemented. - 'poolBackprop', // maxPoolBackprop not yet implemented. - ] - }, - { - include: 'matmul', - excludes: [ - 'matmulBatch', // Shape mismatch. - 'fused matmul', // FusedMatmul not yet implemented. - 'gradient', // Various: sum not yet implemented. - 'has zero in its shape', // Test times out. - 'valueAndGradients', // backend.sum() not yet implemented. - 'upcasts when dtypes dont match', // Missing cast(). - 'batched matmul', // Actual != expected, shape mismatch. - ] - }, - { - include: 'add ', - excludes: [ - 'complex', // No complex support yet. - 'upcasts when dtypes dont match', // Missing cast(). - 'accepts a tensor-like object', // Timeout. - 'broadcast inner dim of b', // Arrays differ. - '6D', // Rank 6 is not yet implemented. - 'add tensors with 0 in shape', // Timeout. - 'gradient', // sum not yet implemented. - ] - }, - {include: 'subtract ', excludes: []}, - { - include: 'slice ', - excludes: [ - 'square a sliced texture', // abs not yet implemented. - 'square a non-sliced texture', // abs not not yet implemented. - 'flatten a sliced tensor not continuous', // square not yet implemented. - 'reshape a sliced 1d into a 2d tensor and', // square not yet - // implemented. - '5D', // Rank 5 is not yet implemented. - '6D', // Rank 6 is not yet implemented. - 'strided slice with', // Rank 6 is not yet implemented. - ] - }, - { - include: 'stridedSlice', - excludes: [ - 'strided slice with several new axes', // Rank 6 is not yet implemented. - 'strided slice with new axes and', // Rank 6 is not yet implemented. - ] - }, - { - include: 'mul ', - excludes: [ - 'int32 * int32', // Actual != Expected. - 'broadcast', // Various: Actual != Expected, compile fails, etc. - 'gradient', // Various: sum not yet implemented. - 'complex', // No complex support yet. - 'upcasts when dtypes dont match', // Actual != expected. - ] - }, - { - include: 'conv2d', - excludes: [ - 'NCHW', // Not yet implemented. - 'gradient', // 'conv2dDerInput' not yet implemented - 'conv2dTranspose', // DerInput is not Implemented. - ] - }, - { - include: 'pad', - excludes: [ - 'RFFT', // 'zerosLike' not yet implemented. - 'frame', // Slice not yet implemented. - 'grad', // 'depthwiseConv2DDerFilter' not yet implemented, slice not yet - // implemented - ] - }, - { - include: 'Reduction: max', - excludes: [ - '5D', // Rank 5 is not yet implemented. - '6D', // Rank 5 is not yet implemented. - 'accepts tensor with bool', // Actual != Expected. - 'gradient', // zerosLike not yet implemented. - ] - }, - { - include: 'Reduction: min', - excludes: [ - '5D', // Rank 5 is not yet implemented. - '6D', // Rank 5 is not yet implemented. - 'accepts tensor with bool', // Actual != Expected. - 'gradient', // zerosLike not yet implemented. - ] - }, - { - include: 'Reduction: sum', - excludes: [ - 'dtype bool', // not support dtype bool yet. - '5D', // Rank 5 is not yet implemented. - '6D', // Rank 5 is not yet implemented. - 'accepts tensor with bool', // Actual != Expected. - 'gradient', // zerosLike not yet implemented. - ] - }, - { - include: 'abs', - excludes: [ - 'complex', // No complex support yet. - '5D', // Rank 5 is not yet implemented. - '6D', // Rank 5 is not yet implemented. - 'accepts tensor with bool', // Actual != Expected. - 'gradient', // zerosLike not yet implemented. - 'absoluteDifference', // absoluteDifference not yet implemented - ] - }, - { - include: 'cropAndResize', - excludes: [ - '2x2to3x3-NoCrop', // The operation failed for an operation-specific - // reason - 'MultipleBoxes-DifferentBoxes', // TimeOut - ] - } + {include: 'deprecated batchNormalization', excludes: []}, ]; const customInclude = (testName: string) => { From 977e0b87400cd9639ebbe8f0e3bf74720417214b Mon Sep 17 00:00:00 2001 From: NALLEIN Date: Wed, 8 Jan 2020 13:46:05 +0800 Subject: [PATCH 2/8] webgpu Add batchnormal kernel --- .../src/kernels/batchnorm_webgpu.ts | 2 +- tfjs-backend-webgpu/src/setup_test.ts | 250 +++++++++++++++++- 2 files changed, 250 insertions(+), 2 deletions(-) diff --git a/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts b/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts index 8fe2ef7e534..a7f704ab00d 100644 --- a/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts @@ -38,7 +38,7 @@ export class BatchNormProgram implements WebGPUProgram { backend_util.assertAndGetBroadcastShape(xShape, meanShape); backend_util.assertAndGetBroadcastShape(xShape, varianceShape); this.outputShape = xShape; - this.dispatchLayout = {x: [0, 1], y: [2], z: [3]}; + this.dispatchLayout = {x: [1, 2], y: [0], z: [3]}; this.dispatch = computeDispatch( this.dispatchLayout, this.outputShape, this.workGroupSize); diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index 4f95d62e816..2f949b2bd0b 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -84,7 +84,255 @@ setTestEnvs([{ }]); const TEST_FILTERS: TestFilter[] = [ - {include: 'deprecated batchNormalization', excludes: []}, + { + include: 'less', + excludes: [ + 'upcasts when dtypes dont match', // Actual != expected. + 'NaNs in', // Actual != expected. + 'broadcasting Tensor2D shapes', // Actual != expected. + 'derivat', // logicalAnd not yet implemented. + ] + }, + { + include: 'clip', + excludes: [ + 'derivat', // logicalAnd not yet implemented. + 'gradient', // logicalAnd not yet implemented. + ] + }, + { + include: 'greater', + excludes: [ + 'upcasts when dtypes dont match', // Actual != expected. + 'NaNs in', // Actual != expected. + 'broadcasting Tensor2D shapes', // Actual != expected. + 'works with 0 sized tensors', // Timeout. + 'gradient', // zerosLike not yet implemented. + ] + }, + { + include: 'div', + excludes: [ + 'broadcast 2D + 1D', // Actual != expected. + 'upcasts when dtypes dont match', // Actual != expected. + 'gradient', // square, sum not yet implemented. + 'divNoNan' // Equal not yet implemented. + ] + }, + { + include: 'depthwise', + excludes: [ + 'gradient', // depthwiseConv2DDerInput not yet implemented. + 'fused', // Not yet implemented. + ] + }, + { + include: 'fromPixels', + excludes: [ + 'HTMLVideolement', // Failed to execute 'getImageData' on + // 'CanvasRenderingContext2D': The source width is 0 + ] + }, + { + include: 'argmax', + excludes: [ + '5D', // Rank 5 is not yet implemented. + '6D', // Rank 5 is not yet implemented. + 'accepts tensor with bool', // Actual != Expected. + 'gradient', // zerosLike not yet implemented. + ] + }, + { + include: 'concat', + excludes: [ + 'complex', // No complex support yet. + 'concat a large number of tensors', // Actual != Expected. + 'gradient', // split not yet implemented. + ] + }, + { + include: 'transpose', + excludes: [ + 'oneHot', // Not yet implemented. + 'fused', // Not yet implemented. + 'batched matmul', // Actual != expected, shape mismatch. + 'shape has ones', // Actual != expected. + '5D', // Rank 5 is not yet implemented. + '6D', // Rank 5 is not yet implemented. + ] + }, + { + include: 'relu', + excludes: [ + 'valueAndGradients', // sum not yet implemented. + 'gradient', // sum not yet implemented. + 'fused', // Not yet implemented. + '5D', // Rank 5 is not yet implemented. + '6D', // Rank 5 is not yet implemented. + 'propagates NaNs', // Arrays differ. + 'derivative', // sum not yet implemented. + 'gradient with clones', // sum not yet implemented. + 'derivative where alpha got broadcasted', // sum not yet implemented. + ] + }, + { + include: 'resizeBilinear', + excludes: [ + 'gradient', // Not yet implemented. + 'works for ints' // Actual != expected. + ] + }, + {include: 'floor divide ', excludes: []}, + { + include: 'fused', + excludes: [ + 'A x B', // fusedBatchMatMul not yet implemented. + 'elu', // elu not yet implemented. + 'A x B with bias only', // fusedBatchMatMul not yet implemented. + 'basic with bias', // Actual != expected. + 'gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0', // conv2dDerInput not yet + // implemented. + 'gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias', // conv2dDerInput + // not yet + // implemented. + ] + }, + { + include: 'maxPool', + excludes: [ + 'maxPoolBackprop', // Not yet implemented. + 'maxPool3d', // Not yet implemented. + ] + }, + { + include: 'pool', + excludes: [ + 'avg x=[', // backend.avgPool not implemented. + 'max x=[4,3,1] f=[2,2] s=1 d=2', // spaceToBatchND not yet implemented. + 'max x=[2,4,4,1] f=[2,2] s=1 d=2', // spaceToBatchND not yet implemented. + 'poolBackprop', // maxPoolBackprop not yet implemented. + ] + }, + { + include: 'matmul', + excludes: [ + 'matmulBatch', // Shape mismatch. + 'fused matmul', // FusedMatmul not yet implemented. + 'gradient', // Various: sum not yet implemented. + 'has zero in its shape', // Test times out. + 'valueAndGradients', // backend.sum() not yet implemented. + 'upcasts when dtypes dont match', // Missing cast(). + 'batched matmul', // Actual != expected, shape mismatch. + ] + }, + { + include: 'add ', + excludes: [ + 'complex', // No complex support yet. + 'upcasts when dtypes dont match', // Missing cast(). + 'accepts a tensor-like object', // Timeout. + 'broadcast inner dim of b', // Arrays differ. + '6D', // Rank 6 is not yet implemented. + 'add tensors with 0 in shape', // Timeout. + 'gradient', // sum not yet implemented. + ] + }, + {include: 'subtract ', excludes: []}, + { + include: 'slice ', + excludes: [ + 'square a sliced texture', // abs not yet implemented. + 'square a non-sliced texture', // abs not not yet implemented. + 'flatten a sliced tensor not continuous', // square not yet implemented. + 'reshape a sliced 1d into a 2d tensor and', // square not yet + // implemented. + '5D', // Rank 5 is not yet implemented. + '6D', // Rank 6 is not yet implemented. + 'strided slice with', // Rank 6 is not yet implemented. + ] + }, + { + include: 'stridedSlice', + excludes: [ + 'strided slice with several new axes', // Rank 6 is not yet implemented. + 'strided slice with new axes and', // Rank 6 is not yet implemented. + ] + }, + { + include: 'mul ', + excludes: [ + 'int32 * int32', // Actual != Expected. + 'broadcast', // Various: Actual != Expected, compile fails, etc. + 'gradient', // Various: sum not yet implemented. + 'complex', // No complex support yet. + 'upcasts when dtypes dont match', // Actual != expected. + ] + }, + { + include: 'conv2d', + excludes: [ + 'NCHW', // Not yet implemented. + 'gradient', // 'conv2dDerInput' not yet implemented + 'conv2dTranspose', // DerInput is not Implemented. + ] + }, + { + include: 'pad', + excludes: [ + 'RFFT', // 'zerosLike' not yet implemented. + 'frame', // Slice not yet implemented. + 'grad', // 'depthwiseConv2DDerFilter' not yet implemented, slice not yet + // implemented + ] + }, + { + include: 'Reduction: max', + excludes: [ + '5D', // Rank 5 is not yet implemented. + '6D', // Rank 5 is not yet implemented. + 'accepts tensor with bool', // Actual != Expected. + 'gradient', // zerosLike not yet implemented. + ] + }, + { + include: 'Reduction: min', + excludes: [ + '5D', // Rank 5 is not yet implemented. + '6D', // Rank 5 is not yet implemented. + 'accepts tensor with bool', // Actual != Expected. + 'gradient', // zerosLike not yet implemented. + ] + }, + { + include: 'Reduction: sum', + excludes: [ + 'dtype bool', // not support dtype bool yet. + '5D', // Rank 5 is not yet implemented. + '6D', // Rank 5 is not yet implemented. + 'accepts tensor with bool', // Actual != Expected. + 'gradient', // zerosLike not yet implemented. + ] + }, + { + include: 'abs', + excludes: [ + 'complex', // No complex support yet. + '5D', // Rank 5 is not yet implemented. + '6D', // Rank 5 is not yet implemented. + 'accepts tensor with bool', // Actual != Expected. + 'gradient', // zerosLike not yet implemented. + 'absoluteDifference', // absoluteDifference not yet implemented + ] + }, + { + include: 'cropAndResize', + excludes: [ + '2x2to3x3-NoCrop', // The operation failed for an operation-specific + // reason + 'MultipleBoxes-DifferentBoxes', // TimeOut + ] + }, + {include: 'deprecated batchNormalization', excludes: []} ]; const customInclude = (testName: string) => { From 8d5360629b4d13dd754bcc5c620d47c0a444321a Mon Sep 17 00:00:00 2001 From: NALLEIN Date: Wed, 8 Jan 2020 13:59:23 +0800 Subject: [PATCH 3/8] webgpu Add batchnormal kernel --- tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts b/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts index a7f704ab00d..fbb5bebb90e 100644 --- a/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {backend_util} from '@tensorflow/tfjs-core' +import {backend_util} from '@tensorflow/tfjs-core'; import {computeDispatch} from '../webgpu_util'; From 1184e2ac72fac34856ce4bd752c8bc2b113d0c3d Mon Sep 17 00:00:00 2001 From: NALLEIN Date: Tue, 7 Apr 2020 00:11:09 +0800 Subject: [PATCH 4/8] Modularize batchNorm --- tfjs-backend-webgpu/src/backend_webgpu.ts | 25 ----------- .../src/kernels/BatchNormal.ts | 45 +++++++++++++++++++ tfjs-backend-webgpu/src/kernels/Rsqrt.ts | 31 +++++++++++++ .../src/kernels/unary_op_webgpu.ts | 1 + .../src/register_all_kernels.ts | 4 +- tfjs-backend-webgpu/src/setup_test.ts | 13 +++++- tfjs-core/src/kernel_names.ts | 3 ++ 7 files changed, 95 insertions(+), 27 deletions(-) create mode 100644 tfjs-backend-webgpu/src/kernels/BatchNormal.ts create mode 100644 tfjs-backend-webgpu/src/kernels/Rsqrt.ts diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index caa9d4ccf68..55e3d38eb27 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -24,7 +24,6 @@ import {Glslang} from '@webgpu/glslang/dist/web-devel/glslang.onefile'; import {BufferManager} from './buffer_manager'; import {ArgMinMaxProgram} from './kernels/argminmax_webgpu'; -import {BatchNormProgram} from './kernels/batchnorm_webgpu'; import * as binary_op from './kernels/binary_op_webgpu'; import {BinaryOpProgram} from './kernels/binary_op_webgpu'; import {ClipProgram} from './kernels/clip_webgpu'; @@ -553,30 +552,6 @@ export class WebGPUBackend extends KernelBackend { input.size < sizeThreshold); } - batchNormalization( - x: Tensor4D, mean: Tensor4D|Tensor1D, variance: Tensor4D|Tensor1D, - varianceEpsilon: number, scale?: Tensor4D|Tensor1D, - offset?: Tensor4D|Tensor1D): Tensor4D { - const inputs = [x, mean, variance]; - - let offsetShape = null; - if (offset != null) { - offsetShape = offset.shape; - inputs.push(offset); - } - - let scaleShape = null; - if (scale != null) { - scaleShape = scale.shape; - inputs.push(scale); - } - - const batchNormProgram = new BatchNormProgram( - x.shape, mean.shape, variance.shape, offsetShape, scaleShape, - varianceEpsilon); - return this.compileAndRun(batchNormProgram, inputs); - } - pad( x: T, paddings: Array<[number, number]>, constantValue: number): T { const program = new PadProgram(x.shape, paddings, constantValue); diff --git a/tfjs-backend-webgpu/src/kernels/BatchNormal.ts b/tfjs-backend-webgpu/src/kernels/BatchNormal.ts new file mode 100644 index 00000000000..9b1ad51438d --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/BatchNormal.ts @@ -0,0 +1,45 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, FusedBatchNorm, FusedBatchNormInputs,FusedBatchNormAttrs, Tensor} from '@tensorflow/tfjs-core'; +import {WebGPUBackend} from '../backend_webgpu'; +import {BatchNormProgram} from './batchnorm_webgpu'; + +export const batchNromConfig : KernelConfig = { + kernelName: FusedBatchNorm, + backendName: 'webgpu', + kernelFunc: ({inputs,attrs,backend}) => { + const {x,scale,offset,mean,variance} = inputs as FusedBatchNormInputs; + const {varianceEpsilon} = attrs as unknown as FusedBatchNormAttrs; + const webGPUBackend = backend as WebGPUBackend; + let batchNormInputs = [x as Tensor, mean as Tensor, variance as Tensor]; + let offsetShape = null; + if(offset != null) { + offsetShape = offset.shape; + batchNormInputs.push(offset as Tensor); + } + let scaleShape = null; + if(scale != null) { + scaleShape = scale.shape; + batchNormInputs.push(scale as Tensor); + } + const program = new BatchNormProgram( + x.shape,mean.shape,variance.shape,offsetShape,scaleShape, + varianceEpsilon); + return webGPUBackend.compileAndRun(program, batchNormInputs); + } +}; \ No newline at end of file diff --git a/tfjs-backend-webgpu/src/kernels/Rsqrt.ts b/tfjs-backend-webgpu/src/kernels/Rsqrt.ts new file mode 100644 index 00000000000..7811f9001ab --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/Rsqrt.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Rsqrt, RsqrtInputs, Tensor} from '@tensorflow/tfjs-core'; +import {WebGPUBackend} from '../backend_webgpu'; +import {RSQRT, UnaryOpProgram} from './unary_op_webgpu'; + +export const rsqrtConfig: KernelConfig = { + kernelName: Rsqrt, + backendName: 'webgpu', + kernelFunc: ({inputs, backend}) => { + const {x} = inputs as RsqrtInputs; + const webGPUBackend = backend as WebGPUBackend; + const program = new UnaryOpProgram(x.shape, RSQRT); + return webGPUBackend.compileAndRun(program, [x as Tensor]); + } +}; \ No newline at end of file diff --git a/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts b/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts index 481069eb369..0295e064aeb 100644 --- a/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts @@ -29,6 +29,7 @@ export const ELU = `return (a >= 0.0) ? a : (exp(a) - 1.0);`; export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`; export const ABS = `return abs(a);`; export const SQUARE = `return a * a;`; +export const RSQRT = `return inversesqrt(a);`; export class UnaryOpProgram implements WebGPUProgram { outputShape: number[]; diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index 0661369792e..3d2ff41e902 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -16,12 +16,14 @@ */ import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; +import {batchNromConfig} from './kernels/BatchNormal'; +import {rsqrtConfig} from './kernels/Rsqrt'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; // List all kernel configs here const kernelConfigs: KernelConfig[] = - [squareConfig, squaredDifferenceConfig]; + [squareConfig, squaredDifferenceConfig, rsqrtConfig, batchNromConfig]; for (const kernelConfig of kernelConfigs) { registerKernel(kernelConfig); diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index 30ddb2d2ae0..c22705a9a38 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -353,7 +353,18 @@ const TEST_FILTERS: TestFilter[] = [ 'MultipleBoxes-DifferentBoxes', // TimeOut ] }, - {include: 'deprecated batchNormalization', excludes: []} + { + include: 'batchNorm', + excludes: [ + 'gradient', + ] + }, + { + include: 'rsqrt', + excludes: [ + 'gradient', + ] + } ]; const customInclude = (testName: string) => { diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 20c0b8b3bfb..4e9956d60d8 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -45,6 +45,9 @@ export type SquaredDifferenceInputs = BinaryInputs; export const Square = 'Square'; export type SquareInputs = Pick; +export const Rsqrt = 'Rsqrt'; +export type RsqrtInputs = Pick; + export const Transpose = 'Transpose'; export type TransposeInputs = Pick; export interface TransposeAttrs { From 6eaea89ef01465df9f12ad7caf4e64d3714fd1d9 Mon Sep 17 00:00:00 2001 From: NALLEIN Date: Wed, 15 Apr 2020 10:02:05 +0800 Subject: [PATCH 5/8] Modularize batchNorm --- .../src/kernels/BatchNormal.ts | 45 ------------------ .../src/kernels/FusedBatchNorm.ts | 47 +++++++++++++++++++ tfjs-backend-webgpu/src/kernels/Rsqrt.ts | 31 ------------ .../src/kernels/batchnorm_webgpu.ts | 30 ++++++++++-- .../src/kernels/unary_op_webgpu.ts | 1 - .../src/register_all_kernels.ts | 5 +- .../src/shader_preprocessor.ts | 4 ++ tfjs-core/src/kernel_names.ts | 3 -- 8 files changed, 80 insertions(+), 86 deletions(-) delete mode 100644 tfjs-backend-webgpu/src/kernels/BatchNormal.ts create mode 100644 tfjs-backend-webgpu/src/kernels/FusedBatchNorm.ts delete mode 100644 tfjs-backend-webgpu/src/kernels/Rsqrt.ts diff --git a/tfjs-backend-webgpu/src/kernels/BatchNormal.ts b/tfjs-backend-webgpu/src/kernels/BatchNormal.ts deleted file mode 100644 index 9b1ad51438d..00000000000 --- a/tfjs-backend-webgpu/src/kernels/BatchNormal.ts +++ /dev/null @@ -1,45 +0,0 @@ -/** - * @license - * Copyright 2020 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelConfig, FusedBatchNorm, FusedBatchNormInputs,FusedBatchNormAttrs, Tensor} from '@tensorflow/tfjs-core'; -import {WebGPUBackend} from '../backend_webgpu'; -import {BatchNormProgram} from './batchnorm_webgpu'; - -export const batchNromConfig : KernelConfig = { - kernelName: FusedBatchNorm, - backendName: 'webgpu', - kernelFunc: ({inputs,attrs,backend}) => { - const {x,scale,offset,mean,variance} = inputs as FusedBatchNormInputs; - const {varianceEpsilon} = attrs as unknown as FusedBatchNormAttrs; - const webGPUBackend = backend as WebGPUBackend; - let batchNormInputs = [x as Tensor, mean as Tensor, variance as Tensor]; - let offsetShape = null; - if(offset != null) { - offsetShape = offset.shape; - batchNormInputs.push(offset as Tensor); - } - let scaleShape = null; - if(scale != null) { - scaleShape = scale.shape; - batchNormInputs.push(scale as Tensor); - } - const program = new BatchNormProgram( - x.shape,mean.shape,variance.shape,offsetShape,scaleShape, - varianceEpsilon); - return webGPUBackend.compileAndRun(program, batchNormInputs); - } -}; \ No newline at end of file diff --git a/tfjs-backend-webgpu/src/kernels/FusedBatchNorm.ts b/tfjs-backend-webgpu/src/kernels/FusedBatchNorm.ts new file mode 100644 index 00000000000..9babb7ffd0f --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/FusedBatchNorm.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {FusedBatchNorm, FusedBatchNormAttrs, FusedBatchNormInputs, KernelConfig, Tensor} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; + +import {BatchNormProgram} from './batchnorm_webgpu'; + +export const fusedBatchNromConfig: KernelConfig = { + kernelName: FusedBatchNorm, + backendName: 'webgpu', + kernelFunc: ({inputs, attrs, backend}) => { + const {x, scale, offset, mean, variance} = inputs as FusedBatchNormInputs; + const {varianceEpsilon} = attrs as unknown as FusedBatchNormAttrs; + const webGPUBackend = backend as WebGPUBackend; + const batchNormInputs = [x as Tensor, mean as Tensor, variance as Tensor]; + let offsetShape = null; + if (offset != null) { + offsetShape = offset.shape; + batchNormInputs.push(offset as Tensor); + } + let scaleShape = null; + if (scale != null) { + scaleShape = scale.shape; + batchNormInputs.push(scale as Tensor); + } + const program = new BatchNormProgram( + x.shape, mean.shape, variance.shape, offsetShape, scaleShape, + varianceEpsilon); + return webGPUBackend.compileAndRun(program, batchNormInputs); + } +}; diff --git a/tfjs-backend-webgpu/src/kernels/Rsqrt.ts b/tfjs-backend-webgpu/src/kernels/Rsqrt.ts deleted file mode 100644 index 7811f9001ab..00000000000 --- a/tfjs-backend-webgpu/src/kernels/Rsqrt.ts +++ /dev/null @@ -1,31 +0,0 @@ -/** - * @license - * Copyright 2020 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {KernelConfig, Rsqrt, RsqrtInputs, Tensor} from '@tensorflow/tfjs-core'; -import {WebGPUBackend} from '../backend_webgpu'; -import {RSQRT, UnaryOpProgram} from './unary_op_webgpu'; - -export const rsqrtConfig: KernelConfig = { - kernelName: Rsqrt, - backendName: 'webgpu', - kernelFunc: ({inputs, backend}) => { - const {x} = inputs as RsqrtInputs; - const webGPUBackend = backend as WebGPUBackend; - const program = new UnaryOpProgram(x.shape, RSQRT); - return webGPUBackend.compileAndRun(program, [x as Tensor]); - } -}; \ No newline at end of file diff --git a/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts b/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts index fbb5bebb90e..9cab697d6b9 100644 --- a/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts @@ -17,6 +17,7 @@ import {backend_util} from '@tensorflow/tfjs-core'; +import {getCoordsDataType} from '../shader_preprocessor'; import {computeDispatch} from '../webgpu_util'; import {WebGPUProgram} from './webgpu_program'; @@ -41,6 +42,29 @@ export class BatchNormProgram implements WebGPUProgram { this.dispatchLayout = {x: [1, 2], y: [0], z: [3]}; this.dispatch = computeDispatch( this.dispatchLayout, this.outputShape, this.workGroupSize); + const dim = this.outputShape.length; + const coordsDataType = getCoordsDataType(dim); + let setOutput = + 'setOutput(coords[0], coords[1], coords[2], coords[3], value);'; + switch (dim) { + case 2: + this.dispatchLayout = {x: [1], y: [0], z: []}; + setOutput = 'setOutput(coords[0], coords[1], value);'; + break; + case 3: + this.dispatchLayout = {x: [1, 2], y: [0], z: []}; + setOutput = 'setOutput(coords[0], coords[1], coords[2], value);'; + break; + case 4: + this.dispatchLayout = {x: [1, 2], y: [0], z: [3]}; + setOutput = + 'setOutput(coords[0], coords[1], coords[2], coords[3],value);'; + break; + default: + this.dispatchLayout = {x: [1, 2], y: [0], z: [3]}; + setOutput = + 'setOutput(coords[0], coords[1], coords[2], coords[3],value);'; + } let offsetSnippet = '0.0'; if (offsetShape != null) { @@ -57,13 +81,13 @@ export class BatchNormProgram implements WebGPUProgram { } this.userCode = ` - void writeResult(ivec4 coords,float value) { + void writeResult(${coordsDataType} coords,float value) { if (coordsInBounds(coords, outShape)) { - setOutput(coords[0], coords[1], coords[2], coords[3], value); + ${setOutput} } } void main() { - ivec4 coords = getOutputCoords(); + ${coordsDataType} coords = getOutputCoords(); float x = getXAtOutCoords(); float mean = getMeanAtOutCoords(); float variance = getVarianceAtOutCoords(); diff --git a/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts b/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts index 0295e064aeb..481069eb369 100644 --- a/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts @@ -29,7 +29,6 @@ export const ELU = `return (a >= 0.0) ? a : (exp(a) - 1.0);`; export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`; export const ABS = `return abs(a);`; export const SQUARE = `return a * a;`; -export const RSQRT = `return inversesqrt(a);`; export class UnaryOpProgram implements WebGPUProgram { outputShape: number[]; diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index 3d2ff41e902..d882fc8b6cf 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -16,14 +16,13 @@ */ import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; -import {batchNromConfig} from './kernels/BatchNormal'; -import {rsqrtConfig} from './kernels/Rsqrt'; +import {fusedBatchNromConfig} from './kernels/FusedBatchNorm'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; // List all kernel configs here const kernelConfigs: KernelConfig[] = - [squareConfig, squaredDifferenceConfig, rsqrtConfig, batchNromConfig]; + [squareConfig, squaredDifferenceConfig, fusedBatchNromConfig]; for (const kernelConfig of kernelConfigs) { registerKernel(kernelConfig); diff --git a/tfjs-backend-webgpu/src/shader_preprocessor.ts b/tfjs-backend-webgpu/src/shader_preprocessor.ts index 9490afeb218..33a51f6e0dc 100644 --- a/tfjs-backend-webgpu/src/shader_preprocessor.ts +++ b/tfjs-backend-webgpu/src/shader_preprocessor.ts @@ -143,6 +143,10 @@ const SHADER_PREFIX = `#version 450 all(lessThan(coord, shape)); } + bool coordsInBounds(ivec3 coord, ivec3 shape) { + return all(greaterThanEqual(coord, ivec3(0))) && + all(lessThan(coord, shape)); + bool coordsInBounds(ivec2 coord, ivec2 shape) { return all(greaterThanEqual(coord, ivec2(0))) && all(lessThan(coord, shape)); diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 4e9956d60d8..20c0b8b3bfb 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -45,9 +45,6 @@ export type SquaredDifferenceInputs = BinaryInputs; export const Square = 'Square'; export type SquareInputs = Pick; -export const Rsqrt = 'Rsqrt'; -export type RsqrtInputs = Pick; - export const Transpose = 'Transpose'; export type TransposeInputs = Pick; export interface TransposeAttrs { From 87f37cb67904650205dc8509aed4dbede9518878 Mon Sep 17 00:00:00 2001 From: NALLEIN Date: Wed, 15 Apr 2020 10:23:19 +0800 Subject: [PATCH 6/8] Modularize batchNorm --- tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts | 4 ++-- tfjs-backend-webgpu/src/setup_test.ts | 6 ------ 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts b/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts index 9cab697d6b9..e4b98cfc441 100644 --- a/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts @@ -40,8 +40,6 @@ export class BatchNormProgram implements WebGPUProgram { backend_util.assertAndGetBroadcastShape(xShape, varianceShape); this.outputShape = xShape; this.dispatchLayout = {x: [1, 2], y: [0], z: [3]}; - this.dispatch = computeDispatch( - this.dispatchLayout, this.outputShape, this.workGroupSize); const dim = this.outputShape.length; const coordsDataType = getCoordsDataType(dim); let setOutput = @@ -65,6 +63,8 @@ export class BatchNormProgram implements WebGPUProgram { setOutput = 'setOutput(coords[0], coords[1], coords[2], coords[3],value);'; } + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workGroupSize); let offsetSnippet = '0.0'; if (offsetShape != null) { diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index c22705a9a38..25ab0159d56 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -358,12 +358,6 @@ const TEST_FILTERS: TestFilter[] = [ excludes: [ 'gradient', ] - }, - { - include: 'rsqrt', - excludes: [ - 'gradient', - ] } ]; From b7ff6e5ab08f323a48da41e38413d3e8c8799e82 Mon Sep 17 00:00:00 2001 From: NALLEIN Date: Sun, 19 Apr 2020 14:19:37 +0800 Subject: [PATCH 7/8] Fix syntax --- .../src/kernels/batchnorm_webgpu.ts | 25 ++++++------------- 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts b/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts index e4b98cfc441..525ed5b870b 100644 --- a/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/batchnorm_webgpu.ts @@ -44,24 +44,13 @@ export class BatchNormProgram implements WebGPUProgram { const coordsDataType = getCoordsDataType(dim); let setOutput = 'setOutput(coords[0], coords[1], coords[2], coords[3], value);'; - switch (dim) { - case 2: - this.dispatchLayout = {x: [1], y: [0], z: []}; - setOutput = 'setOutput(coords[0], coords[1], value);'; - break; - case 3: - this.dispatchLayout = {x: [1, 2], y: [0], z: []}; - setOutput = 'setOutput(coords[0], coords[1], coords[2], value);'; - break; - case 4: - this.dispatchLayout = {x: [1, 2], y: [0], z: [3]}; - setOutput = - 'setOutput(coords[0], coords[1], coords[2], coords[3],value);'; - break; - default: - this.dispatchLayout = {x: [1, 2], y: [0], z: [3]}; - setOutput = - 'setOutput(coords[0], coords[1], coords[2], coords[3],value);'; + if (dim === 2) { + this.dispatchLayout = {x: [1], y: [0], z: []}; + setOutput = 'setOutput(coords[0], coords[1], value);'; + } + if (dim === 3) { + this.dispatchLayout = {x: [1, 2], y: [0], z: []}; + setOutput = 'setOutput(coords[0], coords[1], coords[2], value);'; } this.dispatch = computeDispatch( this.dispatchLayout, this.outputShape, this.workGroupSize); From e58fb943b2ea3116a0da6a19e0ad9ca4ce2f63a3 Mon Sep 17 00:00:00 2001 From: NALLEIN Date: Wed, 6 May 2020 22:21:48 +0800 Subject: [PATCH 8/8] Fix name --- tfjs-backend-webgpu/src/kernels/FusedBatchNorm.ts | 2 +- tfjs-backend-webgpu/src/register_all_kernels.ts | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tfjs-backend-webgpu/src/kernels/FusedBatchNorm.ts b/tfjs-backend-webgpu/src/kernels/FusedBatchNorm.ts index 9babb7ffd0f..c68e382b348 100644 --- a/tfjs-backend-webgpu/src/kernels/FusedBatchNorm.ts +++ b/tfjs-backend-webgpu/src/kernels/FusedBatchNorm.ts @@ -21,7 +21,7 @@ import {WebGPUBackend} from '../backend_webgpu'; import {BatchNormProgram} from './batchnorm_webgpu'; -export const fusedBatchNromConfig: KernelConfig = { +export const fusedBatchNormConfig: KernelConfig = { kernelName: FusedBatchNorm, backendName: 'webgpu', kernelFunc: ({inputs, attrs, backend}) => { diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index d882fc8b6cf..f4f0367b600 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -16,13 +16,13 @@ */ import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; -import {fusedBatchNromConfig} from './kernels/FusedBatchNorm'; +import {fusedBatchNormConfig} from './kernels/FusedBatchNorm'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; // List all kernel configs here const kernelConfigs: KernelConfig[] = - [squareConfig, squaredDifferenceConfig, fusedBatchNromConfig]; + [squareConfig, squaredDifferenceConfig, fusedBatchNormConfig]; for (const kernelConfig of kernelConfigs) { registerKernel(kernelConfig);