Skip to content

Commit

Permalink
preserve non-dense or overlapping tensor's layout in *_like functions (
Browse files Browse the repository at this point in the history
…#46046)

Summary:
Pull Request resolved: #46046

*_like functions are used in pytorch to create a new tensor with the same shape of the input tensor. But we don’t always preserve the layout permutation of the tensor. Current behavior is that, for a dense and non-overlapping tensor, its layout permutation is preserved. For eg.  passing a channel last contiguous tensor t with ‘shape/stride’  (2, 4, 3, 2)/(24, 1, 8, 4) to empty_like(t) function will create a new tensor with exactly the same ‘shape/stride’ as the input tensor t. However, if the input tensor is non-dense or has overlap, we simply create a contiguous tensor based on input tensor’s shape, so the tensor layout permutation is lost.

This PR preserves the layout permutation for non-dense or overlapping tensor. The strides propagation rule that used in this PR is exactly the same as what is being used in TensorIterator.  The behavior changes are listed below:

| code                                                                                                                                                                                           | old                                                   | new                                                  |
|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------|------------------------------------------------------|
| #strided tensors<br>a=torch.randn(2,3,8)[:,:,::2].permute(2,0,1)<br>print(a.stride())<br>print(a.exp().stride())<br>print((a+a).stride())<br>out = torch.empty(0)<br>torch.add(a,a,out=out)<br>print(out.stride()) | (2, 24, 8) <br>(6, 3, 1) <br>(1, 12, 4) <br>(6, 3, 1) | (2, 24, 8)<br>(1, 12, 4)<br>(1, 12, 4)<br>(1, 12, 4) |
| #memory dense tensors<br>a=torch.randn(3,1,1).as_strided((3,1,1), (1,3,3))<br>print(a.stride(), (a+torch.randn(1)).stride())<br>a=torch.randn(2,3,4).permute(2,0,1)<br>print(a.stride())<br>print(a.exp().stride())<br>print((a+a).stride())<br>out = torch.empty(0)<br>torch.add(a,a,out=out)<br>print(out.stride())                                                                                                                                                                                               |  (1, 3, 3) (1, 1, 1)<br>(1, 12, 4)<br>(6, 3, 1)<br>(1, 12, 4)<br>(6, 3, 1)                                                       |  (1, 3, 3) (1, 3, 3)<br>(1, 12, 4)<br>(1, 12, 4)<br>(1, 12, 4)<br>(1, 12, 4) |

This is to solve the non-dense tensor layout problem in #45505

TODO:
- [x] Fix all the BC broken test cases in pytorch
- [ ] Investigate if any fb internal tests are broken

This change will cover all kinds of non-dense tensors.

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D24288970

Pulled By: glaringlee

fbshipit-source-id: 320fd4e0d1a810a12abfb1441472298c983a368d
  • Loading branch information
lixinyu authored and facebook-github-bot committed Oct 21, 2020
1 parent 2181449 commit a651b87
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 0 deletions.
93 changes: 93 additions & 0 deletions aten/src/ATen/ExpandUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,97 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> inferExpandGeometry(
expandedSizes, expandedStrides);
}


// This function returns a dense and non-overlapping strides, which keeps the same layout permutation
// as the input `tensor_strides`, computed based on the input `tensor_sizes`.
// Note:
// 1. This function expects the inputs `tensor_strides` and `tensor_sizes` are non-dense or overlapping,
// If the inputs are densed and non-overlapping, the output strides will be the same as `tensor_strides`.
// However, this function won't check whether inputs are dense or overlapping, so the whole function will
// still be executed even the inputs are already dense and non-overlapping, this will cause slowness.
//
// Please verify whether the inputs are non-dense or overlapping before calling this function if possible,
// if the inputs come from a tensor, you can check this through `is_non_overlapping_and_dense()`
//
// 2. The strides propagation rule that is used in this function is exactily the same as what is being used in
// TensorIterator. Please refer to https://github.com/pytorch/pytorch/pull/42922 for more details

std::vector<int64_t> infer_dense_strides(IntArrayRef tensor_sizes, IntArrayRef tensor_strides) {

TORCH_CHECK(tensor_sizes.size() == tensor_strides.size(),
"Input sizes and strides should have same size but got ", tensor_sizes.size(), " and ", tensor_strides.size());

size_t ndim = tensor_sizes.size();
if (ndim == 0) {
return {};
}
if (ndim == 1) {
return {1};
}

std::vector<int64_t> perm(ndim);
// initialize perm with n-1, n-2, ..., 1, 0
std::iota(perm.rbegin(), perm.rend(), 0);

// The following sorting algorithm has exactly the same behavior as TensorIterator
// This is to make sure we have the same stride propagation everywhere.

// return -1 if dim0 should come before dim1
// return 1 if dim0 should come after dim1
// return 0 if comparison is ambiguous
auto should_swap = [&](size_t dim0, size_t dim1) {
int64_t stride0 = tensor_strides[dim0];
int64_t stride1 = tensor_strides[dim1];

// if any stride is 0, treat it as ambiguous comparison to
// keep the same behavior as TensorIterator
if (stride0 == 0 || stride1 == 0) {
return 0;
}
if (stride0 < stride1) {
return -1;
}
if (stride0 > stride1) {
return 1;
}
// for equal strides, the dimension with smaller size goes front
if (tensor_sizes[dim0] > tensor_sizes[dim1]) {
return 1;
}
return 0;
};

// Insertion sort (stable) indices in `perm` based on input tensor's stride and shape,
// all dimensions with 0 stride won't move. This is the same behavior as TensorIterator.
// eg. Given tensor with size/stride (6, 5, 4, 3, 2)/(6, 0, 120, 0, 1), the initial `perm`
// is (4, 3, 2, 1, 0) and the sorted `perm` will be (4, 3, 0, 1, 2)
for (int i = 1; i < ndim; ++i) {
int dim1 = i;
for (int dim0 = i - 1; dim0 >= 0; --dim0) {
int comparison = should_swap(perm[dim0], perm[dim1]);
if (comparison > 0) {
std::swap(perm[dim0], perm[dim1]);
dim1 = dim0;
}
else if (comparison < 0) {
break;
}
}
}

// compute output strides which preserves the input tensor's memory layout
std::vector<int64_t> out_strides(ndim);
int64_t curr_stride = 1;
for (size_t i = 0; i < ndim; ++i) {
int64_t idx = perm[i];
out_strides[idx] = curr_stride;
// Note: for size 0, we simply treated it as 1, it really doesn't matter here
// since the total number of element is 0.
if (tensor_sizes[idx] > 1) {
curr_stride *= tensor_sizes[idx];
}
}
return out_strides;
}

} // namespace at
4 changes: 4 additions & 0 deletions aten/src/ATen/ExpandUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ inferExpandGeometry(
IntArrayRef tensor_strides,
IntArrayRef sizes);

CAFFE2_API std::vector<int64_t> infer_dense_strides(
IntArrayRef tensor_sizes,
IntArrayRef tensor_strides);

// True if input shapes are expandable
// NOTE: infer_size did a similar check, please keep them sync if change is needed
inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) {
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ Tensor empty_like(
if (memory_format == MemoryFormat::Preserve) {
if (self.is_non_overlapping_and_dense()) {
result = at::empty_strided(self.sizes(), self.strides(), options.memory_format(c10::nullopt));
} else if (self.layout() == kStrided) {
// If input tensor is not dense and non-overlapping but strided, we will infer an output strides
// which keeps the layout permutation of the input tensor.
std::vector<int64_t> strides = infer_dense_strides(self.sizes(), self.strides());
// See Note [Explicit nullopt MemoryFormat argument]
result = at::empty_strided(self.sizes(), strides, options.memory_format(c10::nullopt));
} else {
// See Note [Explicit nullopt MemoryFormat argument]
result = at::empty(self.sizes(), options.memory_format(self.suggest_memory_format()), c10::nullopt);
Expand Down
12 changes: 12 additions & 0 deletions aten/src/ATen/test/extension_backend_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,20 @@ Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
return a;
}

Tensor empty_strided_override(
IntArrayRef size,
IntArrayRef stride,
c10::optional<c10::ScalarType> dtype,
c10::optional<c10::Layout> layout,
c10::optional<c10::Device> device,
c10::optional<bool> pin_memory) {

return empty_override(size, at::kMSNPU, c10::nullopt);
}

TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
m.impl_UNBOXED("aten::empty.memory_format", empty_override);
m.impl_UNBOXED("aten::empty_strided", empty_strided_override);
m.impl_UNBOXED("aten::add.Tensor", add_override);
}

Expand Down
51 changes: 51 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19260,6 +19260,57 @@ def test_repeated_dim(self, device):
with self.assertRaisesRegex(RuntimeError, e_msg):
op(x, dim=dim)

# Note: This test failed on XLA since its test cases are created by empty_strided which
# doesn't support overlapping sizes/strides in XLA impl
@onlyOnCPUAndCUDA
def test_like_fn_stride_proparation_vs_tensoriterator_unary_op(self, device):
# Test like functions against tensoriterator based unary operator (exp) to
# make sure the returned tensor from like function follows the same stride propergation
# rule as what tensoriterator does for unary operator. The like function's output strides
# is computed on CPU side always, no need to test GPU here.

def compare_helper_(like_fn, t):
te = torch.exp(t)
tl = like_fn(t)
self.assertEqual(te.stride(), tl.stride())
self.assertEqual(te.size(), tl.size())

like_fns = [
lambda t, **kwargs: torch.zeros_like(t, **kwargs),
lambda t, **kwargs: torch.ones_like(t, **kwargs),
lambda t, **kwargs: torch.randint_like(t, 10, 100, **kwargs),
lambda t, **kwargs: torch.randint_like(t, 100, **kwargs),
lambda t, **kwargs: torch.randn_like(t, **kwargs),
lambda t, **kwargs: torch.rand_like(t, **kwargs),
lambda t, **kwargs: torch.full_like(t, 7, **kwargs),
lambda t, **kwargs: torch.empty_like(t, **kwargs)]

# dense non-overlapping tensor,
# non-dense non-overlapping sliced tensor
# non-dense non-overlapping gapped tensor
# non-dense non-overlapping 0 strided tensor
# non-dense overlapping general tensor
# non-dense overlapping sliced tensor
# non-dense overlapping gapped tensor
# non-dense overlapping 0 strided tensor
# non-dense overlapping equal strides
tset = (
torch.randn(4, 3, 2, device=device),
torch.randn(4, 3, 2, device=device)[:, :, ::2],
torch.empty_strided((4, 3, 2), (10, 3, 1), device=device).fill_(1.0),
torch.empty_strided((4, 3, 2), (10, 0, 3), device=device).fill_(1.0),
torch.empty_strided((4, 3, 2), (10, 1, 2), device=device).fill_(1.0),
torch.empty_strided((4, 3, 2), (4, 2, 1), device=device)[:, :, ::2].fill_(1.0),
torch.empty_strided((4, 3, 2), (10, 1, 1), device=device).fill_(1.0),
torch.empty_strided((4, 1, 1, 2), (10, 0, 0, 2), device=device).fill_(1.0),
torch.empty_strided((4, 2, 3), (10, 3, 3), device=device).fill_(1.0))

for like_fn in like_fns:
for t in tset:
for p in permutations(range(t.dim())):
tp = t.permute(p)
compare_helper_(like_fn, tp)

# Tests that compare a device's computation with the (gold-standard) CPU's.
class TestDevicePrecision(TestCase):
exact_dtype = True
Expand Down

0 comments on commit a651b87

Please sign in to comment.