Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

webgpu: conv2d_mm: Fix dispatch dimensions #1745

Merged
merged 2 commits into from May 7, 2019

Conversation

kainino0x
Copy link
Contributor

@kainino0x kainino0x commented May 7, 2019

And some other minor tweaks.

To see the logs from the Cloud Build CI, please join either
our discussion
or announcement mailing list.


This change is Reviewable

@kainino0x kainino0x force-pushed the fix-conv-mm branch 2 times, most recently from 091f2af to 11d984f Compare May 7, 2019 15:49
Copy link
Contributor Author

@kainino0x kainino0x left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained


src/backends/webgpu/src/kernels/conv2d_mm_webgpu.ts, line 40 at r2 (raw file):

  constructor(convInfo: Conv2DInfo, workPerThread: number) {
    this.outputShape = convInfo.outShape;

I think you were talking about moving the computeDispatch out of the individual ops.
Does the exported outputShape need be matMulOutShape below?

Will that break the actual tensor output shape? (Do we need both an outputShape and a shapeForDispatch?) I'd be in favor of just keeping the computeDispatch in the op.

@annxingyuan annxingyuan self-requested a review May 7, 2019 15:53
Copy link
Collaborator

@annxingyuan annxingyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

And some other minor tweaks.
@annxingyuan annxingyuan merged commit 678d88b into tensorflow:master May 7, 2019
@kainino0x kainino0x deleted the fix-conv-mm branch May 22, 2019 16:37
axinging added a commit to qjia7/tfjs-webgl2compute that referenced this pull request Jul 9, 2019
This CL gurantees below test case run correctly. But leads to significant performance regression on some case.
We will fix this later.

Merge from this:
tensorflow/tfjs-core#1745

Test case is from:
https://github.com/tensorflow/tfjs-core/pull/1745/files#diff-15289bfca30229483003cbaea418596aR191
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
2 participants