Skip to content

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Aug 11, 2023

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 11, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107004

Note: Links to docs will display an error until the docs builds have been completed.

✅ 1 Unrelated Failure

As of commit 80262e7 with merge base ed07821 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

eellison added a commit that referenced this pull request Aug 11, 2023
ghstack-source-id: 15e315d
Pull Request resolved: #107004
@eellison eellison marked this pull request as draft August 11, 2023 01:53
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@eellison eellison changed the title [test] Enable Lowering Channels last Conv1x1 when max autotune is set Aug 15, 2023
@eellison eellison marked this pull request as ready for review August 15, 2023 18:58
… set"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
… set"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@shunting314
Copy link
Contributor

I have a couple of questions.

  1. could you explain more why the gain depends on enabling max-autotune for gemm? So unconditionally calling aten.mm for 1x1 conv does not result in the same gain?
  2. is the speedup number '2.1x -> 2.5x' mentioned in the summary for training or inference? I assume it's inference but would like to double check.

… set"


This can lead to a large speedup when max autotune is set, e.g. resnet 2.1x -> 2.5x, particularly in combination with freezing. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@eellison eellison added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 16, 2023
@eellison
Copy link
Contributor Author

eellison commented Aug 16, 2023

enabling max-autotune for gemm? So unconditionally calling aten.mm for 1x1 conv does not result in the same gain?

The mm shapes when you convert conv to 1x1 are pretty unusual. So I think the cublas heuristics are not as well tuned for them. For example, one addmm in resnet, 200704, 64 @ 64, 256. Additionally, with freezing the batchnorm-relu after becomes just a relu, and with max_autotune we get to fuse the relu afterward which is a very cheap activation.

is the speedup number '2.1x -> 2.5x' mentioned in the summary for training or inference? I assume it's inference but would like to double check.

this is for inference. getting training data as well. but the pure mm shapes are faster, so it should help.

dilation = pad_listlike(dilation, ndim)
output_padding = pad_listlike(output_padding, ndim)

def channels_last_conv():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any intuitive understanding for why this prefers channels last?

Copy link
Contributor Author

@eellison eellison Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're permuting the input to permute(0, 2, 3, 1), if the input is channels-last, it produces a dense tensor. If the input is contiguous, this produces a bad format for mm.

Comment on lines +315 to +319
layout = conv_layout(x, weight, None, **kwargs)
req_stride_order = ir.get_stride_order(
V.graph.sizevars.size_hints(layout.stride)
)
return req_stride_order == ir.NHWC_STRIDE_ORDER
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to make sure I understand -- this is checking that the layout of the output is NHWC, and if it is, then x and w are both NHWC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we lower conv1x1 as mm, we are making the output NHWC. This is checking that we would have made it NHWC anyway, by looking at the output layout.

When you call into aten.convolution, if either x or w is channels last, then the other will be made to be channels last as well with the op, then the kernel is invoked.

x and w are both NHWC

So, this is true with respect to the convolution kernel itself, but not necessarily the inputs to aten::convolution. Just one input needs to be NHWC for the output to be NHWC.

Within the conv1x1 as mm function, we force the input to be channels last strides. So we want to avoid :

  • an extra copy to the input making it channels last where the copy wouldn't have happened otherwise
  • accidentally changing the output striding, which might have downstream effects

… set"


This can lead to a large speedup when max autotune is set, e.g. resnet 2.1x -> 2.5x, particularly in combination with freezing. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
eellison added a commit that referenced this pull request Aug 16, 2023
@eellison
Copy link
Contributor Author

@pytorchbot merge -f "unrelated failure"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants