From 732fce5fe2b620255b8b07c9a224e9a454272651 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Tue, 12 Nov 2019 16:33:48 +0800 Subject: [PATCH 1/2] Add shaderKey as the program member This change is to reduce the overhead of shaderKey comparison. Previously, we used a very long string 'userCode' as part of the shader key which is time consuming when search the shader key. --- tfjs-backend-webgpu/src/backend_webgpu.ts | 16 +++++++--- .../src/kernels/argminmax_webgpu.ts | 2 ++ .../src/kernels/binary_op_webgpu.ts | 2 ++ .../src/kernels/clip_webgpu.ts | 7 ++-- .../src/kernels/concat_webgpu.ts | 6 ++-- .../src/kernels/conv2d_mm_webgpu.ts | 17 +++++----- .../src/kernels/conv2d_naive_webgpu.ts | 10 +++--- .../src/kernels/depthwise_conv2d_webgpu.ts | 32 +++++++------------ .../src/kernels/fill_webgpu.ts | 2 ++ .../src/kernels/from_pixels_webgpu.ts | 7 ++-- .../src/kernels/matmul_packed_webgpu.ts | 10 +++--- .../src/kernels/matmul_webgpu.ts | 13 +++++--- .../src/kernels/maxpool_webgpu.ts | 2 ++ tfjs-backend-webgpu/src/kernels/pad_webgpu.ts | 3 ++ .../src/kernels/resize_bilinear_webgpu.ts | 2 ++ .../src/kernels/select_webgpu.ts | 6 +++- .../src/kernels/slice_webgpu.ts | 5 ++- .../src/kernels/transpose_webgpu.ts | 6 ++-- .../src/kernels/unary_op_webgpu.ts | 3 +- .../src/kernels/webgpu_program.ts | 5 ++- 20 files changed, 94 insertions(+), 62 deletions(-) diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index e3844ff278d..d9eadbf7290 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -688,7 +688,8 @@ export class WebGPUBackend extends KernelBackend { const dimensions = [ convInfo.filterHeight, convInfo.filterWidth, ...pad, - convInfo.strideHeight, convInfo.strideWidth + convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationHeight, + convInfo.dilationWidth ]; return this.compileAndRun(program, [x, filter], output, dimensions); @@ -698,7 +699,13 @@ export class WebGPUBackend extends KernelBackend { x: Tensor4D, filter: Tensor4D, convInfo: backend_util.Conv2DInfo): Tensor4D { const program = new DepthwiseConv2DProgram(convInfo); - return this.compileAndRun(program, [x, filter]); + const dimensions = [ + convInfo.filterHeight, convInfo.filterWidth, convInfo.padInfo.top, + convInfo.padInfo.left, convInfo.strideHeight, convInfo.strideWidth, + convInfo.dilationHeight, convInfo.dilationWidth, convInfo.inHeight, + convInfo.inWidth + ]; + return this.compileAndRun(program, [x, filter], null, dimensions); } private argMinMaxReduce(x: Tensor, axis: number, reduceType: 'min'|'max'): @@ -942,7 +949,8 @@ export class WebGPUBackend extends KernelBackend { if (numChannels != null && numChannels !== 4) { pixelArray = new Uint8Array(pixels.width * pixels.height * numChannels); - for (let i = 0; i < imageData.length; i++) { + const dataLength = imageData.length; + for (let i = 0; i < dataLength; i++) { if (i % 4 < numChannels) { const pixelIndex = Math.floor(i / 4); pixelArray[pixelIndex * numChannels + i % 4] = imageData[i]; @@ -953,7 +961,7 @@ export class WebGPUBackend extends KernelBackend { const output = this.makeOutputArray(outShape, 'int32'); const info = this.tensorMap.get(output.dataId); - info.values = Int32Array.from(pixelArray); + info.values = new Int32Array(pixelArray); this.maybeReleaseBuffer(output.dataId); this.uploadToGPU(output.dataId); diff --git a/tfjs-backend-webgpu/src/kernels/argminmax_webgpu.ts b/tfjs-backend-webgpu/src/kernels/argminmax_webgpu.ts index d7ec0187076..91bf19a19fc 100644 --- a/tfjs-backend-webgpu/src/kernels/argminmax_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/argminmax_webgpu.ts @@ -24,6 +24,7 @@ import {WebGPUProgram} from './webgpu_program'; export class ArgMinMaxProgram implements WebGPUProgram { outputShape: number[]; + shaderKey: string; userCode: string; dispatchLayout: {x: number[], y: number[]}; dispatch: [number, number, number]; @@ -179,5 +180,6 @@ export class ArgMinMaxProgram implements WebGPUProgram { 'setOutput(flatOutputIndex, int(bestIndex));'} } `; + this.shaderKey = `ArgMinMax${op}${reduceInSharedMemory}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/binary_op_webgpu.ts b/tfjs-backend-webgpu/src/kernels/binary_op_webgpu.ts index 8850704143c..1a744422064 100644 --- a/tfjs-backend-webgpu/src/kernels/binary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/binary_op_webgpu.ts @@ -42,6 +42,7 @@ export const PRELU = `return (a < 0.) ? b * a : a;`; export class BinaryOpProgram implements WebGPUProgram { outputShape: number[]; + shaderKey: string; userCode: string; dispatchLayout: {x: number[]}; dispatch: [number, number, number]; @@ -80,5 +81,6 @@ export class BinaryOpProgram implements WebGPUProgram { } } `; + this.shaderKey = `binary${op}${type}${size}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/clip_webgpu.ts b/tfjs-backend-webgpu/src/kernels/clip_webgpu.ts index 22c94572e9b..d7c8f64f279 100644 --- a/tfjs-backend-webgpu/src/kernels/clip_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/clip_webgpu.ts @@ -23,6 +23,7 @@ import {WebGPUProgram} from './webgpu_program'; export class ClipProgram implements WebGPUProgram { outputShape: number[]; + shaderKey: string; userCode: string; variableNames = ['A']; dispatchLayout: {x: number[]}; @@ -35,8 +36,8 @@ export class ClipProgram implements WebGPUProgram { const size = util.sizeFromShape(this.outputShape); this.dispatchLayout = flatDispatchLayout(this.outputShape); this.dispatch = computeDispatch( - this.dispatchLayout, this.outputShape, - this.workGroupSize, [this.workPerThread, 1 ,1]); + this.dispatchLayout, this.outputShape, this.workGroupSize, + [this.workPerThread, 1, 1]); const type = getCoordsDataType(this.outputShape.length); this.userCode = ` @@ -58,5 +59,7 @@ export class ClipProgram implements WebGPUProgram { } } `; + + this.shaderKey = `clip${size}${type}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/concat_webgpu.ts b/tfjs-backend-webgpu/src/kernels/concat_webgpu.ts index 7dc1067ec1e..6e39ef5ed69 100644 --- a/tfjs-backend-webgpu/src/kernels/concat_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/concat_webgpu.ts @@ -22,6 +22,7 @@ import {WebGPUProgram} from './webgpu_program'; export class ConcatProgram implements WebGPUProgram { outputShape: number[]; + shaderKey: string; userCode: string; dispatchLayout: {x: number[]}; dispatch: [number, number, number]; @@ -36,8 +37,8 @@ export class ConcatProgram implements WebGPUProgram { const size = util.sizeFromShape(this.outputShape); this.dispatchLayout = flatDispatchLayout(this.outputShape); this.dispatch = computeDispatch( - this.dispatchLayout, this.outputShape, - this.workGroupSize, [this.workPerThread, 1 ,1]); + this.dispatchLayout, this.outputShape, this.workGroupSize, + [this.workPerThread, 1, 1]); const offsets: number[] = new Array(shapes.length - 1); offsets[0] = shapes[0][1]; @@ -76,5 +77,6 @@ export class ConcatProgram implements WebGPUProgram { } } `; + this.shaderKey = `concat${size}${offsets.join(',')}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts b/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts index d48fbf86524..673a337d532 100644 --- a/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts @@ -25,11 +25,12 @@ import {WebGPUProgram} from './webgpu_program'; export class Conv2DMMProgram implements WebGPUProgram { outputShape: number[]; + shaderKey: string; userCode: string; dispatchLayout: {x: number[], y: number[], z: number[]}; dispatch: [number, number, number]; variableNames = ['x', 'W']; - uniforms = 'ivec2 filterDims, pad, stride;'; + uniforms = 'ivec2 filterDims, pad, stride, dilation;'; workGroupSize: [number, number, number]; constructor(convInfo: backend_util.Conv2DInfo, workPerThread: number) { @@ -52,9 +53,6 @@ export class Conv2DMMProgram implements WebGPUProgram { matMulSource = makeMatMulPackedSource(elementsPerThread); } - const dilationHeight = convInfo.dilationHeight; - const dilationWidth = convInfo.dilationWidth; - const tileAOuter = this.workGroupSize[1] * elementsPerThread[1]; const tileBOuter = this.workGroupSize[0] * elementsPerThread[0]; const tileInner = tileBOuter; @@ -64,10 +62,12 @@ export class Conv2DMMProgram implements WebGPUProgram { const dimBOuter = this.outputShape[1] * this.outputShape[2]; const dimInner = convInfo.filterHeight * convInfo.filterWidth * convInfo.inChannels; - const sampleA = tilesFitEvenlyIntoShape(tileSizeA, [dimAOuter, dimInner]) ? + const fitA = tilesFitEvenlyIntoShape(tileSizeA, [dimAOuter, dimInner]); + const sampleA = fitA ? `W[getFlatIndex(coord, shape)]` : `coordsInBounds(coord, shape) ? W[getFlatIndex(coord, shape)] : 0`; - const sampleB = tilesFitEvenlyIntoShape(tileSizeB, [dimInner, dimBOuter]) ? + const fitB = tilesFitEvenlyIntoShape(tileSizeB, [dimInner, dimBOuter]); + const sampleB = fitB ? `x[getFlatIndex(coord, xShape)]` : `coordsInBounds(coord, xShape) ? x[getFlatIndex(coord, xShape)] : 0`; @@ -102,8 +102,8 @@ export class Conv2DMMProgram implements WebGPUProgram { ivec4 coord = ivec4( batch, - pad[0] + outRow * stride[0] + ${dilationHeight} * WRow, - pad[1] + outCol * stride[1] + ${dilationWidth} * WCol, + pad[0] + outRow * stride[0] + dilation[0] * WRow, + pad[1] + outCol * stride[1] + dilation[1] * WCol, r / (filterDims[0] * filterDims[1])); return ${sampleB}; } @@ -128,5 +128,6 @@ export class Conv2DMMProgram implements WebGPUProgram { mm_matMul(dimAOuter, dimInner, dimBOuter); } `; + this.shaderKey = `conv2dmm'${elementsPerThread.join('')}${fitA}${fitB}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/conv2d_naive_webgpu.ts b/tfjs-backend-webgpu/src/kernels/conv2d_naive_webgpu.ts index 725d293177d..0d8899ae73f 100644 --- a/tfjs-backend-webgpu/src/kernels/conv2d_naive_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/conv2d_naive_webgpu.ts @@ -23,16 +23,15 @@ import {WebGPUProgram} from './webgpu_program'; export class Conv2DNaiveProgram implements WebGPUProgram { outputShape: number[]; + shaderKey: string; userCode: string; dispatchLayout: {x: number[], y: number[], z: number[]}; dispatch: [number, number, number]; variableNames = ['x', 'W']; - uniforms = 'ivec2 filterDims, pad, stride;'; + uniforms = 'ivec2 filterDims, pad, stride, dilation;'; workGroupSize: [number, number, number] = [4, 8, 4]; constructor(convInfo: backend_util.Conv2DInfo) { - const dilationHeight = convInfo.dilationHeight; - const dilationWidth = convInfo.dilationWidth; this.outputShape = convInfo.outShape; this.dispatchLayout = {x: [2], y: [1], z: [0, 3]}; this.dispatch = computeDispatch( @@ -73,8 +72,8 @@ export class Conv2DNaiveProgram implements WebGPUProgram { for (int col = 0; col < filterDims[1]; ++col) { for (int xChannel = 0; xChannel < xShape[3]; ++xChannel) { float v = readInp(batch, - pad[0] + coords[1] * stride[0] + ${dilationHeight} * row, - pad[1] + coords[2] * stride[1] + ${dilationWidth} * col, + pad[0] + coords[1] * stride[0] + dilation[0] * row, + pad[1] + coords[2] * stride[1] + dilation[1] * col, xChannel); float f = readFilt(row, col, xChannel, outChannel); acc += v * f; @@ -85,5 +84,6 @@ export class Conv2DNaiveProgram implements WebGPUProgram { writeResult(batch, coords[1], coords[2], outChannel, acc); } `; + this.shaderKey = 'conv2dnaive'; } } diff --git a/tfjs-backend-webgpu/src/kernels/depthwise_conv2d_webgpu.ts b/tfjs-backend-webgpu/src/kernels/depthwise_conv2d_webgpu.ts index 289da7e3446..3d91165eccb 100755 --- a/tfjs-backend-webgpu/src/kernels/depthwise_conv2d_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/depthwise_conv2d_webgpu.ts @@ -21,11 +21,12 @@ import {WebGPUProgram} from './webgpu_program'; export class DepthwiseConv2DProgram implements WebGPUProgram { outputShape: number[]; + shaderKey: string; userCode: string; dispatchLayout: {x: number[], y: number[], z: number[]}; dispatch: [number, number, number]; variableNames = ['x', 'W']; - uniforms = 'ivec2 filterDims, pad, stride;'; + uniforms = 'ivec2 filterDims, pad, stride, dilation, inDims;'; workGroupSize: [number, number, number] = [4, 8, 4]; constructor(convInfo: backend_util.Conv2DInfo) { @@ -33,16 +34,6 @@ export class DepthwiseConv2DProgram implements WebGPUProgram { this.dispatchLayout = {x: [2], y: [1], z: [0, 3]}; this.dispatch = computeDispatch( this.dispatchLayout, this.outputShape, this.workGroupSize); - const xNumRows = convInfo.inHeight; - const xNumCols = convInfo.inWidth; - const padTop = convInfo.padInfo.top; - const padLeft = convInfo.padInfo.left; - const strideHeight = convInfo.strideHeight; - const strideWidth = convInfo.strideWidth; - const dilationHeight = convInfo.dilationHeight; - const dilationWidth = convInfo.dilationWidth; - const filterHeight = convInfo.filterHeight; - const filterWidth = convInfo.filterWidth; const channelMul = convInfo.outChannels / convInfo.inChannels; util.assert( @@ -50,9 +41,6 @@ export class DepthwiseConv2DProgram implements WebGPUProgram { () => 'TODO: NCHW is unimplemented'); this.userCode = ` - const ivec2 strides = ivec2(${strideHeight}, ${strideWidth}); - const ivec2 pads = ivec2(${padTop}, ${padLeft}); - void writeResult(int batch, int row, int col, int chan, float value) { ivec4 coord = ivec4(batch, row, col, chan); if (coordsInBounds(coord, outShape)) { @@ -63,7 +51,7 @@ export class DepthwiseConv2DProgram implements WebGPUProgram { void main() { ivec4 coords = getOutputCoords(); int batch = coords[0]; - ivec2 xRCCorner = coords.yz * strides - pads; + ivec2 xRCCorner = coords.yz * stride - pad; int d2 = coords[3]; int d1 = d2 / ${channelMul}; int q = d2 - d1 * ${channelMul}; @@ -75,17 +63,17 @@ export class DepthwiseConv2DProgram implements WebGPUProgram { // ? = to be determined. : = across all values in that axis. float dotProd = 0.0; // TODO(xing.xu): Flatten the two for loops and vec4 the operations. - for (int wR = 0; wR < ${filterHeight}; wR++) { - int xR = xRCorner + wR * ${dilationHeight}; + for (int wR = 0; wR < filterDims[0]; wR++) { + int xR = xRCorner + wR * dilation[0]; - if (xR < 0 || xR >= ${xNumRows}) { + if (xR < 0 || xR >= inDims[0]) { continue; } - for (int wC = 0; wC < ${filterWidth}; wC++) { - int xC = xCCorner + wC * ${dilationWidth}; + for (int wC = 0; wC < filterDims[1]; wC++) { + int xC = xCCorner + wC * dilation[1]; - if (xC < 0 || xC >= ${xNumCols}) { + if (xC < 0 || xC >= inDims[1]) { continue; } @@ -97,5 +85,7 @@ export class DepthwiseConv2DProgram implements WebGPUProgram { writeResult(batch, coords[1], coords[2], d2, dotProd); } `; + + this.shaderKey = `depthwise${channelMul}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/fill_webgpu.ts b/tfjs-backend-webgpu/src/kernels/fill_webgpu.ts index 6bbf7b27692..1a8d2048245 100644 --- a/tfjs-backend-webgpu/src/kernels/fill_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/fill_webgpu.ts @@ -21,6 +21,7 @@ import {WebGPUProgram} from './webgpu_program'; export class FillProgram implements WebGPUProgram { variableNames: string[] = []; outputShape: number[] = []; + shaderKey: string; userCode: string; dispatchLayout: {x: number[]}; dispatch: [number, number, number]; @@ -46,5 +47,6 @@ export class FillProgram implements WebGPUProgram { } } `; + this.shaderKey = `fill${size}${value}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/from_pixels_webgpu.ts b/tfjs-backend-webgpu/src/kernels/from_pixels_webgpu.ts index 606b382c713..0f0c73aa73f 100644 --- a/tfjs-backend-webgpu/src/kernels/from_pixels_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/from_pixels_webgpu.ts @@ -21,13 +21,13 @@ import {WebGPUProgram} from './webgpu_program'; export class FromPixelsProgram implements WebGPUProgram { outputShape: number[]; + shaderKey: string; userCode: string; variableNames = ['A']; dispatchLayout: {x: number[]}; dispatch: [number, number, number]; constructor(outputShape: number[]) { - const [height, width, ] = outputShape; this.outputShape = outputShape; this.dispatchLayout = flatDispatchLayout(this.outputShape); this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape); @@ -38,7 +38,7 @@ export class FromPixelsProgram implements WebGPUProgram { int texR = coords[0]; int texC = coords[1]; int depth = coords[2]; - vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0); + vec2 uv = (vec2(texC, texR) + halfCR) / vec2(outShape.yx); vec4 values = texelFetch(A, uv); float value; @@ -55,5 +55,6 @@ export class FromPixelsProgram implements WebGPUProgram { setOutput(floor(value * 255.0 + 0.5)); } `; + this.shaderKey = 'fromPixel'; } -} \ No newline at end of file +} diff --git a/tfjs-backend-webgpu/src/kernels/matmul_packed_webgpu.ts b/tfjs-backend-webgpu/src/kernels/matmul_packed_webgpu.ts index c8d46a4bcbc..07206043c1b 100644 --- a/tfjs-backend-webgpu/src/kernels/matmul_packed_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/matmul_packed_webgpu.ts @@ -117,6 +117,7 @@ export function makeMatMulPackedSource(workPerThread: number[]): string { export class MatMulPackedProgram implements WebGPUProgram { outputShape: number[]; + shaderKey: string; userCode: string; dispatchLayout: {x: number[], y: number[], z: number[]}; dispatch: [number, number, number]; @@ -135,12 +136,13 @@ export class MatMulPackedProgram implements WebGPUProgram { const tileInner = tileBOuter; const tileSizeA = [tileAOuter, tileInner]; const tileSizeB = [tileInner, tileBOuter]; - - const sampleA = tilesFitEvenlyIntoShape(tileSizeA, aShape.slice(1)) ? + const fitA = tilesFitEvenlyIntoShape(tileSizeA, aShape.slice(1)); + const sampleA = fitA ? `A[row * dimInner + col]` : `coordsInBounds(ivec2(row, col), ivec2(dimAOuter, dimInner)) ? A[row * dimInner + col] : 0`; - const sampleB = tilesFitEvenlyIntoShape(tileSizeB, bShape.slice(1)) ? + const fitB = tilesFitEvenlyIntoShape(tileSizeB, bShape.slice(1)); + const sampleB = fitB ? `B[row * dimBOuter + col]` : `coordsInBounds(ivec2(row, col), ivec2(dimInner, dimBOuter)) ? B[row * dimBOuter + col] : 0`; @@ -149,7 +151,6 @@ export class MatMulPackedProgram implements WebGPUProgram { this.dispatch = computeDispatch( this.dispatchLayout, this.outputShape, this.workGroupSize, [workPerThread, workPerThread, 1]); - this.userCode = ` int dimAOuter = aShape[1]; int dimInner = aShape[2]; @@ -173,5 +174,6 @@ export class MatMulPackedProgram implements WebGPUProgram { mm_matMul(dimAOuter, dimInner, dimBOuter); } `; + this.shaderKey = `matmulpacked${this.workPerThread}${fitA}${fitB}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/matmul_webgpu.ts b/tfjs-backend-webgpu/src/kernels/matmul_webgpu.ts index dcd2b8c6735..5ee436290e1 100644 --- a/tfjs-backend-webgpu/src/kernels/matmul_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/matmul_webgpu.ts @@ -70,6 +70,7 @@ export function makeMatMulSource(): string { export class MatMulProgram implements WebGPUProgram { outputShape: number[]; + shaderKey: string; userCode: string; dispatchLayout: {x: number[], y: number[], z: number[]}; dispatch: [number, number, number]; @@ -84,14 +85,15 @@ export class MatMulProgram implements WebGPUProgram { this.dispatchLayout = {x: [2], y: [1], z: [0]}; this.dispatch = computeDispatch( this.dispatchLayout, this.outputShape, this.workGroupSize); - - const sampleA = tilesFitEvenlyIntoShape( - this.workGroupSize.slice(0, 2), aShape.slice(1)) ? + const fitA = tilesFitEvenlyIntoShape( + this.workGroupSize.slice(0, 2), aShape.slice(1)); + const sampleA = fitA ? `A[row * dimInner + col]` : `coordsInBounds(ivec2(row, col), ivec2(dimAOuter, dimInner)) ? A[row * dimInner + col] : 0`; - const sampleB = tilesFitEvenlyIntoShape( - this.workGroupSize.slice(0, 2), bShape.slice(1)) ? + const fitB = tilesFitEvenlyIntoShape( + this.workGroupSize.slice(0, 2), bShape.slice(1)); + const sampleB = fitB ? `B[row * dimBOuter + col]` : `coordsInBounds(ivec2(row, col), ivec2(dimInner, dimBOuter)) ? B[row * dimBOuter + col] : 0`; @@ -119,5 +121,6 @@ export class MatMulProgram implements WebGPUProgram { mm_matMul(dimAOuter, dimInner, dimBOuter); } `; + this.shaderKey = `matmul${fitA}${fitB}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/maxpool_webgpu.ts b/tfjs-backend-webgpu/src/kernels/maxpool_webgpu.ts index 88bcfa32a46..0be069993d1 100644 --- a/tfjs-backend-webgpu/src/kernels/maxpool_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/maxpool_webgpu.ts @@ -23,6 +23,7 @@ import {WebGPUProgram} from './webgpu_program'; export class MaxPoolProgram implements WebGPUProgram { outputShape: number[]; + shaderKey: string; userCode: string; dispatchLayout: {x: number[], y: number[], z: number[]}; dispatch: [number, number, number]; @@ -76,5 +77,6 @@ export class MaxPoolProgram implements WebGPUProgram { } } `; + this.shaderKey = 'maxpool'; } } diff --git a/tfjs-backend-webgpu/src/kernels/pad_webgpu.ts b/tfjs-backend-webgpu/src/kernels/pad_webgpu.ts index 6cfa7142c15..82d1a2098b1 100644 --- a/tfjs-backend-webgpu/src/kernels/pad_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/pad_webgpu.ts @@ -24,6 +24,7 @@ import {WebGPUProgram} from './webgpu_program'; export class PadProgram implements WebGPUProgram { outputShape: number[]; + shaderKey: string; userCode: string; dispatchLayout: {x: number[]}; dispatch: [number, number, number]; @@ -81,5 +82,7 @@ export class PadProgram implements WebGPUProgram { } } `; + this.shaderKey = + `pad${startValue}${endValue}${rank}${size}${type}${constantValue}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/resize_bilinear_webgpu.ts b/tfjs-backend-webgpu/src/kernels/resize_bilinear_webgpu.ts index 51b8bed4db3..8180fb1a3fb 100644 --- a/tfjs-backend-webgpu/src/kernels/resize_bilinear_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/resize_bilinear_webgpu.ts @@ -21,6 +21,7 @@ import {WebGPUProgram} from './webgpu_program'; export class ResizeBilinearProgram implements WebGPUProgram { outputShape: number[]; + shaderKey: string; userCode: string; dispatchLayout: {x: number[], y: number[], z: number[]}; dispatch: [number, number, number]; @@ -83,5 +84,6 @@ export class ResizeBilinearProgram implements WebGPUProgram { } } `; + this.shaderKey = `resizeblilinear${adjustHeight}${adjustWidth}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/select_webgpu.ts b/tfjs-backend-webgpu/src/kernels/select_webgpu.ts index 74ce707b1c5..1b85aee9050 100644 --- a/tfjs-backend-webgpu/src/kernels/select_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/select_webgpu.ts @@ -15,13 +15,16 @@ * ============================================================================= */ import {util} from '@tensorflow/tfjs-core'; + +import {getCoordsDataType} from '../shader_preprocessor'; import {computeDispatch, flatDispatchLayout} from '../webgpu_util'; + import {WebGPUProgram} from './webgpu_program'; -import {getCoordsDataType} from '../shader_preprocessor'; export class SelectProgram implements WebGPUProgram { variableNames = ['c', 'a', 'b']; outputShape: number[]; + shaderKey: string; userCode: string; dispatchLayout: {x: number[]}; dispatch: [number, number, number]; @@ -80,5 +83,6 @@ export class SelectProgram implements WebGPUProgram { } } `; + this.shaderKey = `select${size}${rank}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/slice_webgpu.ts b/tfjs-backend-webgpu/src/kernels/slice_webgpu.ts index 35496e4df1e..a6a9a37e553 100644 --- a/tfjs-backend-webgpu/src/kernels/slice_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/slice_webgpu.ts @@ -15,13 +15,15 @@ * ============================================================================= */ +import {getCoordsDataType} from '../shader_preprocessor'; import {computeDispatch, flatDispatchLayout} from '../webgpu_util'; + import {WebGPUProgram} from './webgpu_program'; -import {getCoordsDataType} from '../shader_preprocessor'; export class SliceProgram implements WebGPUProgram { variableNames = ['source']; outputShape: number[]; + shaderKey: string; userCode: string; rank: number; dispatchLayout: {x: number[]}; @@ -53,6 +55,7 @@ export class SliceProgram implements WebGPUProgram { setOutput(index, getSource(${sourceCoords})); } `; + this.shaderKey = `slice${this.rank}${start.join(',')}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/transpose_webgpu.ts b/tfjs-backend-webgpu/src/kernels/transpose_webgpu.ts index 992a9cf5e3f..70bfb9667cc 100644 --- a/tfjs-backend-webgpu/src/kernels/transpose_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/transpose_webgpu.ts @@ -23,6 +23,7 @@ import {WebGPUProgram} from './webgpu_program'; export class TransposeProgram implements WebGPUProgram { variableNames = ['A']; + shaderKey: string; outputShape: number[]; userCode: string; dispatchLayout: {x: number[]}; @@ -42,8 +43,8 @@ export class TransposeProgram implements WebGPUProgram { const size = util.sizeFromShape(this.outputShape); this.dispatchLayout = flatDispatchLayout(this.outputShape); this.dispatch = computeDispatch( - this.dispatchLayout, this.outputShape, - this.workGroupSize, [this.workPerThread, 1 ,1]); + this.dispatchLayout, this.outputShape, this.workGroupSize, + [this.workPerThread, 1, 1]); const switched = getSwitchedCoords(newDim); @@ -61,6 +62,7 @@ export class TransposeProgram implements WebGPUProgram { } } `; + this.shaderKey = `tranpose${size}${dtype}${newDim.join(',')}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts b/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts index 6371658ff33..32197797cf7 100644 --- a/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/kernels/unary_op_webgpu.ts @@ -29,6 +29,7 @@ export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`; export class UnaryOpProgram implements WebGPUProgram { outputShape: number[]; userCode: string; + shaderKey: string; dispatchLayout: {x: number[]}; dispatch: [number, number, number]; variableNames = ['A']; @@ -44,7 +45,6 @@ export class UnaryOpProgram implements WebGPUProgram { this.dispatchLayout, this.outputShape, this.workGroupSize, [this.workPerThread, 1, 1]); const type = getCoordsDataType(this.outputShape.length); - this.userCode = ` float unaryOperation(float a) { ${op} @@ -65,5 +65,6 @@ export class UnaryOpProgram implements WebGPUProgram { } } `; + this.shaderKey = `unary${op}${type}${size}`; } } diff --git a/tfjs-backend-webgpu/src/kernels/webgpu_program.ts b/tfjs-backend-webgpu/src/kernels/webgpu_program.ts index b11cd4bad03..6324d60c284 100644 --- a/tfjs-backend-webgpu/src/kernels/webgpu_program.ts +++ b/tfjs-backend-webgpu/src/kernels/webgpu_program.ts @@ -21,6 +21,7 @@ import * as shaderc from '@webgpu/shaderc'; import * as shader_preprocessor from '../shader_preprocessor'; export interface WebGPUProgram { + shaderKey: string; userCode: string; outputShape: number[]; // dispatchLayout enumerates how tensor dimensions are distributed among @@ -115,12 +116,10 @@ export const compileProgram = return {bindGroupLayout, pipeline}; }; -// TODO: Consider allowing each program to specify its own shader key. E.g. some -// kernels account for different work group sizes, but some don't. // TODO: Consider uploading shape info as vec4s regardless of rank to reduce // recompilation. export function makeShaderKey(program: WebGPUProgram, ranks: number[]): string { const key = (program.workGroupSize ? program.workGroupSize.join(',') : '') + - ranks.join(',') + program.userCode; + ranks.join(',') + program.shaderKey; return key; } From 54f06b9aa9c846e822938a38a221ebf49a6e4a28 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 3 Dec 2019 17:30:12 +0800 Subject: [PATCH 2/2] Make shaderKey optional --- tfjs-backend-webgpu/src/kernels/webgpu_program.ts | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tfjs-backend-webgpu/src/kernels/webgpu_program.ts b/tfjs-backend-webgpu/src/kernels/webgpu_program.ts index 6324d60c284..e1d97a1dfd0 100644 --- a/tfjs-backend-webgpu/src/kernels/webgpu_program.ts +++ b/tfjs-backend-webgpu/src/kernels/webgpu_program.ts @@ -21,7 +21,9 @@ import * as shaderc from '@webgpu/shaderc'; import * as shader_preprocessor from '../shader_preprocessor'; export interface WebGPUProgram { - shaderKey: string; + // The unique key to distinguish different shader source code. If shaderKey is + // not specified, use userCode to replace. + shaderKey?: string; userCode: string; outputShape: number[]; // dispatchLayout enumerates how tensor dimensions are distributed among @@ -120,6 +122,7 @@ export const compileProgram = // recompilation. export function makeShaderKey(program: WebGPUProgram, ranks: number[]): string { const key = (program.workGroupSize ? program.workGroupSize.join(',') : '') + - ranks.join(',') + program.shaderKey; + ranks.join(',') + + (program.shaderKey ? program.shaderKey : program.userCode); return key; }