Skip to content

Commit

Permalink
Fix cudaSetDevice for CUDA 12 (#370)
Browse files Browse the repository at this point in the history
update
  • Loading branch information
rusty1s committed Apr 15, 2024
1 parent 2d55981 commit 85ace25
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions csrc/cuda/convert_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ __global__ void ind2ptr_kernel(const int64_t *ind_data, int64_t *out_data,

torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) {
CHECK_CUDA(ind);
cudaSetDevice(ind.get_device());
c10::cuda::MaybeSetDevice(ind.get_device());

auto out = torch::empty({M + 1}, ind.options());

Expand Down Expand Up @@ -55,7 +55,7 @@ __global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data,

torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) {
CHECK_CUDA(ptr);
cudaSetDevice(ptr.get_device());
c10::cuda::MaybeSetDevice(ptr.get_device());

auto out = torch::empty({E}, ptr.options());
auto ptr_data = ptr.data_ptr<int64_t>();
Expand Down
2 changes: 1 addition & 1 deletion csrc/cuda/diag_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
int64_t M, int64_t N, int64_t k) {
CHECK_CUDA(row);
CHECK_CUDA(col);
cudaSetDevice(row.get_device());
c10::cuda::MaybeSetDevice(row.get_device());

auto E = row.size(0);
auto num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
Expand Down
2 changes: 1 addition & 1 deletion csrc/cuda/rw_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(start);
cudaSetDevice(rowptr.get_device());
c10::cuda::MaybeSetDevice(rowptr.get_device());

CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
Expand Down
4 changes: 2 additions & 2 deletions csrc/cuda/spmm_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
if (optional_value.has_value())
CHECK_CUDA(optional_value.value());
CHECK_CUDA(mat);
cudaSetDevice(rowptr.get_device());
c10::cuda::MaybeSetDevice(rowptr.get_device());

CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
Expand Down Expand Up @@ -201,7 +201,7 @@ torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
CHECK_CUDA(col);
CHECK_CUDA(mat);
CHECK_CUDA(grad);
cudaSetDevice(row.get_device());
c10::cuda::MaybeSetDevice(row.get_device());

mat = mat.contiguous();
grad = grad.contiguous();
Expand Down

0 comments on commit 85ace25

Please sign in to comment.