diff --git a/tfjs-backend-webgl/src/conv_backprop_packed_gpu.ts b/tfjs-backend-webgl/src/conv_backprop_packed_gpu.ts new file mode 100644 index 00000000000..ccdae6a70cf --- /dev/null +++ b/tfjs-backend-webgl/src/conv_backprop_packed_gpu.ts @@ -0,0 +1,104 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util} from '@tensorflow/tfjs-core'; +import {GPGPUProgram, useShapeUniforms} from './gpgpu_math'; + +export class Conv2DDerInputPackedProgram implements GPGPUProgram { + variableNames = ['dy', 'W']; + packedInputs = true; + packedOutput = true; + outputShape: number[]; + userCode: string; + enableShapeUniforms: boolean; + customUniforms = [ + {name: 'strides', type: 'vec2' as const }, + ]; + + constructor(convInfo: backend_util.Conv2DInfo) { + this.outputShape = convInfo.inShape; + this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); + + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + + const padTop = filterHeight - 1 - convInfo.padInfo.top; + const padLeft = filterWidth - 1 - convInfo.padInfo.left; + + this.userCode = ` + const ivec2 pads = ivec2(${padTop}, ${padLeft}); + + void main() { + ivec4 coords = getOutputCoords(); + int batch = coords[0]; + int d1 = coords[3]; + + ivec2 dyCorner = ivec2(coords[1], coords[2]) - pads; + int dyRCorner = dyCorner.x; + int dyCCorner = dyCorner.y; + + vec4 result = vec4(0.); + for (int wR = 0; wR < ${filterHeight}; wR++) { + float dyR = float(dyRCorner + wR) / strides[0]; + if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { + continue; + } + int idyR = int(dyR); + int wRPerm = ${filterHeight} - 1 - wR; + + for (int wC = 0; wC < ${filterWidth}; wC++) { + int wCPerm = ${filterWidth} - 1 - wC; + + float dyC = float(dyCCorner + wC) / strides[1]; + float idyCVal = dyC < 0.0 ? 0. : + dyC >= ${convInfo.outWidth}.0 ? 0. : + fract(dyC) > 0.0 ? 0. : 1.; + int idyC = int(dyC); + + float dyC2 = float(dyCCorner + wC + 1) / strides[1]; + float idyCVal2 = dyC2 < 0.0 ? 0. : + dyC2 >= ${convInfo.outWidth}.0 ? 0. : + fract(dyC2) > 0.0 ? 0. : 1.; + int idyC2 = int(dyC2); + + if (idyCVal + idyCVal2 == 0.) { + continue; + } + + for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) { + vec4 wValue = getW(wRPerm, wCPerm, d1, d2); + vec4 dySample = getDy(batch, idyR, idyC, d2); + vec4 dySample2 = (idyC / 2 == idyC2 / 2) ? + dySample : getDy(batch, idyR, idyC2, d2); + + vec2 dyValue = mod(float(idyC), 2.) == 0. ? + dySample.xy : dySample.zw; + result.xy += vec2(dot(dyValue, wValue.xy), + dot(dyValue, wValue.zw)) * idyCVal; + + dyValue = mod(float(idyC2), 2.) == 0. ? + dySample2.xy : dySample2.zw; + result.zw += vec2(dot(dyValue, wValue.xy), + dot(dyValue, wValue.zw)) * idyCVal2; + } + } + } + setOutput(result); + } + `; + } +} diff --git a/tfjs-backend-webgl/src/kernels/Conv2DBackpropInput.ts b/tfjs-backend-webgl/src/kernels/Conv2DBackpropInput.ts index 6de51c3ca0b..cfc853db5f1 100644 --- a/tfjs-backend-webgl/src/kernels/Conv2DBackpropInput.ts +++ b/tfjs-backend-webgl/src/kernels/Conv2DBackpropInput.ts @@ -15,10 +15,11 @@ * ============================================================================= */ -import {backend_util, Conv2DBackpropInput, Conv2DBackpropInputAttrs, Conv2DBackpropInputInputs, KernelConfig, KernelFunc} from '@tensorflow/tfjs-core'; +import {backend_util, Conv2DBackpropInput, Conv2DBackpropInputAttrs, Conv2DBackpropInputInputs, env, KernelConfig, KernelFunc} from '@tensorflow/tfjs-core'; import {MathBackendWebGL} from '../backend_webgl'; import {Conv2DDerInputProgram} from '../conv_backprop_gpu'; +import {Conv2DDerInputPackedProgram} from '../conv_backprop_packed_gpu'; export function conv2DBackpropInput(args: { inputs: Conv2DBackpropInputInputs, @@ -34,8 +35,17 @@ export function conv2DBackpropInput(args: { inputShape, filter.shape as [number, number, number, number], strides, 1 /* dilations */, pad, dimRoundingMode, false, $dataFormat); - const program = new Conv2DDerInputProgram(convInfo); - return backend.runWebGLProgram(program, [dy, filter], 'float32'); + if (env().getBool('WEBGL_PACK') && $dataFormat === 'channelsLast') { + const customValues = [ + [convInfo.strideHeight, convInfo.strideWidth], + ]; + const program = new Conv2DDerInputPackedProgram(convInfo); + return backend.runWebGLProgram( + program, [dy, filter], 'float32', customValues); + } else { + const program = new Conv2DDerInputProgram(convInfo); + return backend.runWebGLProgram(program, [dy, filter], 'float32'); + } } export const conv2DBackpropInputConfig: KernelConfig = {