Skip to content

Commit

Permalink
Tune elementwise ops for ROCm (#21754)
Browse files Browse the repository at this point in the history
Summary:
```
The stride calculation using OffsetCalculator performs poorly with
MAX_DIMS=25. This reduces MAX_DIMS (after coalescing) to 16 on ROCm.
I think it's unlikely that anyone will exceed this limit. If they do,
we can add additional specializations for ROCm with more dimensions.
```

I'm not sure about the underlying cause. With MAX_DIM=25, the add kernel's params
is ~648 bytes vs. ~424 bytes with MAX_DIM=16. The kernel instruction footprint is
bigger too, but most of these instructions are never executed and most kernel parameters
are never loaded because the typical dimensionality is much smaller.

Mini benchmark here:
https://gist.github.com/colesbury/1e917ae6a0ca9d24712121b92fed4c8f

(broadcasting operations are much faster)

cc iotamudelta
Pull Request resolved: pytorch/pytorch#21754

Reviewed By: bddppq

Differential Revision: D15811906

Pulled By: colesbury

fbshipit-source-id: 063f92c083d26e2ef2edc98df7ff0400f9432b9d
  • Loading branch information
colesbury authored and facebook-github-bot committed Jun 13, 2019
1 parent 6384d1c commit b9db1d5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
10 changes: 7 additions & 3 deletions aten/src/ATen/cuda/detail/OffsetCalculator.cuh
Expand Up @@ -9,15 +9,19 @@
/// OffsetCalculator calculates the offset in bytes of a linear index for NARGS
/// operands that share the same shape, but may have different strides.

#ifdef __HIP_PLATFORM_HCC__
constexpr int MAX_DIMS = 16;
#else
constexpr int MAX_DIMS = 25;
#endif

template <int NARGS, typename index_t = uint32_t>
struct OffsetCalculator {
static constexpr int MAX_DIMS = 25;

// The offset for each argument (in bytes). Wrapper around fixed-size array.
using offset_type = at::detail::Array<index_t, NARGS>;

OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides) : dims(dims) {
TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>25) dims");
TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims");
for (int i = 0; i < MAX_DIMS; ++i) {
if (i < dims) {
sizes_[i] = IntDivider<index_t>(sizes[i]);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Loops.cuh
Expand Up @@ -25,7 +25,7 @@
#ifdef __HIP_PLATFORM_HCC__
static constexpr int launch_size_1d = 1024;
static constexpr int launch_size_nd = 1024;
static constexpr int launch_bound2 = 16;
static constexpr int launch_bound2 = 1;
#else
static constexpr int launch_size_1d = 512;
static constexpr int launch_size_nd = 128;
Expand Down
5 changes: 4 additions & 1 deletion aten/src/THC/THCIntegerDivider.cuh
Expand Up @@ -2,6 +2,9 @@
#define THC_INTEGER_DIVIDER_INC

#include <assert.h>
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
#include <cuda_runtime.h>
#endif

// A utility class to implement integer division by muliplication, given a fixed
// divisor.
Expand Down Expand Up @@ -91,7 +94,7 @@ struct IntDivider<unsigned int> {
}

__host__ __device__ inline unsigned int div(unsigned int n) const {
#ifdef __CUDA_ARCH__
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
// 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and
// 'm1'.
unsigned int t = __umulhi(n, m1);
Expand Down

0 comments on commit b9db1d5

Please sign in to comment.