Skip to content

Commit

Permalink
all_to_one cuda support non-2d inputs (#2575)
Browse files Browse the repository at this point in the history
Summary:

Support non-2d tensors in all_to_one.
Sometimes torchrec outputs are empty tensors.

This can be solved on python module side guarding empty tensors, but we d better not to have shape dependent logic for PT2 tracing.

Differential Revision: D57180617
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed May 9, 2024
1 parent c216005 commit 2e5438a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,35 @@ void all_to_one(

static auto intermediate_nodes =
get_intermediate_node(fbgemm_gpu::get_nvlink_matrix());

auto copy_fn =
[&](Tensor& dst, const Tensor& src, at::cuda::CUDAStream& copy_stream) {
if (src.numel() == 0) {
return;
}

if (src.dim() == 2u) {
AT_CUDA_CHECK(cudaMemcpy2DAsync(
dst.data_ptr(),
dst.stride(0) * dst.element_size(),
src.data_ptr(),
src.stride(0) * src.element_size(),
src.size(1) * src.element_size(),
src.size(0),
cudaMemcpyDeviceToDevice,
copy_stream));
} else {
TORCH_CHECK(dst.is_contiguous());
TORCH_CHECK(src.is_contiguous());
AT_CUDA_CHECK(cudaMemcpyAsync(
dst.data_ptr(),
src.data_ptr(),
src.numel() * src.element_size(),
cudaMemcpyDeviceToDevice,
copy_stream));
}
};

for (const auto i : c10::irange(input_tensors.size())) {
const auto& src = input_tensors.at(i);
Node src_device_id = src.get_device();
Expand All @@ -180,15 +209,7 @@ void all_to_one(
auto& dst = two_hop_transfers.back().intermediate_tensor;
at::cuda::CUDAStream copy_stream =
at::cuda::getCurrentCUDAStream(src_device_id);
AT_CUDA_CHECK(cudaMemcpy2DAsync(
dst.data_ptr(),
dst.stride(0) * dst.element_size(),
src.data_ptr(),
src.stride(0) * src.element_size(),
src.size(1) * src.element_size(),
src.size(0),
cudaMemcpyDeviceToDevice,
copy_stream));
copy_fn(dst, src, copy_stream);
two_hop_transfers.back().transfer_cuda_event->record(copy_stream);
is_two_hop_transfer.push_back(true);
} else {
Expand Down Expand Up @@ -233,23 +254,17 @@ void all_to_one(
if (metadata) {
continue;
}

auto& src = input_tensors[i];
if (src.numel() == 0) {
continue;
}
if (src.device() != src_device) {
continue;
}

auto& dst = output_tensors[i];
// on source device, launch memcpy.
AT_CUDA_CHECK(cudaMemcpy2DAsync(
dst.data_ptr(),
dst.stride(0) * dst.element_size(),
src.data_ptr(),
src.stride(0) * src.element_size(),
src.size(1) * src.element_size(),
src.size(0),
cudaMemcpyDeviceToDevice,
copy_stream));
copy_fn(dst, src, copy_stream);
}
}

Expand All @@ -261,6 +276,9 @@ void all_to_one(
if (src_device == target_device) {
continue;
}
if (src.numel() == 0) {
continue;
}

// intermediate rank
at::cuda::CUDAGuard device_guard(src_device);
Expand All @@ -279,15 +297,7 @@ void all_to_one(
const auto output_index = two_hop_transfer.output_idx;
auto& dst = output_tensors.at(output_index);
// on source device, launch memcpy.
AT_CUDA_CHECK(cudaMemcpy2DAsync(
dst.data_ptr(),
dst.stride(0) * dst.element_size(),
src.data_ptr(),
src.stride(0) * src.element_size(),
src.size(1) * src.element_size(),
src.size(0),
cudaMemcpyDeviceToDevice,
copy_stream));
copy_fn(dst, src, copy_stream);
}

// Do the same-GPU cases.
Expand All @@ -299,15 +309,7 @@ void all_to_one(
// single device memcpy, not that src_device == dst_device.
at::cuda::CUDAStream copy_stream =
at::cuda::getCurrentCUDAStream(target_device_index);
AT_CUDA_CHECK(cudaMemcpy2DAsync(
dst.data_ptr(),
dst.stride(0) * dst.element_size(),
src.data_ptr(),
src.stride(0) * src.element_size(),
src.size(1) * src.element_size(),
src.size(0),
cudaMemcpyDeviceToDevice,
copy_stream));
copy_fn(dst, src, copy_stream);
}
}
}
Expand Down Expand Up @@ -621,7 +623,7 @@ std::vector<Tensor> all_to_one_device(
TORCH_CHECK(tensor.is_cuda());
output_tensors.push_back(
tensor.device() != target_device
? at::empty(tensor.sizes(), tensor.options().device(target_device))
? at::empty_like(tensor, tensor.options().device(target_device))
: tensor);
}
all_to_one(
Expand Down
16 changes: 15 additions & 1 deletion fbgemm_gpu/test/merge_pooled_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def ref(pooled_ad_embeddings, batch_indices):
num_inputs=st.integers(min_value=1, max_value=10),
num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()),
r=st.randoms(use_true_random=False),
arbitrary_dim_input=st.randoms(use_true_random=False),
)
# Can instantiate 8 contexts which takes a long time.
@settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None)
Expand All @@ -135,10 +136,23 @@ def test_all_to_one_device(
num_gpus,
# pyre-fixme[2]: Parameter must be annotated.
r,
# pyre-fixme[2]: Parameter must be annotated.
arbitrary_dim_input,
) -> None:
dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}")
with torch.cuda.device(dst_device):
inputs = [torch.randn(10, 20) for _ in range(num_inputs)]
if arbitrary_dim_input:
ranks = torch.randint(0, 5, (num_inputs,))
# pyre-ignore
dims = torch.randint(0, 10, (ranks.sum().item(),))
inputs = []
offset = 0
for i in range(num_inputs):
rank = ranks[i].item()
inputs.append(torch.randn(dims[offset : offset + rank].tolist()))
offset += rank
else:
inputs = [torch.randn(10, 20) for _ in range(num_inputs)]
cuda_inputs = [
input.to(f"cuda:{i % num_gpus}") for i, input in enumerate(inputs)
]
Expand Down

0 comments on commit 2e5438a

Please sign in to comment.