diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 7ae801222b87..8496173b1e8f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -43,8 +43,30 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu const outputShape = getOutputShape(inputTensor.dims, perm); const output = outputVariable('output', inputDataType, outputShape.length); const input = inputVariable('a', inputDataType, inputRank); - - const getShaderSource = (shaderHelper: ShaderHelper) => ` + let getShaderSource; + if (perm.length === 2 && perm[0] === 1 && perm[1] === 0) { + const wgslType = output.type.value; + const workgroupSize: [number, number, number] = [16, 16, 1]; + getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} + var tile : array, ${workgroupSize[0]}>; + ${shaderHelper.mainStart(workgroupSize)} + var x = workgroup_id.x * ${workgroupSize[0]}u + local_id.x; + var y = workgroup_id.y * ${workgroupSize[0]}u + local_id.y; + let width = uniforms.output_shape[0]; + let height = uniforms.output_shape[1]; + if (x < width && y < height) { + tile[local_id.y][local_id.x] = ${input.getByOffset('y * width + x')}; + } + workgroupBarrier(); + x = workgroup_id.y * ${workgroupSize[0]}u + local_id.x; + y = workgroup_id.x * ${workgroupSize[0]}u + local_id.y; + if (x < height && y < width) { + ${output.setByOffset('y * height + x', 'tile[local_id.x][local_id.y]')} + } + }`; + } else { + getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} ${permFunctionBody(perm, inputRank, input, output)} @@ -57,6 +79,7 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu ${output.setByOffset('global_idx', input.getByIndices('aIndices'))} }`; + } return { name: 'Transpose', shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']}, diff --git a/js/web/test/data/ops/transpose.jsonc b/js/web/test/data/ops/transpose.jsonc index e1edfa7e4151..2b01475522ac 100644 --- a/js/web/test/data/ops/transpose.jsonc +++ b/js/web/test/data/ops/transpose.jsonc @@ -190,5 +190,29 @@ ] } ] + }, + { + "name": "Transpose - perms:[1, 0]", + "operator": "Transpose", + "attributes": [{ "name": "perm", "data": [1, 0], "type": "ints" }], + "cases": [ + { + "name": "T[6,4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [6, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24], + "dims": [4, 6], + "type": "float32" + } + ] + } + ] } ]