Skip to content

Commit

Permalink
[PyTorch][Tensor] Introduce tensor.dim_order (#106835)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #106835

This is a stride based attribute for a tensor available in Python.

This can help inspect tensors generated using `torch.empty_permuted(.., physical_layout, ...)`, where physical_layout should match the dim_order returned here. `empty_permuted` will be renamed to use dim_order as the param name in the future. And also help Executorch export pipeline with implementing dim_order based tensors.

Importing in the function to avoid circular dependency on torch.Tensor.

Differential Revision: D48134476

fbshipit-source-id: 9bf7b4fdbcabefd4602dd3820703891f9257e6cb
  • Loading branch information
digantdesai authored and facebook-github-bot committed Aug 24, 2023
1 parent 2fcda65 commit 55ca74c
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ Tensor class reference
Tensor.digamma
Tensor.digamma_
Tensor.dim
Tensor.dim_order
Tensor.dist
Tensor.div
Tensor.div_
Expand Down
20 changes: 20 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7735,6 +7735,26 @@ def test_helper(dim1, dim2, memory_format):
test_helper((3, 3), (3, 3, 3, 3), torch.channels_last)
test_helper((3, 3, 3), (3, 3, 3, 3, 3), torch.channels_last_3d)

def test_dim_order(self):
shape = (2, 3, 5, 7)

t = torch.empty(shape)
self.assertSequenceEqual(t.dim_order(), (0, 1, 2, 3), seq_type=tuple)
# transpose doesn't really change the underlying physical memory
# so execpting dim_order change to reflect that (like strides)
self.assertSequenceEqual(t.transpose(0, 1).dim_order(), (1, 0, 2, 3))

t = torch.empty(shape, memory_format=torch.channels_last)
self.assertSequenceEqual(t.dim_order(), (0, 2, 3, 1))

t = torch.empty((2, 3, 5, 7, 8), memory_format=torch.channels_last_3d)
self.assertSequenceEqual(t.dim_order(), (0, 2, 3, 4, 1))

for dim_order in itertools.permutations(range(4)):
self.assertSequenceEqual(
dim_order, torch.empty_permuted(shape, dim_order).dim_order()
)

def test_subclass_tensors(self):
# raise an error when trying to subclass FloatTensor
with self.assertRaisesRegex(TypeError, "type 'torch.FloatTensor' is not an acceptable base type"):
Expand Down
26 changes: 26 additions & 0 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,32 @@ def to_sparse_coo(self):
"""
return self.to_sparse()

def dim_order(self):
"""
dim_order() -> tuple
Returns a tuple of int describing the dim order or physical layout of :attr:`self`.
Args:
None
Dim order represents how dimensions are laid out in memory,
starting from the outermost to the innermost dimension.
Example::
>>> torch.empty((2, 3, 5, 7)).dim_order()
(0, 1, 2, 3)
>>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).dim_order()
(0, 2, 3, 1)
.. warning::
The dim_order tensor API is experimental and subject to change.
"""
import torch._prims_common as utils
return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self))

def _update_names(self, names, inplace):
if has_torch_function_unary(self):
return handle_torch_function(
Expand Down
2 changes: 2 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12437,6 +12437,8 @@ def merge_dicts(*dicts):
(105, 1, 21, 3)
>>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).stride()
(105, 1, 21, 3)
>>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).dim_order()
(0, 2, 3, 1)
""".format(
**factory_common_args
),
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def get_ignored_functions() -> Set[Callable]:
Tensor._is_any_true,
Tensor._addmm_activation,
Tensor.to_padded_tensor,
Tensor.dim_order,
}


Expand Down

0 comments on commit 55ca74c

Please sign in to comment.