From 665690860b087498cdd8c2d99e97263fac508204 Mon Sep 17 00:00:00 2001 From: Xiaodong Wang Date: Mon, 26 Oct 2020 18:56:17 -0700 Subject: [PATCH] [pytorch/ops] Concat fast path w/ zero tensor (#46805) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46805 The current implementation goes with slow path if there is zero tensor in the list. This is inefficient. Use the fast path for torch.cat even if there are empty tensors. This wastes one thread block for the empty tensor, but still much better than the slow path. Test Plan: CI + sandcastle Differential Revision: D24524441 fbshipit-source-id: 522dea42628207bd77a8dfba39476b1dc3c1de45 --- aten/src/ATen/native/cuda/Shape.cu | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/cuda/Shape.cu b/aten/src/ATen/native/cuda/Shape.cu index 5da007507905..64af6cb268a2 100644 --- a/aten/src/ATen/native/cuda/Shape.cu +++ b/aten/src/ATen/native/cuda/Shape.cu @@ -237,7 +237,12 @@ void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, batchCounter < CAT_ARRAY_BATCH_SIZE && (i+batchCounter) < inputs.size(); ++batchCounter) { - int64_t dimSize = at::native::size(inputs[i+batchCounter], dimension); + int64_t dimSize = 0; + // There is a legacy case where a 1-D empty tensor can be concat with + // high-dimensional tensor + if (inputs[i+batchCounter].numel() > 0) { + dimSize = at::native::size(inputs[i+batchCounter], dimension); + } stackInputs[batchCounter].input = inputs[i+batchCounter].data_ptr(); @@ -338,7 +343,12 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, batchCounter < CAT_ARRAY_BATCH_SIZE && (i+batchCounter) < inputs.size(); ++batchCounter) { - int64_t dimSize = at::native::size(inputs[i+batchCounter], dimension); + int64_t dimSize = 0; + // There is a legacy case where a 1-D empty tensor can be concat with + // high-dimensional tensor + if (inputs[i+batchCounter].numel() > 0) { + dimSize = at::native::size(inputs[i+batchCounter], dimension); + } catMetaData.input[batchCounter] = inputs[i+batchCounter].data_ptr(); catMetaData.offset[batchCounter] = offset; catMetaData.dimSize[batchCounter] = dimSize; @@ -431,7 +441,6 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { auto should_skip = [](const Tensor &t) { return t.dim() == 1 && at::native::size(t, 0) == 0; }; - bool hasSkippedInput = false; const Tensor *notSkippedTensor = NULL; // non-owning reference int nDims = 0; @@ -452,10 +461,8 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { } at::assert_no_internal_overlap(out); - for (int i = 0; i < inputs.size(); i++) - { + for (int i = 0; i < inputs.size(); i++) { if (should_skip(inputs[i])) { - hasSkippedInput = true; continue; } nDims = inputs[i].dim(); @@ -501,11 +508,10 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { // We parallelize the copy if all 6 conditions pass: // // 1. There is more than one input tensor - // 2. No empty inputs - // 3. The out tensor is 32-bit indexable - // 4. The number of dimensions is <= 4 - // 5. All input tensors are contiguous (output tensor may be non-contig) - // 6. All input tensors can use 32-bit indexing + // 2. The out tensor is 32-bit indexable + // 3. The number of dimensions is <= 4 + // 4. All input tensors are contiguous (output tensor may be non-contig) + // 5. All input tensors can use 32-bit indexing const bool all32BitIndexable = std::all_of(inputs.begin(), inputs.end(), [] (const Tensor& t) { @@ -522,7 +528,6 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { }); allSameType = allSameType && (out.scalar_type() == firstType); if (inputs.size() > 1 && - !hasSkippedInput && out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS && at::cuda::detail::canUse32BitIndexMath(out) && allContiguous &&