From 0fbbc7acb4deb068ed383445c2e73cb08f66029a Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Wed, 9 Oct 2019 16:27:09 -0700 Subject: [PATCH] Allow `align_to` to take in partially named tensors (#27308) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27308 Currently, `tensor.align_to(*names)` has the restriction that the `tensor` must be fully named. This doesn't need to be the case, when using Ellipsis, we "expand the ellipsis to all unmentioned dimensions, in the order which they appear in the original tensor". For example, consider `tensor: Tensor[None, None, C]`. `tensor.align_to(C, None, None)` is ambiguous because the user might have wanted to switch the order of the None dimensions and there is no way to specify that using this API. However, `tensor.align_to('C', ...)` isn't ambiguous: we can select the two unnamed dimensions in the order in which they appear. To actually implement this, we write a brand-new `align_to(names, ellipsis_idx)` function in c++ that is separate from the regular `align_to(names)` implementation. Ideally we would support "..." as a special name in c++ and combine the two implementations; we'll need to support "..." in c++ in the future but that requires a bit of extra work. In this PR, Python processees the ellipsis and then calls the correct overload. Test Plan: - run tests Differential Revision: D17745179 Pulled By: zou3519 fbshipit-source-id: 9fed06d224215cfb7efecd8c002604baab3c45e6 --- aten/src/ATen/core/NamedTensor.cpp | 35 ++++--- aten/src/ATen/core/NamedTensor.h | 1 + aten/src/ATen/native/NamedTensor.cpp | 113 +++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 4 + test/test_namedtensor.py | 19 +++- torch/_namedtensor_internals.py | 42 +++----- torch/tensor.py | 8 +- 7 files changed, 170 insertions(+), 52 deletions(-) diff --git a/aten/src/ATen/core/NamedTensor.cpp b/aten/src/ATen/core/NamedTensor.cpp index 1d078c9e5449725..fe27b5ef6837ab9 100644 --- a/aten/src/ATen/core/NamedTensor.cpp +++ b/aten/src/ATen/core/NamedTensor.cpp @@ -42,12 +42,6 @@ DimnameList default_names(size_t len) { return DimnameList(&all_unnamed.front(), len); } -void check_names_valid_for(const Tensor& tensor, DimnameList names) { - return impl::check_names_valid_for(tensor.unsafeGetTensorImpl(), names); -} - -namespace impl { - static void check_unique_names(DimnameList names) { // Strategy: Compare each element with the ones that come after it. // Although this is O(N^2), in practice N is small (no more than 25). @@ -62,6 +56,24 @@ static void check_unique_names(DimnameList names) { } } +void check_names_valid_for(const Tensor& tensor, DimnameList names) { + return impl::check_names_valid_for(tensor.unsafeGetTensorImpl(), names); +} + +void check_names_valid_for(int64_t tensor_dim, DimnameList names) { + TORCH_CHECK( + tensor_dim <= kMaxNamedTensorDim, + "Named tensors only support up to ", kMaxNamedTensorDim, " dims: " + "Attempted to create a tensor with dim ", tensor_dim, " with names ", names); + TORCH_CHECK(tensor_dim == names.size(), + "Number of names (", names.size(), ") and " + "number of dimensions in tensor (", tensor_dim, ") ", + "do not match. Attempted to create a tensor with names ", names); + check_unique_names(names); +} + +namespace impl { + static NamedTensorMeta* get_named_tensor_meta(TensorImpl* impl) { if (!NamesMode::is_enabled()) { return nullptr; @@ -77,16 +89,7 @@ static const NamedTensorMeta* get_named_tensor_meta(const TensorImpl* impl) { } void check_names_valid_for(TensorImpl* impl, DimnameList names) { - auto ndim = impl->dim(); - TORCH_CHECK( - ndim <= kMaxNamedTensorDim, - "Named tensors only support up to ", kMaxNamedTensorDim, " dims: " - "Attempted to create a tensor with dim ", ndim, " with names ", names); - TORCH_CHECK(ndim == names.size(), - "Number of names (", names.size(), ") and " - "number of dimensions in tensor (", ndim, ") ", - "do not match. Attempted to create a tensor with names ", names); - check_unique_names(names); + check_names_valid_for(impl->dim(), names); } void internal_set_names_inplace(TensorImpl* impl, optional names) { diff --git a/aten/src/ATen/core/NamedTensor.h b/aten/src/ATen/core/NamedTensor.h index e9d51cdf994c96c..46644e3ecd755a7 100644 --- a/aten/src/ATen/core/NamedTensor.h +++ b/aten/src/ATen/core/NamedTensor.h @@ -75,6 +75,7 @@ struct CAFFE2_API NoNamesGuard { }; void check_names_valid_for(const Tensor& tensor, DimnameList names); +void check_names_valid_for(int64_t tensor_dim, DimnameList names); // Sets the names of `tensor` to be `names`. CAFFE2_API Tensor& internal_set_names_inplace(Tensor& tensor, optional names); diff --git a/aten/src/ATen/native/NamedTensor.cpp b/aten/src/ATen/native/NamedTensor.cpp index 9fa4c1e65aeae52..d76c2526cea69cc 100644 --- a/aten/src/ATen/native/NamedTensor.cpp +++ b/aten/src/ATen/native/NamedTensor.cpp @@ -4,6 +4,8 @@ #include #include +#include + #ifdef BUILD_NAMEDTENSOR namespace at { namespace native { @@ -141,6 +143,117 @@ static Tensor align(const Tensor& tensor, DimnameList names, bool is_aligning_tw return result; } +static int64_t countUnset(std::bitset set, int64_t up_to_idx) { + int64_t result = 0; + for (auto i = 0; i < up_to_idx; ++i) { + if (!set.test(i)) result++; + } + return result; +} + +// Handles `tensor.align_to(*order)` in the case where there is an ellipsis. +// +// Let tensor: Tensor[N, C, H, W]. Consider `tensor.align_to('W', ..., 'N')` +// We expand the `...` to "all unmentioned dimensions, in the order which they +// appear in the original tensor." +// +// `order` is passed in **without** the ellipsis name. This is because ellipsis +// is not a valid name in cpp right now. Future work should be done on making +// ellipsis a valid name. +// +// `ellipsis_idx` is where the ellipsis occurs in the Python call. +// In our example, `tensor.align_to('W', ..., 'N')`, order = ['W', 'N'] and +// ellipsis_idx = 1. +Tensor align_to(const Tensor& tensor, DimnameList order, int64_t ellipsis_idx) { + const auto tensor_names = tensor.names(); + const auto tensor_sizes = tensor.sizes(); + const auto tensor_strides = tensor.strides(); + const auto tensor_dim = tensor.sizes().size(); + constexpr int64_t not_found = -1; + + // General strategy. + // + // Step 1: We compute the following 3 things: + // 1. How many names the ellipsis should expand to + // 2. Which names in `tensor.names` are not mentioned in `order`. + // 3. Where names in `order` occur in tensor, if at all. + // + // Step 2: Compute the new sizes/strides/names. + // First, determine the ndim of the output tensor (this is not obvious) + // by counting the number of names in `tensor` that are not in `order`. + // Next, fill in output sizes/strides/names by using `order` and knowledge + // of which dimensions in `tensor` are unmentioned in `order`. + + std::bitset order_has_tensor_name; + + // tensor_idx_for[i] = j means that the ith name in `order` + // appears in the jth element of tensor. + std::vector tensor_idx_for(order.size(), not_found); + + for (auto order_idx = 0; order_idx < order.size(); ++order_idx) { + const auto name = order[order_idx]; + TORCH_CHECK(name.isBasic(), + "align_to: the desired order of dimensions cannot contain a None name, got ", + order); + auto it = std::find(tensor_names.begin(), tensor_names.end(), name); + if (it == tensor_names.end()) { + continue; + } + auto idx_in_tensor = std::distance(tensor_names.begin(), it); + tensor_idx_for[order_idx] = idx_in_tensor; + order_has_tensor_name.set(idx_in_tensor); + } + + const auto num_ellipsis_names = countUnset(order_has_tensor_name, tensor_dim); + const auto out_dim = num_ellipsis_names + order.size(); + + // Step 2: Now that we know the size of the output tensor, we can use the + // metadata obtained from Step 1 to fill in the new sizes/strides/names + std::vector new_sizes(out_dim, 1); + std::vector new_strides(out_dim, 0); + std::vector new_names(out_dim, Dimname::wildcard()); + + auto setNewSizesStridesNamesFor = [&](int64_t out_dim, int64_t tensor_dim) { + new_sizes[out_dim] = tensor_sizes[tensor_dim]; + new_strides[out_dim] = tensor_strides[tensor_dim]; + new_names[out_dim] = tensor_names[tensor_dim]; + }; + + // Fill in the non-ellipsis dimensions + for (auto order_idx = 0; order_idx < order.size(); ++order_idx) { + auto out_idx = order_idx; + if (order_idx >= ellipsis_idx) { + out_idx = order_idx + num_ellipsis_names; + } + const auto tensor_idx = tensor_idx_for[order_idx]; + if (tensor_idx == not_found) { + // We are adding a new size-one dimension + new_names[out_idx] = order[order_idx]; + continue; + } + setNewSizesStridesNamesFor(out_idx, tensor_idx); + } + + // Fill in the ellipsis dimensions + for (auto tensor_idx = 0; tensor_idx < tensor_dim; ++tensor_idx) { + if (order_has_tensor_name.test(tensor_idx)) { + continue; + } + setNewSizesStridesNamesFor(ellipsis_idx, tensor_idx); + ellipsis_idx++; + } + + check_names_valid_for(out_dim, new_names); + + Tensor result; + { + NoNamesGuard guard; + result = tensor.as_strided(new_sizes, new_strides); + } + internal_set_names_inplace(result, std::move(new_names), /*validate_names=*/false); + return result; +} + Tensor align_to(const Tensor& tensor, DimnameList names) { auto tensor_names = tensor.names(); auto tensor_sizes = tensor.sizes(); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index bb812f3adc86b81..c13f1011581151a 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -69,6 +69,10 @@ variants: method supports_named_tensor: True +- func: align_to(Tensor(a) self, DimnameList order, int ellipsis_idx) -> Tensor(a) + variants: method + supports_named_tensor: True + - func: align_as(Tensor self, Tensor other) -> Tensor use_c10_dispatcher: unboxed_only variants: method diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py index 015ab594fa08ab0..7f4c70ea55711a9 100644 --- a/test/test_namedtensor.py +++ b/test/test_namedtensor.py @@ -1366,6 +1366,11 @@ def test_align_to_ellipsis(self): self.assertEqual(output.names, ['H', 'C', 'W', 'N']) self.assertEqual(output.shape, [3, 2, 5, 7]) + # ... = ['N', 'W'] + output = tensor.align_to('H', 'C', '...') + self.assertEqual(output.names, ['H', 'C', 'N', 'W']) + self.assertEqual(output.shape, [3, 2, 7, 5]) + # ... = ['H', 'C'] output = tensor.align_to('W', '...', 'N') self.assertEqual(output.names, ['W', 'H', 'C', 'N']) @@ -1377,16 +1382,20 @@ def test_align_to_ellipsis(self): self.assertEqual(output.shape, [7, 2, 1, 3, 5]) # Input tensor partially named - partiall_named = create('N:7,None:1') - with self.assertRaisesRegex(RuntimeError, "All input dims must be named"): - partiall_named.align_to('...', 'N') + partially_named = create('None:2,None:3,None:5,C:7') + output = partially_named.align_to('C', '...') + self.assertEqual(output.names, ['C', None, None, None]) + self.assertEqual(output.shape, [7, 2, 3, 5]) + + with self.assertRaisesRegex(RuntimeError, "order of dimensions cannot contain a None"): + partially_named.align_to('C', None, '...') # Input order partially named - with self.assertRaisesRegex(RuntimeError, "desired order must not contain None"): + with self.assertRaisesRegex(RuntimeError, "cannot contain a None name"): tensor.align_to('...', 'N', None) # Input order duplicate names - with self.assertRaisesRegex(RuntimeError, "Duplicate names"): + with self.assertRaisesRegex(RuntimeError, "duplicate names"): tensor.align_to('...', 'N', 'N') def test_align_as(self): diff --git a/torch/_namedtensor_internals.py b/torch/_namedtensor_internals.py index ccca565c4a704d9..124fd9931781890 100644 --- a/torch/_namedtensor_internals.py +++ b/torch/_namedtensor_internals.py @@ -54,6 +54,15 @@ def is_ellipsis(item): else: return item == Ellipsis or item == '...' +def single_ellipsis_index(names, fn_name): + ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)] + if len(ellipsis_indices) >= 2: + raise RuntimeError('{}: More than one Ellipsis (\'...\') found in names (' + '{}). This function supports up to one Ellipsis.' + .format(fn_name, names)) + if len(ellipsis_indices) == 1: + return ellipsis_indices[0] + return None def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names): return names[numel_pre_glob:len(names) - numel_post_glob] @@ -64,39 +73,14 @@ def replace_ellipsis_by_position(ellipsis_idx, names, tensor_names): return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1:] -def replace_ellipsis_with_missing_names(ellipsis_idx, names, tensor_names, fn_name): - if any([dimname is None for dimname in tensor_names]): - raise RuntimeError( - '{}: All input dims must be named, got tensor with dims: {}. ' - 'Please use `tensor.refine_names(*names)` to add names to ' - 'unnamed dims'.format(fn_name, tensor_names)) - if any([dimname is None for dimname in names]): - raise RuntimeError('{}: desired order must not contain None, got: {}.' - .format(fn_name, names)) - desired_ordering_set = set(names) - if len(desired_ordering_set) != len(names): - raise RuntimeError('{}: Duplicate names are not allowed in desired ordering, got: {}.' - .format(fn_name, names)) - missing_names = tuple([name for name in tensor_names if name not in desired_ordering_set]) - return names[:ellipsis_idx] + missing_names + names[ellipsis_idx + 1:] - - -def resolve_ellipsis(names, tensor_names, fn_name, is_positional=True): +def resolve_ellipsis(names, tensor_names, fn_name): """ Expands ... inside `names` to be equal to a list of names from `tensor_names`. """ - ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)] - if len(ellipsis_indices) >= 2: - raise RuntimeError('{}: More than one Ellipsis (\'...\') found in names (' - '{}). This function supports up to one Ellipsis.' - .format(fn_name, names)) - if len(ellipsis_indices) == 0: + ellipsis_idx = single_ellipsis_index(names, fn_name) + if ellipsis_idx is None: return names - ellipsis_idx = ellipsis_indices[0] - if is_positional: - return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names) - else: - return replace_ellipsis_with_missing_names(ellipsis_idx, names, tensor_names, fn_name) + return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names) def update_names_with_list(tensor, names, inplace): diff --git a/torch/tensor.py b/torch/tensor.py index f287d9da290eeb1..618efeb3b875ca4 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -2,7 +2,7 @@ import torch import torch._C as _C from torch._namedtensor_internals import update_names, check_serializing_named_tensor, resolve_ellipsis -from torch._namedtensor_internals import unzip_namedshape +from torch._namedtensor_internals import unzip_namedshape, single_ellipsis_index, is_ellipsis from collections import OrderedDict import torch.utils.hooks as hooks import warnings @@ -599,8 +599,12 @@ def align_to(self, *names): The named tensor API is experimental and subject to change. """ + ellipsis_idx = single_ellipsis_index(names, 'align_to') + if ellipsis_idx is None: + return super(Tensor, self).align_to(names) return super(Tensor, self).align_to( - resolve_ellipsis(names, self.names, 'align_to', is_positional=False)) + [name for name in names if not is_ellipsis(name)], + ellipsis_idx) def unflatten(self, dim, namedshape): r"""Unflattens the named dimension :attr:`dim`, viewing it in the shape