Skip to content

Commit

Permalink
Add names= argument to torch.tensor ctor
Browse files Browse the repository at this point in the history
Test Plan
- new tests [namedtensor ci]

ghstack-source-id: 054b84daebfb4fe7c130a9a6d0ccb327a66a71eb
Pull Request resolved: #25424
  • Loading branch information
zou3519 committed Sep 9, 2019
1 parent 74fa539 commit 9159278
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 3 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/NamedTensorUtils.h
Expand Up @@ -38,7 +38,7 @@ namespace namedinference {
// 2) If result has names, then `names` must be equal to result.names
void propagate_names(Tensor& result, optional<DimnameList> names);
void propagate_names(Tensor& result, std::vector<Dimname>&& names, bool validate_names);
void propagate_names(Tensor& result, optional<std::vector<Dimname>>&& maybe_names, bool validate_names);
CAFFE2_API void propagate_names(Tensor& result, optional<std::vector<Dimname>>&& maybe_names, bool validate_names);
void propagate_names(TensorImpl* result, optional<DimnameList> names);
void propagate_names(TensorImpl* result, std::vector<Dimname>&& names, bool validate_names);
void propagate_names(TensorImpl* result, optional<std::vector<Dimname>>&& maybe_names, bool validate_names);
Expand Down
44 changes: 43 additions & 1 deletion test/test_namedtensor.py
@@ -1,5 +1,5 @@
import unittest
from common_utils import TestCase, run_tests
from common_utils import TestCase, run_tests, TEST_NUMPY
from common_cuda import TEST_CUDA
from collections import namedtuple
import itertools
Expand Down Expand Up @@ -343,6 +343,48 @@ def _test(factory, device):
expected = torch.full([1, 2, 3], 2, device=device).names_(*names)
self.assertTensorDataAndNamesEqual(result, expected)

def test_tensor_from_lists(self):
names = ('N', 'C')
tensor = torch.tensor([[1]], names=names)
self.assertEqual(tensor.names, names)

names = ('N',)
tensor = torch.tensor([1], names=names)
self.assertEqual(tensor.names, names)

with self.assertRaisesRegex(RuntimeError, 'Number of names'):
names = ('N', 'C')
tensor = torch.tensor([1], names=names)

@unittest.skipIf(not TEST_NUMPY, "no numpy")
def test_tensor_from_numpy(self):
import numpy as np
arr = np.array([[1]])
names = ('N', 'C')
tensor = torch.tensor([[1]], names=names)
self.assertEqual(tensor.names, names)

def test_tensor_from_tensor(self):
x = torch.randn(1, 1)
names = ('N', 'C')
tensor = torch.tensor(x, names=names)
self.assertEqual(tensor.names, names)

def test_tensor_from_named_tensor(self):
x = torch.randn(1, 1, names=('N', 'D'))
tensor = torch.tensor(x)
self.assertEqual(tensor.names, ('N', 'D'))

# there's no way to distinguish between names=None and not passing in names.
# If the user passes in names=None they are asking for trouble.
x = torch.randn(1, 1, names=('N', 'D'))
tensor = torch.tensor(x, names=None)
self.assertEqual(tensor.names, ('N', 'D'))

x = torch.randn(1, 1, names=('N', 'D'))
with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
tensor = torch.tensor(x, names=('N', 'C'))

def test_size(self):
t = torch.empty(2, 3, 5, names=('N', None, 'C'))
self.assertEqual(t.size('N'), 2)
Expand Down
18 changes: 17 additions & 1 deletion torch/csrc/utils/tensor_new.cpp
Expand Up @@ -17,6 +17,7 @@

#include <ATen/ATen.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/NamedTensorUtils.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>

Expand Down Expand Up @@ -556,10 +557,19 @@ Tensor sparse_coo_tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_t

Tensor tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
#ifdef BUILD_NAMEDTENSOR
"tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, DimnameList? names=None)",
#else
"tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
#endif
});

ParsedArgs<5> parsed_args;
#ifdef BUILD_NAMEDTENSOR
constexpr int ctor_num_args = 6;
#else
constexpr int ctor_num_args = 5;
#endif
ParsedArgs<ctor_num_args> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
PyObject* data = r.pyobject(0);
Expand All @@ -581,6 +591,12 @@ Tensor tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObje
true,
type_inference,
pin_memory);
#ifdef BUILD_NAMEDTENSOR
auto names = r.toDimnameListOptional(5);
if (names) {
at::namedinference::propagate_names(new_tensor, std::move(names), /*validate_names=*/true);
}
#endif
new_tensor.detach_(); // ensure new_tensor a leaf node
new_tensor.set_requires_grad(args_requires_grad);
return new_tensor;
Expand Down

0 comments on commit 9159278

Please sign in to comment.