Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 136 additions & 18 deletions tfjs-backend-webgpu/src/conv_backprop_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,157 @@ export class Conv2DDerInputProgram implements WebGPUProgram {
'filterDims : vec2<i32>, pads : vec2<i32>, strides : vec2<i32>, outBackprop : vec4<i32>,';
outputShape: number[];
shaderKey: string;
dispatchLayout: {x: number[]};
dispatchLayout: {x: number[], y?: number[], z?: number[]};
dispatch: [number, number, number];
workgroupSize: [number, number, number] = [64, 1, 1];
isChannelsLast: boolean;
size = true;
size = false;
isVec4 = false;
workPerThread = 1;

constructor(convInfo: backend_util.Conv2DInfo) {
this.outputShape = convInfo.inShape;
this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workgroupSize);
this.isChannelsLast = convInfo.dataFormat === 'channelsLast';
this.shaderKey = `conv2DDerInput_${this.isChannelsLast}`;
this.isVec4 = this.isChannelsLast && convInfo.outChannels % 4 === 0 &&
convInfo.inChannels % 4 === 0;
if (this.isVec4) {
// TODO: Expand to any value.
this.workPerThread = 2;
this.workgroupSize = [4, 4, 4];
this.dispatchLayout = {x: [3], y: [2], z: [0, 1]};
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workgroupSize,
[4, this.workPerThread, 1]);
} else {
this.size = true;
this.workPerThread = 1;
this.workgroupSize = [64, 1, 1];
this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workgroupSize);
}
this.shaderKey = `conv2DDerInput_${this.isChannelsLast}_${this.isVec4}_${
this.workPerThread}`;
}

getUserCode(): string {
const rowDim = this.isChannelsLast ? 1 : 2;
const colDim = this.isChannelsLast ? 2 : 3;
const channelDim = this.isChannelsLast ? 3 : 1;
return `

const vec4Snippet = `
${main()} {
let batch = i32(globalId.z) / uniforms.outShape[1];
let r = i32(globalId.z) % uniforms.outShape[1];
let c = i32(globalId.y) * ${this.workPerThread};
let d1 = i32(globalId.x) * 4;

let dyCorner = vec2<i32>(r, c) - uniforms.pads;

// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
// ? = to be determined. : = across all values in that axis.
var dotProd: array<vec4<f32>, ${this.workPerThread}>;
for (var i = 0; i < ${this.workPerThread}; i++) {
dotProd[i] = vec4<f32>(0.0);
}
for (var wR = 0; wR < uniforms.filterDims.x; wR = wR + 1) {
let dyR = f32(dyCorner.x + wR) / f32(uniforms.strides.x);
let wRPerm = uniforms.filterDims.x - 1 - wR;
if (dyR < 0.0 || dyR >= f32(uniforms.outBackprop[1]) ||
fract(dyR) > 0.0) {
continue;
}
let idyR = i32(dyR);

for (var wC = 0; wC < uniforms.filterDims.y; wC = wC + 1) {
let dyC = f32(dyCorner.y + wC) / f32(uniforms.strides.y);
let dyC2 = f32(dyCorner.y + 1 + wC) / f32(uniforms.strides.y);
let wCPerm = uniforms.filterDims.y - 1 - wC;
var bDyCVal = true;
var bDyCVal2 = true;
if (dyC < 0.0 || dyC >= f32(uniforms.outBackprop[2]) ||
fract(dyC) > 0.0) {
bDyCVal = false;
}
if (dyC2 < 0.0 || dyC2 >= f32(uniforms.outBackprop[2]) ||
fract(dyC2) > 0.0) {
bDyCVal2 = false;
}

let idyC = i32(dyC);
let idyC2 = i32(dyC2);
if (bDyCVal && bDyCVal2) {
let d2Length = uniforms.outBackprop[3];
for (var d2 = 0; d2 < d2Length; d2 = d2 + 4) {
let wValue0 = getW(wRPerm, wCPerm, d1, d2);
let wValue1 = getW(wRPerm, wCPerm, d1 + 1, d2);
let wValue2 = getW(wRPerm, wCPerm, d1 + 2, d2);
let wValue3 = getW(wRPerm, wCPerm, d1 + 3, d2);
var xValue = getDy(batch, idyR, idyC, d2);
let tmpval = vec4<f32>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
dotProd[0] = dotProd[0] + tmpval;
xValue = getDy(batch, idyR, idyC2, d2);
dotProd[1] = dotProd[1] + vec4<f32>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
}
} else if (bDyCVal) {
let d2Length = uniforms.outBackprop[3];
for (var d2 = 0; d2 < d2Length; d2 = d2 + 4) {
let wValue0 = getW(wRPerm, wCPerm, d1, d2);
let wValue1 = getW(wRPerm, wCPerm, d1 + 1, d2);
let wValue2 = getW(wRPerm, wCPerm, d1 + 2, d2);
let wValue3 = getW(wRPerm, wCPerm, d1 + 3, d2);
var xValue = getDy(batch, idyR, idyC, d2);
let tmpval = vec4<f32>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
dotProd[0] = dotProd[0] + tmpval;
}
} else if (bDyCVal2) {
let d2Length = uniforms.outBackprop[3];
for (var d2 = 0; d2 < d2Length; d2 = d2 + 4) {
let wValue0 = getW(wRPerm, wCPerm, d1, d2);
let wValue1 = getW(wRPerm, wCPerm, d1 + 1, d2);
let wValue2 = getW(wRPerm, wCPerm, d1 + 2, d2);
let wValue3 = getW(wRPerm, wCPerm, d1 + 3, d2);
var xValue = getDy(batch, idyR, idyC2, d2);
let tmpval = vec4<f32>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
dotProd[1] = dotProd[1] + tmpval;
}
}
}
}

for (var i = 0; i < ${this.workPerThread}; i = i + 1) {
let coords = vec4<i32>(batch, r, c + i, d1);
if (coordsInBounds4D(coords, uniforms.outShape)) {
setOutputAtCoords(coords[0], coords[1], coords[2], coords[3], dotProd[i]);
}
}
}
`;
return this.isVec4 ?
`
${vec4Snippet}
` :
`
${main('index')} {
if(index < uniforms.size) {
let coords = getCoordsFromIndex(index);
let batch = coords[0];
let d1 = coords[${channelDim}];

let dyCorner = vec2<i32>(coords[${rowDim}], coords[${
colDim}]) - uniforms.pads;
colDim}]) - uniforms.pads;
let dyRCorner = dyCorner.x;
let dyCCorner = dyCorner.y;

Expand All @@ -78,16 +201,11 @@ export class Conv2DDerInputProgram implements WebGPUProgram {
let idyC = i32(dyC);

for (var d2 = 0; d2 < uniforms.outBackprop[3]; d2 = d2 + 1) {
if (${this.isChannelsLast}) {
let xValue = getDy(batch, idyR, idyC, d2);
let wValue = getW(wRPerm, wCPerm, d1, d2);
dotProd = dotProd + xValue * wValue;
} else {
let xValue = getDy(batch, d2, idyR, idyC);
let wValue = getW(wRPerm, wCPerm, d1, d2);
dotProd = dotProd + xValue * wValue;
}

let xValue = ${
this.isChannelsLast ? 'getDy(batch, idyR, idyC, d2)' :
'getDy(batch, d2, idyR, idyC)'};
let wValue = getW(wRPerm, wCPerm, d1, d2);
dotProd = dotProd + xValue * wValue;
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-webgpu/src/flags_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ ENV.registerFlag('WEBGPU_MATMUL_PROGRAM_TYPE', () => -1);
* Whether to use conv2dTranspose_naive which directly implement the
* conv2dTranspose logic rather than using a matmul to simulate.
*/
ENV.registerFlag('WEBGPU_USE_NAIVE_CONV2D_TRANSPOSE', () => false);
ENV.registerFlag('WEBGPU_USE_NAIVE_CONV2D_TRANSPOSE', () => true);

/**
* Whether we use low power GPU. Otherwise, a high performance GPU will be
Expand Down
7 changes: 2 additions & 5 deletions tfjs-backend-webgpu/src/kernels/Conv2DBackpropInput.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,8 @@ export function conv2DBackpropInput(args: {
},
];
let program: Conv2DDerInputProgram|Conv2DDerInputMMProgram;
// When filter size is small, Conv2DDerInputProgram is much faster than
// Conv2DDerInputMMProgram.
if (env().getBool('WEBGPU_USE_NAIVE_CONV2D_TRANSPOSE') ||
convInfo.filterHeight <= 2 && convInfo.filterWidth <= 2 &&
convInfo.outChannels <= 16 && convInfo.inChannels === 1) {
// TODO: Experiment when to use Conv2DDerInputMMProgram algorithm.
if (env().getBool('WEBGPU_USE_NAIVE_CONV2D_TRANSPOSE')) {
program = new Conv2DDerInputProgram(convInfo);
} else {
program = new Conv2DDerInputMMProgram(convInfo);
Expand Down