Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
webgpu: conv2d_mm: Fix dispatch dimensions
Browse files Browse the repository at this point in the history
And some other minor tweaks.
  • Loading branch information
kainino0x committed May 7, 2019
1 parent 0940cf0 commit 091f2af
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 13 deletions.
91 changes: 90 additions & 1 deletion src/backends/webgpu/src/conv2d_test.ts
Expand Up @@ -112,7 +112,8 @@ describe('WebGPU backend - convolution tests', () => {
expectArraysClose(resultData, new Float32Array([20]));
});

it('x=[2,2,2,1] f=[1,1,1,1] s=1 d=1 p=0', async () => {
// TODO: implement batching for conv2d
xit('x=[2,2,2,1] f=[1,1,1,1] s=1 d=1 p=0', async () => {
const inputDepth = 1;
const inShape: [number, number, number, number] = [2, 2, 2, inputDepth];
const outputDepth = 1;
Expand Down Expand Up @@ -172,4 +173,92 @@ describe('WebGPU backend - convolution tests', () => {
expectArraysClose(
resultData, new Float32Array([57, 71, 85, 36, 30, 39, 48, 52]));
});

// TODO: implement batching for conv2d
xit('x=[2,2,2,2] f=[2,2,2,2] s=1 d=1 p=same', async () => {
const x = tf.tensor4d(
[1, 2, 4, 8, 2, 4, 5, 8, 9, 4, 1, 2, 6, 4, 5, 5], [2, 2, 2, 2]);
const w = tf.tensor4d(
[3, 5, 7, 9, 1, 3, 8, 3, 2, 3, 4, 5, 8, 9, 1, 2], [2, 2, 2, 2]);
const result = tf.conv2d(x, w, 1, 'same');

const resultData = await result.data();
expect(result.shape).toEqual([2, 2, 2, 2]);
expectArraysClose(
resultData, new Float32Array([
153, 146, 110, 147, 103, 85, 71, 97, 145, 183, 47, 63, 91, 96, 50, 70
]));
});

it('x=[1,32,32,4] f=[2,2,4,1] s=1 d=1 p=same', async () => {
let i = 1;
const xData = Array.from({length: 1 * 32 * 32 * 4}, () => (i++) % 3);
const wData = Array.from({length: 2 * 2 * 4 * 1}, () => (i++) % 5);
const exp = [
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 13, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 13, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 19, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 13, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 13, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 19, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 13, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 13, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 19,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 13, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 13, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 19, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 13, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 13, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 19, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 13, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 13, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 19,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 13, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 13, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 19, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 13, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 13, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 19, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 13, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 13, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 19,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 13, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 13, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 19, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30, 32, 34, 30,
32, 13, 20, 21, 16, 20, 21, 16, 20, 21, 16, 20, 21, 16, 20, 21, 16, 20,
21, 16, 20, 21, 16, 20, 21, 16, 20, 21, 16, 20, 21, 16, 20, 8
];

const x = tf.tensor4d(xData, [1, 32, 32, 4]);
const w = tf.tensor4d(wData, [2, 2, 4, 1]);

const result = tf.conv2d(x, w, 1, 'same');
const resultData = await result.data();
expectArraysClose(resultData, new Float32Array(exp));
});
});
34 changes: 22 additions & 12 deletions src/backends/webgpu/src/kernels/conv2d_mm_webgpu.ts
Expand Up @@ -38,8 +38,9 @@ export class Conv2DMMProgram implements WebGPUProgram {

constructor(convInfo: Conv2DInfo, workPerThread: number) {
this.outputShape = convInfo.outShape;
const dispatchLayout = {x: [1], y: [2], z: [0, 3]};

tf.util.assert(
convInfo.batchSize === 1, () => 'TODO: batching is unimplemented');
tf.util.assert(
convInfo.dataFormat === 'channelsLast',
() => 'TODO: NCHW is unimplemented');
Expand All @@ -56,8 +57,15 @@ export class Conv2DMMProgram implements WebGPUProgram {
elementsPerThread = [workPerThread, workPerThread, 1];
matMulSource = makeMatMulPackedSource(workPerThread);
}

const dispatchLayout = {x: [1], y: [2], z: [0]};
const matMulOutShape = [
convInfo.outShape[0],
convInfo.outShape[1] * convInfo.outShape[2],
convInfo.outShape[3]
];
this.dispatch = computeDispatch(
dispatchLayout, this.outputShape, this.workGroupSize,
dispatchLayout, matMulOutShape, this.workGroupSize,
elementsPerThread);

this.userCode = `
Expand All @@ -68,33 +76,35 @@ export class Conv2DMMProgram implements WebGPUProgram {
all(lessThan(coord, shape));
}
${generateGetOutputCoords(dispatchLayout, this.outputShape.length)}
${generateGetOutputCoords(dispatchLayout, matMulOutShape.length)}
int batch;
float mm_readA(uint row, uint col) {
int r = int(row), c = int(col);
ivec4 coord = ivec4(
(col / WShape[1]) % WShape[0],
col % WShape[1],
col / (WShape[1] * WShape[0]),
row);
(c / WShape[1]) % WShape[0],
c % WShape[1],
c / (WShape[0] * WShape[1]),
r);
ivec4 shape = ivec4(WShape, xShape[3], outShape[3]);
return coordIsValid(coord, shape) ? W[getFlatIndex(coord, shape)] : 0;
}
float mm_readB(uint row, uint col) {
int outRow = int(col) / outShape[2];
int outCol = int(col) % outShape[2];
int r = int(row), c = int(col);
int outRow = c / outShape[2];
int outCol = c % outShape[2];
int WRow = (int(row) / WShape[1]) % WShape[0];
int WCol = int(row) % WShape[1];
int WRow = (r / WShape[1]) % WShape[0];
int WCol = r % WShape[1];
ivec4 coord = ivec4(
batch,
pad[0] + outRow * stride[0] + WRow,
pad[1] + outCol * stride[1] + WCol,
row / (WShape[1] * WShape[0]));
r / (WShape[0] * WShape[1]));
return coordIsValid(coord, xShape) ?
x[getFlatIndex(coord, xShape)] : 0;
}
Expand Down

0 comments on commit 091f2af

Please sign in to comment.