Skip to content

Commit

Permalink
[PyTorch] Avoid extra Tensor refcounting in _cat_out_cpu
Browse files Browse the repository at this point in the history
We had a local `Tensor` when we only needed a `const Tensor&`.

Differential Revision: [D25544731](https://our.internmc.facebook.com/intern/diff/D25544731/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D25544731/)!

ghstack-source-id: 118559137
Pull Request resolved: #49364
  • Loading branch information
swolchok committed Dec 14, 2020
1 parent 4404b67 commit 82248f9
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ static inline void check_cat_shape_except_dim(const Tensor & first, const Tensor
}
}

static bool should_skip(const Tensor& t) {
return t.numel() == 0 && t.dim() == 1;
}

Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
// previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible
// to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors
Expand All @@ -109,7 +113,6 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
// FIXME: warn if this is the case
bool allSkipped = true;
bool allContiguous = true;
Tensor notSkippedTensor;

// Inputs cannot alias the output tensor
for (int64_t i = 0; i < tensors.size(); i++) {
Expand All @@ -121,19 +124,21 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
}
at::assert_no_internal_overlap(result);

auto should_skip = [](const Tensor& t) { return t.numel() == 0 && t.dim() == 1; };
for (auto const &tensor : tensors) {
if (should_skip(tensor)) {
continue;
const Tensor* pnotSkippedTensor = [](TensorList tensors) -> const Tensor* {
for (auto const &tensor : tensors) {
if (should_skip(tensor)) {
continue;
}
// we've found a non-empty tensor
return &tensor;
}
// we've found a non-empty tensor
allSkipped = false;
notSkippedTensor = tensor;
break;
}
if (allSkipped) {
return nullptr;
}(tensors);

if (!pnotSkippedTensor) {
return result;
}
const Tensor& notSkippedTensor = *pnotSkippedTensor;

TORCH_CHECK(tensors.size() > 0, "expected a non-empty list of Tensors");
TORCH_CHECK(dim <= notSkippedTensor.dim(), "dimension ", dim, "out of range");
Expand Down Expand Up @@ -191,7 +196,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
if (reuse_iterator &&
result.is_contiguous(first_tensor_mem_format) &&
no_type_promotion) {
auto source_slice = notSkippedTensor;
const auto& source_slice = notSkippedTensor;
auto slice_dim_size = source_slice.size(dim);
auto result_slice = result.narrow(dim, 0, slice_dim_size);
auto result_slice_data = result_slice.data_ptr();
Expand Down

0 comments on commit 82248f9

Please sign in to comment.