Skip to content

Commit

Permalink
[js/webgpu] Support shared memory for transpose 2d (microsoft#19267)
Browse files Browse the repository at this point in the history
For 1024x1024, without shared memoey, 18.7ms. With shared memory 13.2ms.
  • Loading branch information
axinging committed May 22, 2024
1 parent 068bb3d commit f1fef19
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
27 changes: 25 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,30 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
const outputShape = getOutputShape(inputTensor.dims, perm);
const output = outputVariable('output', inputDataType, outputShape.length);
const input = inputVariable('a', inputDataType, inputRank);

const getShaderSource = (shaderHelper: ShaderHelper) => `
let getShaderSource;
if (perm.length === 2 && perm[0] === 1 && perm[1] === 0) {
const wgslType = output.type.value;
const workgroupSize: [number, number, number] = [16, 16, 1];
getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
var<workgroup> tile : array<array<${wgslType}, ${workgroupSize[0] + 1}>, ${workgroupSize[0]}>;
${shaderHelper.mainStart(workgroupSize)}
var x = workgroup_id.x * ${workgroupSize[0]}u + local_id.x;
var y = workgroup_id.y * ${workgroupSize[0]}u + local_id.y;
let width = uniforms.output_shape[0];
let height = uniforms.output_shape[1];
if (x < width && y < height) {
tile[local_id.y][local_id.x] = ${input.getByOffset('y * width + x')};
}
workgroupBarrier();
x = workgroup_id.y * ${workgroupSize[0]}u + local_id.x;
y = workgroup_id.x * ${workgroupSize[0]}u + local_id.y;
if (x < height && y < width) {
${output.setByOffset('y * height + x', 'tile[local_id.x][local_id.y]')}
}
}`;
} else {
getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
${permFunctionBody(perm, inputRank, input, output)}
Expand All @@ -57,6 +79,7 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
${output.setByOffset('global_idx', input.getByIndices('aIndices'))}
}`;
}
return {
name: 'Transpose',
shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']},
Expand Down
24 changes: 24 additions & 0 deletions js/web/test/data/ops/transpose.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -190,5 +190,29 @@
]
}
]
},
{
"name": "Transpose - perms:[1, 0]",
"operator": "Transpose",
"attributes": [{ "name": "perm", "data": [1, 0], "type": "ints" }],
"cases": [
{
"name": "T[6,4]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
"dims": [6, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24],
"dims": [4, 6],
"type": "float32"
}
]
}
]
}
]

0 comments on commit f1fef19

Please sign in to comment.