Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tune elementwise ops for ROCm #21754

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 7 additions & 3 deletions aten/src/ATen/cuda/detail/OffsetCalculator.cuh
Expand Up @@ -9,15 +9,19 @@
/// OffsetCalculator calculates the offset in bytes of a linear index for NARGS
/// operands that share the same shape, but may have different strides.

#ifdef __HIP_PLATFORM_HCC__
constexpr int MAX_DIMS = 16;
#else
constexpr int MAX_DIMS = 25;
#endif

template <int NARGS, typename index_t = uint32_t>
struct OffsetCalculator {
static constexpr int MAX_DIMS = 25;

// The offset for each argument (in bytes). Wrapper around fixed-size array.
using offset_type = at::detail::Array<index_t, NARGS>;

OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides) : dims(dims) {
TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>25) dims");
TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims");
for (int i = 0; i < MAX_DIMS; ++i) {
if (i < dims) {
sizes_[i] = IntDivider<index_t>(sizes[i]);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Loops.cuh
Expand Up @@ -25,7 +25,7 @@
#ifdef __HIP_PLATFORM_HCC__
static constexpr int launch_size_1d = 1024;
static constexpr int launch_size_nd = 1024;
static constexpr int launch_bound2 = 16;
static constexpr int launch_bound2 = 1;
#else
static constexpr int launch_size_1d = 512;
static constexpr int launch_size_nd = 128;
Expand Down
5 changes: 4 additions & 1 deletion aten/src/THC/THCIntegerDivider.cuh
Expand Up @@ -2,6 +2,9 @@
#define THC_INTEGER_DIVIDER_INC

#include <assert.h>
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
bddppq marked this conversation as resolved.
Show resolved Hide resolved
#include <cuda_runtime.h>
#endif

// A utility class to implement integer division by muliplication, given a fixed
// divisor.
Expand Down Expand Up @@ -91,7 +94,7 @@ struct IntDivider<unsigned int> {
}

__host__ __device__ inline unsigned int div(unsigned int n) const {
#ifdef __CUDA_ARCH__
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
// 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and
// 'm1'.
unsigned int t = __umulhi(n, m1);
Expand Down