diff --git a/csrc/cuda/spmm_cuda.cu b/csrc/cuda/spmm_cuda.cu index 2a98ebe2..c58e8f84 100644 --- a/csrc/cuda/spmm_cuda.cu +++ b/csrc/cuda/spmm_cuda.cu @@ -63,9 +63,9 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data, #pragma unroll for (int i = 0; i < 32; i++) { // Communication between all threads in a warp. - mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i); + mat_rows[i] = SHFL_SYNC(FULL_MASK, mat_row, i); if (HAS_VALUE) - vals[i] = __shfl_sync(FULL_MASK, val, i); + vals[i] = SHFL_SYNC(FULL_MASK, val, i); } #pragma unroll @@ -179,7 +179,7 @@ spmm_value_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data, #pragma unroll for (int i = 32 / 2; i > 0; i /= 2) { // Parallel reduction inside a warp. - val += __shfl_down_sync(FULL_MASK, val, i); + val += SHFL_DOWN_SYNC(FULL_MASK, val, i); } if (lane_idx == 0) { diff --git a/csrc/cuda/utils.cuh b/csrc/cuda/utils.cuh index 0cc55d13..ba4f3a11 100644 --- a/csrc/cuda/utils.cuh +++ b/csrc/cuda/utils.cuh @@ -17,13 +17,15 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask, return __shfl_down_sync(mask, var.operator __half(), delta); } - #ifdef USE_ROCM - __device__ __inline__ at::Half __ldg(const at::Half* ptr) { - return __ldg(reinterpret_cast(ptr)); - } - #define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta) - #define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta) - #else - #define SHFL_UP_SYNC __shfl_up_sync - #define SHFL_DOWN_SYNC __shfl_down_sync - #endif +#ifdef USE_ROCM +__device__ __inline__ at::Half __ldg(const at::Half* ptr) { + return __ldg(reinterpret_cast(ptr)); +} +#define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta) +#define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta) +#define SHFL_SYNC(mask, var, delta) __shfl(var, delta) +#else +#define SHFL_UP_SYNC __shfl_up_sync +#define SHFL_DOWN_SYNC __shfl_down_sync +#define SHFL_SYNC __shfl_sync +#endif diff --git a/test/test_matmul.py b/test/test_matmul.py index 3ec14356..bf7ad129 100644 --- a/test/test_matmul.py +++ b/test/test_matmul.py @@ -43,14 +43,13 @@ def test_spmm(dtype, device, reduce): out = matmul(src, other, reduce) out.backward(grad_out) + atol = 1e-7 if dtype == torch.float16 or dtype == torch.bfloat16: - assert torch.allclose(expected, out, atol=1e-1) - assert torch.allclose(expected_grad_value, value.grad, atol=1e-1) - assert torch.allclose(expected_grad_other, other.grad, atol=1e-1) - else: - assert torch.allclose(expected, out) - assert torch.allclose(expected_grad_value, value.grad) - assert torch.allclose(expected_grad_other, other.grad) + atol = 1e-1 + + assert torch.allclose(expected, out, atol=atol) + assert torch.allclose(expected_grad_value, value.grad, atol=atol) + assert torch.allclose(expected_grad_other, other.grad, atol=atol) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))