Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,8 @@ export class WebGPUBackend extends KernelBackend {

const dimensions = [
convInfo.filterHeight, convInfo.filterWidth, ...pad,
convInfo.strideHeight, convInfo.strideWidth
convInfo.strideHeight, convInfo.strideWidth, convInfo.dilationHeight,
convInfo.dilationWidth
];

return this.compileAndRun(program, [x, filter], output, dimensions);
Expand All @@ -698,7 +699,13 @@ export class WebGPUBackend extends KernelBackend {
x: Tensor4D, filter: Tensor4D,
convInfo: backend_util.Conv2DInfo): Tensor4D {
const program = new DepthwiseConv2DProgram(convInfo);
return this.compileAndRun(program, [x, filter]);
const dimensions = [
convInfo.filterHeight, convInfo.filterWidth, convInfo.padInfo.top,
convInfo.padInfo.left, convInfo.strideHeight, convInfo.strideWidth,
convInfo.dilationHeight, convInfo.dilationWidth, convInfo.inHeight,
convInfo.inWidth
];
return this.compileAndRun(program, [x, filter], null, dimensions);
}

private argMinMaxReduce(x: Tensor, axis: number, reduceType: 'min'|'max'):
Expand Down Expand Up @@ -942,7 +949,8 @@ export class WebGPUBackend extends KernelBackend {
if (numChannels != null && numChannels !== 4) {
pixelArray = new Uint8Array(pixels.width * pixels.height * numChannels);

for (let i = 0; i < imageData.length; i++) {
const dataLength = imageData.length;
for (let i = 0; i < dataLength; i++) {
if (i % 4 < numChannels) {
const pixelIndex = Math.floor(i / 4);
pixelArray[pixelIndex * numChannels + i % 4] = imageData[i];
Expand All @@ -953,7 +961,7 @@ export class WebGPUBackend extends KernelBackend {
const output = this.makeOutputArray(outShape, 'int32');

const info = this.tensorMap.get(output.dataId);
info.values = Int32Array.from(pixelArray);
info.values = new Int32Array(pixelArray);
this.maybeReleaseBuffer(output.dataId);

this.uploadToGPU(output.dataId);
Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/kernels/argminmax_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {WebGPUProgram} from './webgpu_program';

export class ArgMinMaxProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
userCode: string;
dispatchLayout: {x: number[], y: number[]};
dispatch: [number, number, number];
Expand Down Expand Up @@ -179,5 +180,6 @@ export class ArgMinMaxProgram implements WebGPUProgram {
'setOutput(flatOutputIndex, int(bestIndex));'}
}
`;
this.shaderKey = `ArgMinMax${op}${reduceInSharedMemory}`;
}
}
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/kernels/binary_op_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export const PRELU = `return (a < 0.) ? b * a : a;`;

export class BinaryOpProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
userCode: string;
dispatchLayout: {x: number[]};
dispatch: [number, number, number];
Expand Down Expand Up @@ -80,5 +81,6 @@ export class BinaryOpProgram implements WebGPUProgram {
}
}
`;
this.shaderKey = `binary${op}${type}${size}`;
}
}
7 changes: 5 additions & 2 deletions tfjs-backend-webgpu/src/kernels/clip_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {WebGPUProgram} from './webgpu_program';

export class ClipProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
userCode: string;
variableNames = ['A'];
dispatchLayout: {x: number[]};
Expand All @@ -35,8 +36,8 @@ export class ClipProgram implements WebGPUProgram {
const size = util.sizeFromShape(this.outputShape);
this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape,
this.workGroupSize, [this.workPerThread, 1 ,1]);
this.dispatchLayout, this.outputShape, this.workGroupSize,
[this.workPerThread, 1, 1]);
const type = getCoordsDataType(this.outputShape.length);

this.userCode = `
Expand All @@ -58,5 +59,7 @@ export class ClipProgram implements WebGPUProgram {
}
}
`;

this.shaderKey = `clip${size}${type}`;
}
}
6 changes: 4 additions & 2 deletions tfjs-backend-webgpu/src/kernels/concat_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {WebGPUProgram} from './webgpu_program';

export class ConcatProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
userCode: string;
dispatchLayout: {x: number[]};
dispatch: [number, number, number];
Expand All @@ -36,8 +37,8 @@ export class ConcatProgram implements WebGPUProgram {
const size = util.sizeFromShape(this.outputShape);
this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape,
this.workGroupSize, [this.workPerThread, 1 ,1]);
this.dispatchLayout, this.outputShape, this.workGroupSize,
[this.workPerThread, 1, 1]);

const offsets: number[] = new Array(shapes.length - 1);
offsets[0] = shapes[0][1];
Expand Down Expand Up @@ -76,5 +77,6 @@ export class ConcatProgram implements WebGPUProgram {
}
}
`;
this.shaderKey = `concat${size}${offsets.join(',')}`;
}
}
17 changes: 9 additions & 8 deletions tfjs-backend-webgpu/src/kernels/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ import {WebGPUProgram} from './webgpu_program';

export class Conv2DMMProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
userCode: string;
dispatchLayout: {x: number[], y: number[], z: number[]};
dispatch: [number, number, number];
variableNames = ['x', 'W'];
uniforms = 'ivec2 filterDims, pad, stride;';
uniforms = 'ivec2 filterDims, pad, stride, dilation;';
workGroupSize: [number, number, number];

constructor(convInfo: backend_util.Conv2DInfo, workPerThread: number) {
Expand All @@ -52,9 +53,6 @@ export class Conv2DMMProgram implements WebGPUProgram {
matMulSource = makeMatMulPackedSource(elementsPerThread);
}

const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;

const tileAOuter = this.workGroupSize[1] * elementsPerThread[1];
const tileBOuter = this.workGroupSize[0] * elementsPerThread[0];
const tileInner = tileBOuter;
Expand All @@ -64,10 +62,12 @@ export class Conv2DMMProgram implements WebGPUProgram {
const dimBOuter = this.outputShape[1] * this.outputShape[2];
const dimInner =
convInfo.filterHeight * convInfo.filterWidth * convInfo.inChannels;
const sampleA = tilesFitEvenlyIntoShape(tileSizeA, [dimAOuter, dimInner]) ?
const fitA = tilesFitEvenlyIntoShape(tileSizeA, [dimAOuter, dimInner]);
const sampleA = fitA ?
`W[getFlatIndex(coord, shape)]` :
`coordsInBounds(coord, shape) ? W[getFlatIndex(coord, shape)] : 0`;
const sampleB = tilesFitEvenlyIntoShape(tileSizeB, [dimInner, dimBOuter]) ?
const fitB = tilesFitEvenlyIntoShape(tileSizeB, [dimInner, dimBOuter]);
const sampleB = fitB ?
`x[getFlatIndex(coord, xShape)]` :
`coordsInBounds(coord, xShape) ? x[getFlatIndex(coord, xShape)] : 0`;

Expand Down Expand Up @@ -102,8 +102,8 @@ export class Conv2DMMProgram implements WebGPUProgram {

ivec4 coord = ivec4(
batch,
pad[0] + outRow * stride[0] + ${dilationHeight} * WRow,
pad[1] + outCol * stride[1] + ${dilationWidth} * WCol,
pad[0] + outRow * stride[0] + dilation[0] * WRow,
pad[1] + outCol * stride[1] + dilation[1] * WCol,
r / (filterDims[0] * filterDims[1]));
return ${sampleB};
}
Expand All @@ -128,5 +128,6 @@ export class Conv2DMMProgram implements WebGPUProgram {
mm_matMul(dimAOuter, dimInner, dimBOuter);
}
`;
this.shaderKey = `conv2dmm'${elementsPerThread.join('')}${fitA}${fitB}`;
}
}
10 changes: 5 additions & 5 deletions tfjs-backend-webgpu/src/kernels/conv2d_naive_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,15 @@ import {WebGPUProgram} from './webgpu_program';

export class Conv2DNaiveProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
userCode: string;
dispatchLayout: {x: number[], y: number[], z: number[]};
dispatch: [number, number, number];
variableNames = ['x', 'W'];
uniforms = 'ivec2 filterDims, pad, stride;';
uniforms = 'ivec2 filterDims, pad, stride, dilation;';
workGroupSize: [number, number, number] = [4, 8, 4];

constructor(convInfo: backend_util.Conv2DInfo) {
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
this.outputShape = convInfo.outShape;
this.dispatchLayout = {x: [2], y: [1], z: [0, 3]};
this.dispatch = computeDispatch(
Expand Down Expand Up @@ -73,8 +72,8 @@ export class Conv2DNaiveProgram implements WebGPUProgram {
for (int col = 0; col < filterDims[1]; ++col) {
for (int xChannel = 0; xChannel < xShape[3]; ++xChannel) {
float v = readInp(batch,
pad[0] + coords[1] * stride[0] + ${dilationHeight} * row,
pad[1] + coords[2] * stride[1] + ${dilationWidth} * col,
pad[0] + coords[1] * stride[0] + dilation[0] * row,
pad[1] + coords[2] * stride[1] + dilation[1] * col,
xChannel);
float f = readFilt(row, col, xChannel, outChannel);
acc += v * f;
Expand All @@ -85,5 +84,6 @@ export class Conv2DNaiveProgram implements WebGPUProgram {
writeResult(batch, coords[1], coords[2], outChannel, acc);
}
`;
this.shaderKey = 'conv2dnaive';
}
}
32 changes: 11 additions & 21 deletions tfjs-backend-webgpu/src/kernels/depthwise_conv2d_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,38 +21,26 @@ import {WebGPUProgram} from './webgpu_program';

export class DepthwiseConv2DProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
userCode: string;
dispatchLayout: {x: number[], y: number[], z: number[]};
dispatch: [number, number, number];
variableNames = ['x', 'W'];
uniforms = 'ivec2 filterDims, pad, stride;';
uniforms = 'ivec2 filterDims, pad, stride, dilation, inDims;';
workGroupSize: [number, number, number] = [4, 8, 4];

constructor(convInfo: backend_util.Conv2DInfo) {
this.outputShape = convInfo.outShape;
this.dispatchLayout = {x: [2], y: [1], z: [0, 3]};
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workGroupSize);
const xNumRows = convInfo.inHeight;
const xNumCols = convInfo.inWidth;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const channelMul = convInfo.outChannels / convInfo.inChannels;

util.assert(
convInfo.dataFormat === 'channelsLast',
() => 'TODO: NCHW is unimplemented');

this.userCode = `
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
const ivec2 pads = ivec2(${padTop}, ${padLeft});

void writeResult(int batch, int row, int col, int chan, float value) {
ivec4 coord = ivec4(batch, row, col, chan);
if (coordsInBounds(coord, outShape)) {
Expand All @@ -63,7 +51,7 @@ export class DepthwiseConv2DProgram implements WebGPUProgram {
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
ivec2 xRCCorner = coords.yz * strides - pads;
ivec2 xRCCorner = coords.yz * stride - pad;
int d2 = coords[3];
int d1 = d2 / ${channelMul};
int q = d2 - d1 * ${channelMul};
Expand All @@ -75,17 +63,17 @@ export class DepthwiseConv2DProgram implements WebGPUProgram {
// ? = to be determined. : = across all values in that axis.
float dotProd = 0.0;
// TODO(xing.xu): Flatten the two for loops and vec4 the operations.
for (int wR = 0; wR < ${filterHeight}; wR++) {
int xR = xRCorner + wR * ${dilationHeight};
for (int wR = 0; wR < filterDims[0]; wR++) {
int xR = xRCorner + wR * dilation[0];

if (xR < 0 || xR >= ${xNumRows}) {
if (xR < 0 || xR >= inDims[0]) {
continue;
}

for (int wC = 0; wC < ${filterWidth}; wC++) {
int xC = xCCorner + wC * ${dilationWidth};
for (int wC = 0; wC < filterDims[1]; wC++) {
int xC = xCCorner + wC * dilation[1];

if (xC < 0 || xC >= ${xNumCols}) {
if (xC < 0 || xC >= inDims[1]) {
continue;
}

Expand All @@ -97,5 +85,7 @@ export class DepthwiseConv2DProgram implements WebGPUProgram {
writeResult(batch, coords[1], coords[2], d2, dotProd);
}
`;

this.shaderKey = `depthwise${channelMul}`;
}
}
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/kernels/fill_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {WebGPUProgram} from './webgpu_program';
export class FillProgram implements WebGPUProgram {
variableNames: string[] = [];
outputShape: number[] = [];
shaderKey: string;
userCode: string;
dispatchLayout: {x: number[]};
dispatch: [number, number, number];
Expand All @@ -46,5 +47,6 @@ export class FillProgram implements WebGPUProgram {
}
}
`;
this.shaderKey = `fill${size}${value}`;
}
}
7 changes: 4 additions & 3 deletions tfjs-backend-webgpu/src/kernels/from_pixels_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import {WebGPUProgram} from './webgpu_program';

export class FromPixelsProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
userCode: string;
variableNames = ['A'];
dispatchLayout: {x: number[]};
dispatch: [number, number, number];

constructor(outputShape: number[]) {
const [height, width, ] = outputShape;
this.outputShape = outputShape;
this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape);
Expand All @@ -38,7 +38,7 @@ export class FromPixelsProgram implements WebGPUProgram {
int texR = coords[0];
int texC = coords[1];
int depth = coords[2];
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0);
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(outShape.yx);

vec4 values = texelFetch(A, uv);
float value;
Expand All @@ -55,5 +55,6 @@ export class FromPixelsProgram implements WebGPUProgram {
setOutput(floor(value * 255.0 + 0.5));
}
`;
this.shaderKey = 'fromPixel';
}
}
}
10 changes: 6 additions & 4 deletions tfjs-backend-webgpu/src/kernels/matmul_packed_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ export function makeMatMulPackedSource(workPerThread: number[]): string {

export class MatMulPackedProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
userCode: string;
dispatchLayout: {x: number[], y: number[], z: number[]};
dispatch: [number, number, number];
Expand All @@ -135,12 +136,13 @@ export class MatMulPackedProgram implements WebGPUProgram {
const tileInner = tileBOuter;
const tileSizeA = [tileAOuter, tileInner];
const tileSizeB = [tileInner, tileBOuter];

const sampleA = tilesFitEvenlyIntoShape(tileSizeA, aShape.slice(1)) ?
const fitA = tilesFitEvenlyIntoShape(tileSizeA, aShape.slice(1));
const sampleA = fitA ?
`A[row * dimInner + col]` :
`coordsInBounds(ivec2(row, col), ivec2(dimAOuter, dimInner)) ?
A[row * dimInner + col] : 0`;
const sampleB = tilesFitEvenlyIntoShape(tileSizeB, bShape.slice(1)) ?
const fitB = tilesFitEvenlyIntoShape(tileSizeB, bShape.slice(1));
const sampleB = fitB ?
`B[row * dimBOuter + col]` :
`coordsInBounds(ivec2(row, col), ivec2(dimInner, dimBOuter)) ?
B[row * dimBOuter + col] : 0`;
Expand All @@ -149,7 +151,6 @@ export class MatMulPackedProgram implements WebGPUProgram {
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workGroupSize,
[workPerThread, workPerThread, 1]);

this.userCode = `
int dimAOuter = aShape[1];
int dimInner = aShape[2];
Expand All @@ -173,5 +174,6 @@ export class MatMulPackedProgram implements WebGPUProgram {
mm_matMul(dimAOuter, dimInner, dimBOuter);
}
`;
this.shaderKey = `matmulpacked${this.workPerThread}${fitA}${fitB}`;
}
}
Loading