From 4a2aa0f5f1572d8a9caca9496669f8f52cfa1522 Mon Sep 17 00:00:00 2001 From: anjali411 Date: Wed, 27 Jan 2021 09:09:09 -0800 Subject: [PATCH] index_put_ for complex tensors on CUDA (#51148) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51148 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D26102025 Pulled By: anjali411 fbshipit-source-id: b1b6fd12fda03c4520a3c3200226edf352496188 --- aten/src/ATen/native/cuda/Indexing.cu | 2 +- test/test_indexing.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 035dc188c81c..6b3304cff421 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -230,7 +230,7 @@ void index_put_accum_kernel(Tensor & self, const c10::List std::min(std::max(1,nElemBefore), at::cuda::getCurrentDeviceProperties()->maxGridSize[2])); dim3 block(C10_WARP_SIZE, indices_per_block); - AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, value_.scalar_type(), "indexing_backward", [&] { indexing_backward_kernel<<>>( sorted_indices.data_ptr(), diff --git a/test/test_indexing.py b/test/test_indexing.py index b92fd94e8cbd..10e4a9bafe95 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -762,9 +762,9 @@ def test_int_indices(self, device): self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3)) self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3)) - @dtypes(torch.float, torch.bfloat16, torch.long, torch.bool) - @dtypesIfCPU(torch.float, torch.long, torch.bool, torch.bfloat16) - @dtypesIfCUDA(torch.half, torch.long, torch.bool, torch.bfloat16) + @dtypes(torch.cfloat, torch.cdouble, torch.float, torch.bfloat16, torch.long, torch.bool) + @dtypesIfCPU(torch.cfloat, torch.cdouble, torch.float, torch.long, torch.bool, torch.bfloat16) + @dtypesIfCUDA(torch.cfloat, torch.cdouble, torch.half, torch.long, torch.bool, torch.bfloat16) def test_index_put_src_datatype(self, device, dtype): src = torch.ones(3, 2, 4, device=device, dtype=dtype) vals = torch.ones(3, 2, 4, device=device, dtype=dtype)