diff --git a/tfjs-backend-webgpu/src/conv_backprop_webgpu.ts b/tfjs-backend-webgpu/src/conv_backprop_webgpu.ts index dd820152493..e87ce221cab 100644 --- a/tfjs-backend-webgpu/src/conv_backprop_webgpu.ts +++ b/tfjs-backend-webgpu/src/conv_backprop_webgpu.ts @@ -25,26 +25,149 @@ export class Conv2DDerInputProgram implements WebGPUProgram { 'filterDims : vec2, pads : vec2, strides : vec2, outBackprop : vec4,'; outputShape: number[]; shaderKey: string; - dispatchLayout: {x: number[]}; + dispatchLayout: {x: number[], y?: number[], z?: number[]}; dispatch: [number, number, number]; workgroupSize: [number, number, number] = [64, 1, 1]; isChannelsLast: boolean; - size = true; + size = false; + isVec4 = false; + workPerThread = 1; constructor(convInfo: backend_util.Conv2DInfo) { this.outputShape = convInfo.inShape; - this.dispatchLayout = flatDispatchLayout(this.outputShape); - this.dispatch = computeDispatch( - this.dispatchLayout, this.outputShape, this.workgroupSize); this.isChannelsLast = convInfo.dataFormat === 'channelsLast'; - this.shaderKey = `conv2DDerInput_${this.isChannelsLast}`; + this.isVec4 = this.isChannelsLast && convInfo.outChannels % 4 === 0 && + convInfo.inChannels % 4 === 0; + if (this.isVec4) { + // TODO: Expand to any value. + this.workPerThread = 2; + this.workgroupSize = [4, 4, 4]; + this.dispatchLayout = {x: [3], y: [2], z: [0, 1]}; + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workgroupSize, + [4, this.workPerThread, 1]); + } else { + this.size = true; + this.workPerThread = 1; + this.workgroupSize = [64, 1, 1]; + this.dispatchLayout = flatDispatchLayout(this.outputShape); + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workgroupSize); + } + this.shaderKey = `conv2DDerInput_${this.isChannelsLast}_${this.isVec4}_${ + this.workPerThread}`; } getUserCode(): string { const rowDim = this.isChannelsLast ? 1 : 2; const colDim = this.isChannelsLast ? 2 : 3; const channelDim = this.isChannelsLast ? 3 : 1; - return ` + + const vec4Snippet = ` + ${main()} { + let batch = i32(globalId.z) / uniforms.outShape[1]; + let r = i32(globalId.z) % uniforms.outShape[1]; + let c = i32(globalId.y) * ${this.workPerThread}; + let d1 = i32(globalId.x) * 4; + + let dyCorner = vec2(r, c) - uniforms.pads; + + // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). + // ? = to be determined. : = across all values in that axis. + var dotProd: array, ${this.workPerThread}>; + for (var i = 0; i < ${this.workPerThread}; i++) { + dotProd[i] = vec4(0.0); + } + for (var wR = 0; wR < uniforms.filterDims.x; wR = wR + 1) { + let dyR = f32(dyCorner.x + wR) / f32(uniforms.strides.x); + let wRPerm = uniforms.filterDims.x - 1 - wR; + if (dyR < 0.0 || dyR >= f32(uniforms.outBackprop[1]) || + fract(dyR) > 0.0) { + continue; + } + let idyR = i32(dyR); + + for (var wC = 0; wC < uniforms.filterDims.y; wC = wC + 1) { + let dyC = f32(dyCorner.y + wC) / f32(uniforms.strides.y); + let dyC2 = f32(dyCorner.y + 1 + wC) / f32(uniforms.strides.y); + let wCPerm = uniforms.filterDims.y - 1 - wC; + var bDyCVal = true; + var bDyCVal2 = true; + if (dyC < 0.0 || dyC >= f32(uniforms.outBackprop[2]) || + fract(dyC) > 0.0) { + bDyCVal = false; + } + if (dyC2 < 0.0 || dyC2 >= f32(uniforms.outBackprop[2]) || + fract(dyC2) > 0.0) { + bDyCVal2 = false; + } + + let idyC = i32(dyC); + let idyC2 = i32(dyC2); + if (bDyCVal && bDyCVal2) { + let d2Length = uniforms.outBackprop[3]; + for (var d2 = 0; d2 < d2Length; d2 = d2 + 4) { + let wValue0 = getW(wRPerm, wCPerm, d1, d2); + let wValue1 = getW(wRPerm, wCPerm, d1 + 1, d2); + let wValue2 = getW(wRPerm, wCPerm, d1 + 2, d2); + let wValue3 = getW(wRPerm, wCPerm, d1 + 3, d2); + var xValue = getDy(batch, idyR, idyC, d2); + let tmpval = vec4(dot(xValue, wValue0), + dot(xValue, wValue1), + dot(xValue, wValue2), + dot(xValue, wValue3)); + dotProd[0] = dotProd[0] + tmpval; + xValue = getDy(batch, idyR, idyC2, d2); + dotProd[1] = dotProd[1] + vec4(dot(xValue, wValue0), + dot(xValue, wValue1), + dot(xValue, wValue2), + dot(xValue, wValue3)); + } + } else if (bDyCVal) { + let d2Length = uniforms.outBackprop[3]; + for (var d2 = 0; d2 < d2Length; d2 = d2 + 4) { + let wValue0 = getW(wRPerm, wCPerm, d1, d2); + let wValue1 = getW(wRPerm, wCPerm, d1 + 1, d2); + let wValue2 = getW(wRPerm, wCPerm, d1 + 2, d2); + let wValue3 = getW(wRPerm, wCPerm, d1 + 3, d2); + var xValue = getDy(batch, idyR, idyC, d2); + let tmpval = vec4(dot(xValue, wValue0), + dot(xValue, wValue1), + dot(xValue, wValue2), + dot(xValue, wValue3)); + dotProd[0] = dotProd[0] + tmpval; + } + } else if (bDyCVal2) { + let d2Length = uniforms.outBackprop[3]; + for (var d2 = 0; d2 < d2Length; d2 = d2 + 4) { + let wValue0 = getW(wRPerm, wCPerm, d1, d2); + let wValue1 = getW(wRPerm, wCPerm, d1 + 1, d2); + let wValue2 = getW(wRPerm, wCPerm, d1 + 2, d2); + let wValue3 = getW(wRPerm, wCPerm, d1 + 3, d2); + var xValue = getDy(batch, idyR, idyC2, d2); + let tmpval = vec4(dot(xValue, wValue0), + dot(xValue, wValue1), + dot(xValue, wValue2), + dot(xValue, wValue3)); + dotProd[1] = dotProd[1] + tmpval; + } + } + } + } + + for (var i = 0; i < ${this.workPerThread}; i = i + 1) { + let coords = vec4(batch, r, c + i, d1); + if (coordsInBounds4D(coords, uniforms.outShape)) { + setOutputAtCoords(coords[0], coords[1], coords[2], coords[3], dotProd[i]); + } + } + } + `; + return this.isVec4 ? + ` + ${vec4Snippet} + ` : + ` ${main('index')} { if(index < uniforms.size) { let coords = getCoordsFromIndex(index); @@ -52,7 +175,7 @@ export class Conv2DDerInputProgram implements WebGPUProgram { let d1 = coords[${channelDim}]; let dyCorner = vec2(coords[${rowDim}], coords[${ - colDim}]) - uniforms.pads; + colDim}]) - uniforms.pads; let dyRCorner = dyCorner.x; let dyCCorner = dyCorner.y; @@ -78,16 +201,11 @@ export class Conv2DDerInputProgram implements WebGPUProgram { let idyC = i32(dyC); for (var d2 = 0; d2 < uniforms.outBackprop[3]; d2 = d2 + 1) { - if (${this.isChannelsLast}) { - let xValue = getDy(batch, idyR, idyC, d2); - let wValue = getW(wRPerm, wCPerm, d1, d2); - dotProd = dotProd + xValue * wValue; - } else { - let xValue = getDy(batch, d2, idyR, idyC); - let wValue = getW(wRPerm, wCPerm, d1, d2); - dotProd = dotProd + xValue * wValue; - } - + let xValue = ${ + this.isChannelsLast ? 'getDy(batch, idyR, idyC, d2)' : + 'getDy(batch, d2, idyR, idyC)'}; + let wValue = getW(wRPerm, wCPerm, d1, d2); + dotProd = dotProd + xValue * wValue; } } } diff --git a/tfjs-backend-webgpu/src/flags_webgpu.ts b/tfjs-backend-webgpu/src/flags_webgpu.ts index 49ab70ffbac..7768a6a25c6 100644 --- a/tfjs-backend-webgpu/src/flags_webgpu.ts +++ b/tfjs-backend-webgpu/src/flags_webgpu.ts @@ -39,7 +39,7 @@ ENV.registerFlag('WEBGPU_MATMUL_PROGRAM_TYPE', () => -1); * Whether to use conv2dTranspose_naive which directly implement the * conv2dTranspose logic rather than using a matmul to simulate. */ -ENV.registerFlag('WEBGPU_USE_NAIVE_CONV2D_TRANSPOSE', () => false); +ENV.registerFlag('WEBGPU_USE_NAIVE_CONV2D_TRANSPOSE', () => true); /** * Whether we use low power GPU. Otherwise, a high performance GPU will be diff --git a/tfjs-backend-webgpu/src/kernels/Conv2DBackpropInput.ts b/tfjs-backend-webgpu/src/kernels/Conv2DBackpropInput.ts index 7d968b6e85f..f104a23f6f6 100644 --- a/tfjs-backend-webgpu/src/kernels/Conv2DBackpropInput.ts +++ b/tfjs-backend-webgpu/src/kernels/Conv2DBackpropInput.ts @@ -54,11 +54,8 @@ export function conv2DBackpropInput(args: { }, ]; let program: Conv2DDerInputProgram|Conv2DDerInputMMProgram; - // When filter size is small, Conv2DDerInputProgram is much faster than - // Conv2DDerInputMMProgram. - if (env().getBool('WEBGPU_USE_NAIVE_CONV2D_TRANSPOSE') || - convInfo.filterHeight <= 2 && convInfo.filterWidth <= 2 && - convInfo.outChannels <= 16 && convInfo.inChannels === 1) { + // TODO: Experiment when to use Conv2DDerInputMMProgram algorithm. + if (env().getBool('WEBGPU_USE_NAIVE_CONV2D_TRANSPOSE')) { program = new Conv2DDerInputProgram(convInfo); } else { program = new Conv2DDerInputMMProgram(convInfo);