Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions csrc/cuda/spmm_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
#pragma unroll
for (int i = 0; i < 32; i++) {
// Communication between all threads in a warp.
mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
mat_rows[i] = SHFL_SYNC(FULL_MASK, mat_row, i);
if (HAS_VALUE)
vals[i] = __shfl_sync(FULL_MASK, val, i);
vals[i] = SHFL_SYNC(FULL_MASK, val, i);
}

#pragma unroll
Expand Down Expand Up @@ -179,7 +179,7 @@ spmm_value_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data,

#pragma unroll
for (int i = 32 / 2; i > 0; i /= 2) { // Parallel reduction inside a warp.
val += __shfl_down_sync(FULL_MASK, val, i);
val += SHFL_DOWN_SYNC(FULL_MASK, val, i);
}

if (lane_idx == 0) {
Expand Down
22 changes: 12 additions & 10 deletions csrc/cuda/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
return __shfl_down_sync(mask, var.operator __half(), delta);
}

#ifdef USE_ROCM
__device__ __inline__ at::Half __ldg(const at::Half* ptr) {
return __ldg(reinterpret_cast<const __half*>(ptr));
}
#define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta)
#define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta)
#else
#define SHFL_UP_SYNC __shfl_up_sync
#define SHFL_DOWN_SYNC __shfl_down_sync
#endif
#ifdef USE_ROCM
__device__ __inline__ at::Half __ldg(const at::Half* ptr) {
return __ldg(reinterpret_cast<const __half*>(ptr));
}
#define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta)
#define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta)
#define SHFL_SYNC(mask, var, delta) __shfl(var, delta)
#else
#define SHFL_UP_SYNC __shfl_up_sync
#define SHFL_DOWN_SYNC __shfl_down_sync
#define SHFL_SYNC __shfl_sync
#endif
13 changes: 6 additions & 7 deletions test/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,13 @@ def test_spmm(dtype, device, reduce):
out = matmul(src, other, reduce)
out.backward(grad_out)

atol = 1e-7
if dtype == torch.float16 or dtype == torch.bfloat16:
assert torch.allclose(expected, out, atol=1e-1)
assert torch.allclose(expected_grad_value, value.grad, atol=1e-1)
assert torch.allclose(expected_grad_other, other.grad, atol=1e-1)
else:
assert torch.allclose(expected, out)
assert torch.allclose(expected_grad_value, value.grad)
assert torch.allclose(expected_grad_other, other.grad)
atol = 1e-1

assert torch.allclose(expected, out, atol=atol)
assert torch.allclose(expected_grad_value, value.grad, atol=atol)
assert torch.allclose(expected_grad_other, other.grad, atol=atol)


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
Expand Down