diff --git a/fbgemm_gpu/src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_gpu.cpp b/fbgemm_gpu/src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_gpu.cpp index 37d6c5a44..a2cb4a866 100644 --- a/fbgemm_gpu/src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_gpu.cpp +++ b/fbgemm_gpu/src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_gpu.cpp @@ -164,6 +164,35 @@ void all_to_one( static auto intermediate_nodes = get_intermediate_node(fbgemm_gpu::get_nvlink_matrix()); + + auto copy_fn = + [&](Tensor& dst, const Tensor& src, at::cuda::CUDAStream& copy_stream) { + if (src.numel() == 0) { + return; + } + + if (src.dim() == 2u) { + AT_CUDA_CHECK(cudaMemcpy2DAsync( + dst.data_ptr(), + dst.stride(0) * dst.element_size(), + src.data_ptr(), + src.stride(0) * src.element_size(), + src.size(1) * src.element_size(), + src.size(0), + cudaMemcpyDeviceToDevice, + copy_stream)); + } else { + TORCH_CHECK(dst.is_contiguous()); + TORCH_CHECK(src.is_contiguous()); + AT_CUDA_CHECK(cudaMemcpyAsync( + dst.data_ptr(), + src.data_ptr(), + src.numel() * src.element_size(), + cudaMemcpyDeviceToDevice, + copy_stream)); + } + }; + for (const auto i : c10::irange(input_tensors.size())) { const auto& src = input_tensors.at(i); Node src_device_id = src.get_device(); @@ -180,15 +209,7 @@ void all_to_one( auto& dst = two_hop_transfers.back().intermediate_tensor; at::cuda::CUDAStream copy_stream = at::cuda::getCurrentCUDAStream(src_device_id); - AT_CUDA_CHECK(cudaMemcpy2DAsync( - dst.data_ptr(), - dst.stride(0) * dst.element_size(), - src.data_ptr(), - src.stride(0) * src.element_size(), - src.size(1) * src.element_size(), - src.size(0), - cudaMemcpyDeviceToDevice, - copy_stream)); + copy_fn(dst, src, copy_stream); two_hop_transfers.back().transfer_cuda_event->record(copy_stream); is_two_hop_transfer.push_back(true); } else { @@ -233,23 +254,17 @@ void all_to_one( if (metadata) { continue; } - auto& src = input_tensors[i]; + if (src.numel() == 0) { + continue; + } if (src.device() != src_device) { continue; } auto& dst = output_tensors[i]; // on source device, launch memcpy. - AT_CUDA_CHECK(cudaMemcpy2DAsync( - dst.data_ptr(), - dst.stride(0) * dst.element_size(), - src.data_ptr(), - src.stride(0) * src.element_size(), - src.size(1) * src.element_size(), - src.size(0), - cudaMemcpyDeviceToDevice, - copy_stream)); + copy_fn(dst, src, copy_stream); } } @@ -261,6 +276,9 @@ void all_to_one( if (src_device == target_device) { continue; } + if (src.numel() == 0) { + continue; + } // intermediate rank at::cuda::CUDAGuard device_guard(src_device); @@ -279,15 +297,7 @@ void all_to_one( const auto output_index = two_hop_transfer.output_idx; auto& dst = output_tensors.at(output_index); // on source device, launch memcpy. - AT_CUDA_CHECK(cudaMemcpy2DAsync( - dst.data_ptr(), - dst.stride(0) * dst.element_size(), - src.data_ptr(), - src.stride(0) * src.element_size(), - src.size(1) * src.element_size(), - src.size(0), - cudaMemcpyDeviceToDevice, - copy_stream)); + copy_fn(dst, src, copy_stream); } // Do the same-GPU cases. @@ -299,15 +309,7 @@ void all_to_one( // single device memcpy, not that src_device == dst_device. at::cuda::CUDAStream copy_stream = at::cuda::getCurrentCUDAStream(target_device_index); - AT_CUDA_CHECK(cudaMemcpy2DAsync( - dst.data_ptr(), - dst.stride(0) * dst.element_size(), - src.data_ptr(), - src.stride(0) * src.element_size(), - src.size(1) * src.element_size(), - src.size(0), - cudaMemcpyDeviceToDevice, - copy_stream)); + copy_fn(dst, src, copy_stream); } } } @@ -621,7 +623,7 @@ std::vector all_to_one_device( TORCH_CHECK(tensor.is_cuda()); output_tensors.push_back( tensor.device() != target_device - ? at::empty(tensor.sizes(), tensor.options().device(target_device)) + ? at::empty_like(tensor, tensor.options().device(target_device)) : tensor); } all_to_one( diff --git a/fbgemm_gpu/test/merge_pooled_embeddings_test.py b/fbgemm_gpu/test/merge_pooled_embeddings_test.py index 889f9c5e0..c4837a1e8 100644 --- a/fbgemm_gpu/test/merge_pooled_embeddings_test.py +++ b/fbgemm_gpu/test/merge_pooled_embeddings_test.py @@ -124,6 +124,7 @@ def ref(pooled_ad_embeddings, batch_indices): num_inputs=st.integers(min_value=1, max_value=10), num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()), r=st.randoms(use_true_random=False), + arbitrary_dim_input=st.randoms(use_true_random=False), ) # Can instantiate 8 contexts which takes a long time. @settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None) @@ -135,10 +136,23 @@ def test_all_to_one_device( num_gpus, # pyre-fixme[2]: Parameter must be annotated. r, + # pyre-fixme[2]: Parameter must be annotated. + arbitrary_dim_input, ) -> None: dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}") with torch.cuda.device(dst_device): - inputs = [torch.randn(10, 20) for _ in range(num_inputs)] + if arbitrary_dim_input: + ranks = torch.randint(0, 5, (num_inputs,)) + # pyre-ignore + dims = torch.randint(0, 10, (ranks.sum().item(),)) + inputs = [] + offset = 0 + for i in range(num_inputs): + rank = ranks[i].item() + inputs.append(torch.randn(dims[offset : offset + rank].tolist())) + offset += rank + else: + inputs = [torch.randn(10, 20) for _ in range(num_inputs)] cuda_inputs = [ input.to(f"cuda:{i % num_gpus}") for i, input in enumerate(inputs) ]