Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch/cuda] Concat fast path w/ zero tensor #46805

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 17 additions & 12 deletions aten/src/ATen/native/cuda/Shape.cu
Expand Up @@ -237,7 +237,12 @@ void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
batchCounter < CAT_ARRAY_BATCH_SIZE &&
(i+batchCounter) < inputs.size();
++batchCounter) {
int64_t dimSize = at::native::size(inputs[i+batchCounter], dimension);
int64_t dimSize = 0;
// There is a legacy case where a 1-D empty tensor can be concat with
// high-dimensional tensor
if (inputs[i+batchCounter].numel() > 0) {
dimSize = at::native::size(inputs[i+batchCounter], dimension);
}

stackInputs[batchCounter].input =
inputs[i+batchCounter].data_ptr<scalar_t>();
Expand Down Expand Up @@ -338,7 +343,12 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
batchCounter < CAT_ARRAY_BATCH_SIZE &&
(i+batchCounter) < inputs.size();
++batchCounter) {
int64_t dimSize = at::native::size(inputs[i+batchCounter], dimension);
int64_t dimSize = 0;
// There is a legacy case where a 1-D empty tensor can be concat with
// high-dimensional tensor
if (inputs[i+batchCounter].numel() > 0) {
dimSize = at::native::size(inputs[i+batchCounter], dimension);
}
catMetaData.input[batchCounter] = inputs[i+batchCounter].data_ptr<scalar_t>();
catMetaData.offset[batchCounter] = offset;
catMetaData.dimSize[batchCounter] = dimSize;
Expand Down Expand Up @@ -431,7 +441,6 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) {
auto should_skip = [](const Tensor &t) {
return t.dim() == 1 && at::native::size(t, 0) == 0;
};
bool hasSkippedInput = false;

const Tensor *notSkippedTensor = NULL; // non-owning reference
int nDims = 0;
Expand All @@ -452,10 +461,8 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) {
}
at::assert_no_internal_overlap(out);

for (int i = 0; i < inputs.size(); i++)
{
for (int i = 0; i < inputs.size(); i++) {
if (should_skip(inputs[i])) {
hasSkippedInput = true;
continue;
}
nDims = inputs[i].dim();
Expand Down Expand Up @@ -501,11 +508,10 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) {
// We parallelize the copy if all 6 conditions pass:
//
// 1. There is more than one input tensor
// 2. No empty inputs
// 3. The out tensor is 32-bit indexable
// 4. The number of dimensions is <= 4
// 5. All input tensors are contiguous (output tensor may be non-contig)
// 6. All input tensors can use 32-bit indexing
// 2. The out tensor is 32-bit indexable
// 3. The number of dimensions is <= 4
// 4. All input tensors are contiguous (output tensor may be non-contig)
// 5. All input tensors can use 32-bit indexing

const bool all32BitIndexable = std::all_of(inputs.begin(), inputs.end(),
[] (const Tensor& t) {
Expand All @@ -522,7 +528,6 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) {
});
allSameType = allSameType && (out.scalar_type() == firstType);
if (inputs.size() > 1 &&
!hasSkippedInput &&
out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(out) &&
allContiguous &&
Expand Down