Skip to content

Commit

Permalink
Allow align_to to take in partially named tensors (#27308)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #27308

Currently, `tensor.align_to(*names)` has the restriction that the
`tensor` must be fully named. This doesn't need to be the case, when
using Ellipsis, we "expand the ellipsis to all unmentioned dimensions,
in the order which they appear in the original tensor".

For example, consider `tensor: Tensor[None, None, C]`.

`tensor.align_to(C, None, None)` is ambiguous because the user might
have wanted to switch the order of the None dimensions and there is no
way to specify that using this API. However, `tensor.align_to('C', ...)`
isn't ambiguous: we can select the two unnamed dimensions in the order
in which they appear.

To actually implement this, we write a brand-new `align_to(names,
ellipsis_idx)` function in c++ that is separate from the regular
`align_to(names)` implementation. Ideally we would support "..." as a
special name in c++ and combine the two implementations; we'll need to
support "..." in c++ in the future but that requires a bit of extra work.
In this PR, Python processees the ellipsis and then calls the correct
overload.

Test Plan: - run tests

Differential Revision: D17745179

Pulled By: zou3519

fbshipit-source-id: 9fed06d224215cfb7efecd8c002604baab3c45e6
  • Loading branch information
zou3519 authored and facebook-github-bot committed Oct 9, 2019
1 parent 7591010 commit 0fbbc7a
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 52 deletions.
35 changes: 19 additions & 16 deletions aten/src/ATen/core/NamedTensor.cpp
Expand Up @@ -42,12 +42,6 @@ DimnameList default_names(size_t len) {
return DimnameList(&all_unnamed.front(), len);
}

void check_names_valid_for(const Tensor& tensor, DimnameList names) {
return impl::check_names_valid_for(tensor.unsafeGetTensorImpl(), names);
}

namespace impl {

static void check_unique_names(DimnameList names) {
// Strategy: Compare each element with the ones that come after it.
// Although this is O(N^2), in practice N is small (no more than 25).
Expand All @@ -62,6 +56,24 @@ static void check_unique_names(DimnameList names) {
}
}

void check_names_valid_for(const Tensor& tensor, DimnameList names) {
return impl::check_names_valid_for(tensor.unsafeGetTensorImpl(), names);
}

void check_names_valid_for(int64_t tensor_dim, DimnameList names) {
TORCH_CHECK(
tensor_dim <= kMaxNamedTensorDim,
"Named tensors only support up to ", kMaxNamedTensorDim, " dims: "
"Attempted to create a tensor with dim ", tensor_dim, " with names ", names);
TORCH_CHECK(tensor_dim == names.size(),
"Number of names (", names.size(), ") and "
"number of dimensions in tensor (", tensor_dim, ") ",
"do not match. Attempted to create a tensor with names ", names);
check_unique_names(names);
}

namespace impl {

static NamedTensorMeta* get_named_tensor_meta(TensorImpl* impl) {
if (!NamesMode::is_enabled()) {
return nullptr;
Expand All @@ -77,16 +89,7 @@ static const NamedTensorMeta* get_named_tensor_meta(const TensorImpl* impl) {
}

void check_names_valid_for(TensorImpl* impl, DimnameList names) {
auto ndim = impl->dim();
TORCH_CHECK(
ndim <= kMaxNamedTensorDim,
"Named tensors only support up to ", kMaxNamedTensorDim, " dims: "
"Attempted to create a tensor with dim ", ndim, " with names ", names);
TORCH_CHECK(ndim == names.size(),
"Number of names (", names.size(), ") and "
"number of dimensions in tensor (", ndim, ") ",
"do not match. Attempted to create a tensor with names ", names);
check_unique_names(names);
check_names_valid_for(impl->dim(), names);
}

void internal_set_names_inplace(TensorImpl* impl, optional<DimnameList> names) {
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/NamedTensor.h
Expand Up @@ -75,6 +75,7 @@ struct CAFFE2_API NoNamesGuard {
};

void check_names_valid_for(const Tensor& tensor, DimnameList names);
void check_names_valid_for(int64_t tensor_dim, DimnameList names);

// Sets the names of `tensor` to be `names`.
CAFFE2_API Tensor& internal_set_names_inplace(Tensor& tensor, optional<DimnameList> names);
Expand Down
113 changes: 113 additions & 0 deletions aten/src/ATen/native/NamedTensor.cpp
Expand Up @@ -4,6 +4,8 @@
#include <ATen/NamedTensorUtils.h>
#include <ATen/core/EnableNamedTensor.h>

#include <bitset>

#ifdef BUILD_NAMEDTENSOR
namespace at { namespace native {

Expand Down Expand Up @@ -141,6 +143,117 @@ static Tensor align(const Tensor& tensor, DimnameList names, bool is_aligning_tw
return result;
}

static int64_t countUnset(std::bitset<kMaxNamedTensorDim> set, int64_t up_to_idx) {
int64_t result = 0;
for (auto i = 0; i < up_to_idx; ++i) {
if (!set.test(i)) result++;
}
return result;
}

// Handles `tensor.align_to(*order)` in the case where there is an ellipsis.
//
// Let tensor: Tensor[N, C, H, W]. Consider `tensor.align_to('W', ..., 'N')`
// We expand the `...` to "all unmentioned dimensions, in the order which they
// appear in the original tensor."
//
// `order` is passed in **without** the ellipsis name. This is because ellipsis
// is not a valid name in cpp right now. Future work should be done on making
// ellipsis a valid name.
//
// `ellipsis_idx` is where the ellipsis occurs in the Python call.
// In our example, `tensor.align_to('W', ..., 'N')`, order = ['W', 'N'] and
// ellipsis_idx = 1.
Tensor align_to(const Tensor& tensor, DimnameList order, int64_t ellipsis_idx) {
const auto tensor_names = tensor.names();
const auto tensor_sizes = tensor.sizes();
const auto tensor_strides = tensor.strides();
const auto tensor_dim = tensor.sizes().size();
constexpr int64_t not_found = -1;

// General strategy.
//
// Step 1: We compute the following 3 things:
// 1. How many names the ellipsis should expand to
// 2. Which names in `tensor.names` are not mentioned in `order`.
// 3. Where names in `order` occur in tensor, if at all.
//
// Step 2: Compute the new sizes/strides/names.
// First, determine the ndim of the output tensor (this is not obvious)
// by counting the number of names in `tensor` that are not in `order`.
// Next, fill in output sizes/strides/names by using `order` and knowledge
// of which dimensions in `tensor` are unmentioned in `order`.

std::bitset<kMaxNamedTensorDim> order_has_tensor_name;

// tensor_idx_for[i] = j means that the ith name in `order`
// appears in the jth element of tensor.
std::vector<int64_t> tensor_idx_for(order.size(), not_found);

for (auto order_idx = 0; order_idx < order.size(); ++order_idx) {
const auto name = order[order_idx];
TORCH_CHECK(name.isBasic(),
"align_to: the desired order of dimensions cannot contain a None name, got ",
order);
auto it = std::find(tensor_names.begin(), tensor_names.end(), name);
if (it == tensor_names.end()) {
continue;
}
auto idx_in_tensor = std::distance(tensor_names.begin(), it);
tensor_idx_for[order_idx] = idx_in_tensor;
order_has_tensor_name.set(idx_in_tensor);
}

const auto num_ellipsis_names = countUnset(order_has_tensor_name, tensor_dim);
const auto out_dim = num_ellipsis_names + order.size();

// Step 2: Now that we know the size of the output tensor, we can use the
// metadata obtained from Step 1 to fill in the new sizes/strides/names
std::vector<int64_t> new_sizes(out_dim, 1);
std::vector<int64_t> new_strides(out_dim, 0);
std::vector<Dimname> new_names(out_dim, Dimname::wildcard());

auto setNewSizesStridesNamesFor = [&](int64_t out_dim, int64_t tensor_dim) {
new_sizes[out_dim] = tensor_sizes[tensor_dim];
new_strides[out_dim] = tensor_strides[tensor_dim];
new_names[out_dim] = tensor_names[tensor_dim];
};

// Fill in the non-ellipsis dimensions
for (auto order_idx = 0; order_idx < order.size(); ++order_idx) {
auto out_idx = order_idx;
if (order_idx >= ellipsis_idx) {
out_idx = order_idx + num_ellipsis_names;
}
const auto tensor_idx = tensor_idx_for[order_idx];
if (tensor_idx == not_found) {
// We are adding a new size-one dimension
new_names[out_idx] = order[order_idx];
continue;
}
setNewSizesStridesNamesFor(out_idx, tensor_idx);
}

// Fill in the ellipsis dimensions
for (auto tensor_idx = 0; tensor_idx < tensor_dim; ++tensor_idx) {
if (order_has_tensor_name.test(tensor_idx)) {
continue;
}
setNewSizesStridesNamesFor(ellipsis_idx, tensor_idx);
ellipsis_idx++;
}

check_names_valid_for(out_dim, new_names);

Tensor result;
{
NoNamesGuard guard;
result = tensor.as_strided(new_sizes, new_strides);
}
internal_set_names_inplace(result, std::move(new_names), /*validate_names=*/false);
return result;
}

Tensor align_to(const Tensor& tensor, DimnameList names) {
auto tensor_names = tensor.names();
auto tensor_sizes = tensor.sizes();
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -69,6 +69,10 @@
variants: method
supports_named_tensor: True

- func: align_to(Tensor(a) self, DimnameList order, int ellipsis_idx) -> Tensor(a)
variants: method
supports_named_tensor: True

- func: align_as(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: unboxed_only
variants: method
Expand Down
19 changes: 14 additions & 5 deletions test/test_namedtensor.py
Expand Up @@ -1366,6 +1366,11 @@ def test_align_to_ellipsis(self):
self.assertEqual(output.names, ['H', 'C', 'W', 'N'])
self.assertEqual(output.shape, [3, 2, 5, 7])

# ... = ['N', 'W']
output = tensor.align_to('H', 'C', '...')
self.assertEqual(output.names, ['H', 'C', 'N', 'W'])
self.assertEqual(output.shape, [3, 2, 7, 5])

# ... = ['H', 'C']
output = tensor.align_to('W', '...', 'N')
self.assertEqual(output.names, ['W', 'H', 'C', 'N'])
Expand All @@ -1377,16 +1382,20 @@ def test_align_to_ellipsis(self):
self.assertEqual(output.shape, [7, 2, 1, 3, 5])

# Input tensor partially named
partiall_named = create('N:7,None:1')
with self.assertRaisesRegex(RuntimeError, "All input dims must be named"):
partiall_named.align_to('...', 'N')
partially_named = create('None:2,None:3,None:5,C:7')
output = partially_named.align_to('C', '...')
self.assertEqual(output.names, ['C', None, None, None])
self.assertEqual(output.shape, [7, 2, 3, 5])

with self.assertRaisesRegex(RuntimeError, "order of dimensions cannot contain a None"):
partially_named.align_to('C', None, '...')

# Input order partially named
with self.assertRaisesRegex(RuntimeError, "desired order must not contain None"):
with self.assertRaisesRegex(RuntimeError, "cannot contain a None name"):
tensor.align_to('...', 'N', None)

# Input order duplicate names
with self.assertRaisesRegex(RuntimeError, "Duplicate names"):
with self.assertRaisesRegex(RuntimeError, "duplicate names"):
tensor.align_to('...', 'N', 'N')

def test_align_as(self):
Expand Down
42 changes: 13 additions & 29 deletions torch/_namedtensor_internals.py
Expand Up @@ -54,6 +54,15 @@ def is_ellipsis(item):
else:
return item == Ellipsis or item == '...'

def single_ellipsis_index(names, fn_name):
ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)]
if len(ellipsis_indices) >= 2:
raise RuntimeError('{}: More than one Ellipsis (\'...\') found in names ('
'{}). This function supports up to one Ellipsis.'
.format(fn_name, names))
if len(ellipsis_indices) == 1:
return ellipsis_indices[0]
return None

def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names):
return names[numel_pre_glob:len(names) - numel_post_glob]
Expand All @@ -64,39 +73,14 @@ def replace_ellipsis_by_position(ellipsis_idx, names, tensor_names):
return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1:]


def replace_ellipsis_with_missing_names(ellipsis_idx, names, tensor_names, fn_name):
if any([dimname is None for dimname in tensor_names]):
raise RuntimeError(
'{}: All input dims must be named, got tensor with dims: {}. '
'Please use `tensor.refine_names(*names)` to add names to '
'unnamed dims'.format(fn_name, tensor_names))
if any([dimname is None for dimname in names]):
raise RuntimeError('{}: desired order must not contain None, got: {}.'
.format(fn_name, names))
desired_ordering_set = set(names)
if len(desired_ordering_set) != len(names):
raise RuntimeError('{}: Duplicate names are not allowed in desired ordering, got: {}.'
.format(fn_name, names))
missing_names = tuple([name for name in tensor_names if name not in desired_ordering_set])
return names[:ellipsis_idx] + missing_names + names[ellipsis_idx + 1:]


def resolve_ellipsis(names, tensor_names, fn_name, is_positional=True):
def resolve_ellipsis(names, tensor_names, fn_name):
"""
Expands ... inside `names` to be equal to a list of names from `tensor_names`.
"""
ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)]
if len(ellipsis_indices) >= 2:
raise RuntimeError('{}: More than one Ellipsis (\'...\') found in names ('
'{}). This function supports up to one Ellipsis.'
.format(fn_name, names))
if len(ellipsis_indices) == 0:
ellipsis_idx = single_ellipsis_index(names, fn_name)
if ellipsis_idx is None:
return names
ellipsis_idx = ellipsis_indices[0]
if is_positional:
return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names)
else:
return replace_ellipsis_with_missing_names(ellipsis_idx, names, tensor_names, fn_name)
return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names)


def update_names_with_list(tensor, names, inplace):
Expand Down
8 changes: 6 additions & 2 deletions torch/tensor.py
Expand Up @@ -2,7 +2,7 @@
import torch
import torch._C as _C
from torch._namedtensor_internals import update_names, check_serializing_named_tensor, resolve_ellipsis
from torch._namedtensor_internals import unzip_namedshape
from torch._namedtensor_internals import unzip_namedshape, single_ellipsis_index, is_ellipsis
from collections import OrderedDict
import torch.utils.hooks as hooks
import warnings
Expand Down Expand Up @@ -599,8 +599,12 @@ def align_to(self, *names):
The named tensor API is experimental and subject to change.
"""
ellipsis_idx = single_ellipsis_index(names, 'align_to')
if ellipsis_idx is None:
return super(Tensor, self).align_to(names)
return super(Tensor, self).align_to(
resolve_ellipsis(names, self.names, 'align_to', is_positional=False))
[name for name in names if not is_ellipsis(name)],
ellipsis_idx)

def unflatten(self, dim, namedshape):
r"""Unflattens the named dimension :attr:`dim`, viewing it in the shape
Expand Down

0 comments on commit 0fbbc7a

Please sign in to comment.