Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Avoid extra Tensor refcounting in _cat_out_cpu #49364

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 17 additions & 12 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ static inline void check_cat_shape_except_dim(const Tensor & first, const Tensor
}
}

static bool should_skip(const Tensor& t) {
return t.numel() == 0 && t.dim() == 1;
}

Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
// previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible
// to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors
Expand All @@ -109,7 +113,6 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
// FIXME: warn if this is the case
bool allSkipped = true;
bool allContiguous = true;
Tensor notSkippedTensor;

// Inputs cannot alias the output tensor
for (int64_t i = 0; i < tensors.size(); i++) {
Expand All @@ -121,19 +124,21 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
}
at::assert_no_internal_overlap(result);

auto should_skip = [](const Tensor& t) { return t.numel() == 0 && t.dim() == 1; };
for (auto const &tensor : tensors) {
if (should_skip(tensor)) {
continue;
const Tensor* pnotSkippedTensor = [](TensorList tensors) -> const Tensor* {
for (auto const &tensor : tensors) {
if (should_skip(tensor)) {
continue;
}
// we've found a non-empty tensor
return &tensor;
}
// we've found a non-empty tensor
allSkipped = false;
notSkippedTensor = tensor;
break;
}
if (allSkipped) {
return nullptr;
}(tensors);

if (!pnotSkippedTensor) {
return result;
}
const Tensor& notSkippedTensor = *pnotSkippedTensor;

TORCH_CHECK(tensors.size() > 0, "expected a non-empty list of Tensors");
TORCH_CHECK(dim <= notSkippedTensor.dim(), "dimension ", dim, "out of range");
Expand Down Expand Up @@ -191,7 +196,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
if (reuse_iterator &&
result.is_contiguous(first_tensor_mem_format) &&
no_type_promotion) {
auto source_slice = notSkippedTensor;
const auto& source_slice = notSkippedTensor;
auto slice_dim_size = source_slice.size(dim);
auto result_slice = result.narrow(dim, 0, slice_dim_size);
auto result_slice_data = result_slice.data_ptr();
Expand Down