diff --git a/aten/src/ATen/core/NamedTensor.cpp b/aten/src/ATen/core/NamedTensor.cpp index 1d078c9e5449725..fe27b5ef6837ab9 100644 --- a/aten/src/ATen/core/NamedTensor.cpp +++ b/aten/src/ATen/core/NamedTensor.cpp @@ -42,12 +42,6 @@ DimnameList default_names(size_t len) { return DimnameList(&all_unnamed.front(), len); } -void check_names_valid_for(const Tensor& tensor, DimnameList names) { - return impl::check_names_valid_for(tensor.unsafeGetTensorImpl(), names); -} - -namespace impl { - static void check_unique_names(DimnameList names) { // Strategy: Compare each element with the ones that come after it. // Although this is O(N^2), in practice N is small (no more than 25). @@ -62,6 +56,24 @@ static void check_unique_names(DimnameList names) { } } +void check_names_valid_for(const Tensor& tensor, DimnameList names) { + return impl::check_names_valid_for(tensor.unsafeGetTensorImpl(), names); +} + +void check_names_valid_for(int64_t tensor_dim, DimnameList names) { + TORCH_CHECK( + tensor_dim <= kMaxNamedTensorDim, + "Named tensors only support up to ", kMaxNamedTensorDim, " dims: " + "Attempted to create a tensor with dim ", tensor_dim, " with names ", names); + TORCH_CHECK(tensor_dim == names.size(), + "Number of names (", names.size(), ") and " + "number of dimensions in tensor (", tensor_dim, ") ", + "do not match. Attempted to create a tensor with names ", names); + check_unique_names(names); +} + +namespace impl { + static NamedTensorMeta* get_named_tensor_meta(TensorImpl* impl) { if (!NamesMode::is_enabled()) { return nullptr; @@ -77,16 +89,7 @@ static const NamedTensorMeta* get_named_tensor_meta(const TensorImpl* impl) { } void check_names_valid_for(TensorImpl* impl, DimnameList names) { - auto ndim = impl->dim(); - TORCH_CHECK( - ndim <= kMaxNamedTensorDim, - "Named tensors only support up to ", kMaxNamedTensorDim, " dims: " - "Attempted to create a tensor with dim ", ndim, " with names ", names); - TORCH_CHECK(ndim == names.size(), - "Number of names (", names.size(), ") and " - "number of dimensions in tensor (", ndim, ") ", - "do not match. Attempted to create a tensor with names ", names); - check_unique_names(names); + check_names_valid_for(impl->dim(), names); } void internal_set_names_inplace(TensorImpl* impl, optional names) { diff --git a/aten/src/ATen/core/NamedTensor.h b/aten/src/ATen/core/NamedTensor.h index e9d51cdf994c96c..46644e3ecd755a7 100644 --- a/aten/src/ATen/core/NamedTensor.h +++ b/aten/src/ATen/core/NamedTensor.h @@ -75,6 +75,7 @@ struct CAFFE2_API NoNamesGuard { }; void check_names_valid_for(const Tensor& tensor, DimnameList names); +void check_names_valid_for(int64_t tensor_dim, DimnameList names); // Sets the names of `tensor` to be `names`. CAFFE2_API Tensor& internal_set_names_inplace(Tensor& tensor, optional names); diff --git a/aten/src/ATen/native/NamedTensor.cpp b/aten/src/ATen/native/NamedTensor.cpp index 9fa4c1e65aeae52..d76c2526cea69cc 100644 --- a/aten/src/ATen/native/NamedTensor.cpp +++ b/aten/src/ATen/native/NamedTensor.cpp @@ -4,6 +4,8 @@ #include #include +#include + #ifdef BUILD_NAMEDTENSOR namespace at { namespace native { @@ -141,6 +143,117 @@ static Tensor align(const Tensor& tensor, DimnameList names, bool is_aligning_tw return result; } +static int64_t countUnset(std::bitset set, int64_t up_to_idx) { + int64_t result = 0; + for (auto i = 0; i < up_to_idx; ++i) { + if (!set.test(i)) result++; + } + return result; +} + +// Handles `tensor.align_to(*order)` in the case where there is an ellipsis. +// +// Let tensor: Tensor[N, C, H, W]. Consider `tensor.align_to('W', ..., 'N')` +// We expand the `...` to "all unmentioned dimensions, in the order which they +// appear in the original tensor." +// +// `order` is passed in **without** the ellipsis name. This is because ellipsis +// is not a valid name in cpp right now. Future work should be done on making +// ellipsis a valid name. +// +// `ellipsis_idx` is where the ellipsis occurs in the Python call. +// In our example, `tensor.align_to('W', ..., 'N')`, order = ['W', 'N'] and +// ellipsis_idx = 1. +Tensor align_to(const Tensor& tensor, DimnameList order, int64_t ellipsis_idx) { + const auto tensor_names = tensor.names(); + const auto tensor_sizes = tensor.sizes(); + const auto tensor_strides = tensor.strides(); + const auto tensor_dim = tensor.sizes().size(); + constexpr int64_t not_found = -1; + + // General strategy. + // + // Step 1: We compute the following 3 things: + // 1. How many names the ellipsis should expand to + // 2. Which names in `tensor.names` are not mentioned in `order`. + // 3. Where names in `order` occur in tensor, if at all. + // + // Step 2: Compute the new sizes/strides/names. + // First, determine the ndim of the output tensor (this is not obvious) + // by counting the number of names in `tensor` that are not in `order`. + // Next, fill in output sizes/strides/names by using `order` and knowledge + // of which dimensions in `tensor` are unmentioned in `order`. + + std::bitset order_has_tensor_name; + + // tensor_idx_for[i] = j means that the ith name in `order` + // appears in the jth element of tensor. + std::vector tensor_idx_for(order.size(), not_found); + + for (auto order_idx = 0; order_idx < order.size(); ++order_idx) { + const auto name = order[order_idx]; + TORCH_CHECK(name.isBasic(), + "align_to: the desired order of dimensions cannot contain a None name, got ", + order); + auto it = std::find(tensor_names.begin(), tensor_names.end(), name); + if (it == tensor_names.end()) { + continue; + } + auto idx_in_tensor = std::distance(tensor_names.begin(), it); + tensor_idx_for[order_idx] = idx_in_tensor; + order_has_tensor_name.set(idx_in_tensor); + } + + const auto num_ellipsis_names = countUnset(order_has_tensor_name, tensor_dim); + const auto out_dim = num_ellipsis_names + order.size(); + + // Step 2: Now that we know the size of the output tensor, we can use the + // metadata obtained from Step 1 to fill in the new sizes/strides/names + std::vector new_sizes(out_dim, 1); + std::vector new_strides(out_dim, 0); + std::vector new_names(out_dim, Dimname::wildcard()); + + auto setNewSizesStridesNamesFor = [&](int64_t out_dim, int64_t tensor_dim) { + new_sizes[out_dim] = tensor_sizes[tensor_dim]; + new_strides[out_dim] = tensor_strides[tensor_dim]; + new_names[out_dim] = tensor_names[tensor_dim]; + }; + + // Fill in the non-ellipsis dimensions + for (auto order_idx = 0; order_idx < order.size(); ++order_idx) { + auto out_idx = order_idx; + if (order_idx >= ellipsis_idx) { + out_idx = order_idx + num_ellipsis_names; + } + const auto tensor_idx = tensor_idx_for[order_idx]; + if (tensor_idx == not_found) { + // We are adding a new size-one dimension + new_names[out_idx] = order[order_idx]; + continue; + } + setNewSizesStridesNamesFor(out_idx, tensor_idx); + } + + // Fill in the ellipsis dimensions + for (auto tensor_idx = 0; tensor_idx < tensor_dim; ++tensor_idx) { + if (order_has_tensor_name.test(tensor_idx)) { + continue; + } + setNewSizesStridesNamesFor(ellipsis_idx, tensor_idx); + ellipsis_idx++; + } + + check_names_valid_for(out_dim, new_names); + + Tensor result; + { + NoNamesGuard guard; + result = tensor.as_strided(new_sizes, new_strides); + } + internal_set_names_inplace(result, std::move(new_names), /*validate_names=*/false); + return result; +} + Tensor align_to(const Tensor& tensor, DimnameList names) { auto tensor_names = tensor.names(); auto tensor_sizes = tensor.sizes(); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index bb812f3adc86b81..c13f1011581151a 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -69,6 +69,10 @@ variants: method supports_named_tensor: True +- func: align_to(Tensor(a) self, DimnameList order, int ellipsis_idx) -> Tensor(a) + variants: method + supports_named_tensor: True + - func: align_as(Tensor self, Tensor other) -> Tensor use_c10_dispatcher: unboxed_only variants: method diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py index 015ab594fa08ab0..7f4c70ea55711a9 100644 --- a/test/test_namedtensor.py +++ b/test/test_namedtensor.py @@ -1366,6 +1366,11 @@ def test_align_to_ellipsis(self): self.assertEqual(output.names, ['H', 'C', 'W', 'N']) self.assertEqual(output.shape, [3, 2, 5, 7]) + # ... = ['N', 'W'] + output = tensor.align_to('H', 'C', '...') + self.assertEqual(output.names, ['H', 'C', 'N', 'W']) + self.assertEqual(output.shape, [3, 2, 7, 5]) + # ... = ['H', 'C'] output = tensor.align_to('W', '...', 'N') self.assertEqual(output.names, ['W', 'H', 'C', 'N']) @@ -1377,16 +1382,20 @@ def test_align_to_ellipsis(self): self.assertEqual(output.shape, [7, 2, 1, 3, 5]) # Input tensor partially named - partiall_named = create('N:7,None:1') - with self.assertRaisesRegex(RuntimeError, "All input dims must be named"): - partiall_named.align_to('...', 'N') + partially_named = create('None:2,None:3,None:5,C:7') + output = partially_named.align_to('C', '...') + self.assertEqual(output.names, ['C', None, None, None]) + self.assertEqual(output.shape, [7, 2, 3, 5]) + + with self.assertRaisesRegex(RuntimeError, "order of dimensions cannot contain a None"): + partially_named.align_to('C', None, '...') # Input order partially named - with self.assertRaisesRegex(RuntimeError, "desired order must not contain None"): + with self.assertRaisesRegex(RuntimeError, "cannot contain a None name"): tensor.align_to('...', 'N', None) # Input order duplicate names - with self.assertRaisesRegex(RuntimeError, "Duplicate names"): + with self.assertRaisesRegex(RuntimeError, "duplicate names"): tensor.align_to('...', 'N', 'N') def test_align_as(self): diff --git a/torch/_namedtensor_internals.py b/torch/_namedtensor_internals.py index ccca565c4a704d9..124fd9931781890 100644 --- a/torch/_namedtensor_internals.py +++ b/torch/_namedtensor_internals.py @@ -54,6 +54,15 @@ def is_ellipsis(item): else: return item == Ellipsis or item == '...' +def single_ellipsis_index(names, fn_name): + ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)] + if len(ellipsis_indices) >= 2: + raise RuntimeError('{}: More than one Ellipsis (\'...\') found in names (' + '{}). This function supports up to one Ellipsis.' + .format(fn_name, names)) + if len(ellipsis_indices) == 1: + return ellipsis_indices[0] + return None def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names): return names[numel_pre_glob:len(names) - numel_post_glob] @@ -64,39 +73,14 @@ def replace_ellipsis_by_position(ellipsis_idx, names, tensor_names): return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1:] -def replace_ellipsis_with_missing_names(ellipsis_idx, names, tensor_names, fn_name): - if any([dimname is None for dimname in tensor_names]): - raise RuntimeError( - '{}: All input dims must be named, got tensor with dims: {}. ' - 'Please use `tensor.refine_names(*names)` to add names to ' - 'unnamed dims'.format(fn_name, tensor_names)) - if any([dimname is None for dimname in names]): - raise RuntimeError('{}: desired order must not contain None, got: {}.' - .format(fn_name, names)) - desired_ordering_set = set(names) - if len(desired_ordering_set) != len(names): - raise RuntimeError('{}: Duplicate names are not allowed in desired ordering, got: {}.' - .format(fn_name, names)) - missing_names = tuple([name for name in tensor_names if name not in desired_ordering_set]) - return names[:ellipsis_idx] + missing_names + names[ellipsis_idx + 1:] - - -def resolve_ellipsis(names, tensor_names, fn_name, is_positional=True): +def resolve_ellipsis(names, tensor_names, fn_name): """ Expands ... inside `names` to be equal to a list of names from `tensor_names`. """ - ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)] - if len(ellipsis_indices) >= 2: - raise RuntimeError('{}: More than one Ellipsis (\'...\') found in names (' - '{}). This function supports up to one Ellipsis.' - .format(fn_name, names)) - if len(ellipsis_indices) == 0: + ellipsis_idx = single_ellipsis_index(names, fn_name) + if ellipsis_idx is None: return names - ellipsis_idx = ellipsis_indices[0] - if is_positional: - return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names) - else: - return replace_ellipsis_with_missing_names(ellipsis_idx, names, tensor_names, fn_name) + return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names) def update_names_with_list(tensor, names, inplace): diff --git a/torch/tensor.py b/torch/tensor.py index f287d9da290eeb1..618efeb3b875ca4 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -2,7 +2,7 @@ import torch import torch._C as _C from torch._namedtensor_internals import update_names, check_serializing_named_tensor, resolve_ellipsis -from torch._namedtensor_internals import unzip_namedshape +from torch._namedtensor_internals import unzip_namedshape, single_ellipsis_index, is_ellipsis from collections import OrderedDict import torch.utils.hooks as hooks import warnings @@ -599,8 +599,12 @@ def align_to(self, *names): The named tensor API is experimental and subject to change. """ + ellipsis_idx = single_ellipsis_index(names, 'align_to') + if ellipsis_idx is None: + return super(Tensor, self).align_to(names) return super(Tensor, self).align_to( - resolve_ellipsis(names, self.names, 'align_to', is_positional=False)) + [name for name in names if not is_ellipsis(name)], + ellipsis_idx) def unflatten(self, dim, namedshape): r"""Unflattens the named dimension :attr:`dim`, viewing it in the shape