Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tune elementwise ops for ROCm #21754

Closed
wants to merge 1 commit into from
Closed

Conversation

colesbury
Copy link
Member

The stride calculation using OffsetCalculator performs poorly with
MAX_DIMS=25. This reduces MAX_DIMS (after coalescing) to 16 on ROCm.
I think it's unlikely that anyone will exceed this limit. If they do,
we can add additional specializations for ROCm with more dimensions.

I'm not sure about the underlying cause. With MAX_DIM=25, the add kernel's params
is ~648 bytes vs. ~424 bytes with MAX_DIM=16. The kernel instruction footprint is
bigger too, but most of these instructions are never executed and most kernel parameters
are never loaded because the typical dimensionality is much smaller.

Mini benchmark here:
https://gist.github.com/colesbury/1e917ae6a0ca9d24712121b92fed4c8f

(broadcasting operations are much faster)

cc @iotamudelta

The stride calculation using OffsetCalculator performs poorly with
MAX_DIMS=25. This reduces MAX_DIMS (after coalescing) to 16 on ROCm.
I think it's unlikely that anyone will exceed this limit. If they do,
we can add additional slower specializations for ROCm.

<insert benchmark here>
@pytorchbot pytorchbot added module: cuda Related to torch.cuda, and CUDA support in general module: operators labels Jun 13, 2019
@colesbury colesbury requested a review from gchanan June 13, 2019 19:06
@colesbury
Copy link
Member Author

This seems to be an issue with the kernel argument size and not the kernel instruction size. I've noticed that performance drops dramatically once kernarg_segment_byte_size hits 512 bytes even if there aren't even any instructions present to load from these extra kernel arguments.

aten/src/THC/THCIntegerDivider.cuh Show resolved Hide resolved
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@colesbury has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@bddppq bddppq added the module: rocm AMD GPU support for Pytorch label Jun 13, 2019
zdevito pushed a commit to zdevito/ATen that referenced this pull request Jun 14, 2019
Summary:
```
The stride calculation using OffsetCalculator performs poorly with
MAX_DIMS=25. This reduces MAX_DIMS (after coalescing) to 16 on ROCm.
I think it's unlikely that anyone will exceed this limit. If they do,
we can add additional specializations for ROCm with more dimensions.
```

I'm not sure about the underlying cause. With MAX_DIM=25, the add kernel's params
is ~648 bytes vs. ~424 bytes with MAX_DIM=16. The kernel instruction footprint is
bigger too, but most of these instructions are never executed and most kernel parameters
are never loaded because the typical dimensionality is much smaller.

Mini benchmark here:
https://gist.github.com/colesbury/1e917ae6a0ca9d24712121b92fed4c8f

(broadcasting operations are much faster)

cc iotamudelta
Pull Request resolved: pytorch/pytorch#21754

Reviewed By: bddppq

Differential Revision: D15811906

Pulled By: colesbury

fbshipit-source-id: 063f92c083d26e2ef2edc98df7ff0400f9432b9d
@facebook-github-bot
Copy link
Contributor

@colesbury merged this pull request in cfd8c58.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: cuda Related to torch.cuda, and CUDA support in general module: rocm AMD GPU support for Pytorch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants