Skip to content

Commit

Permalink
Add names= argument to torch.tensor ctor
Browse files Browse the repository at this point in the history
Test Plan
- new tests [namedtensor ci]

ghstack-source-id: 9ec59c5fd3cd8ed355e21e1cdae4992a3f78bc85
Pull Request resolved: #25424
  • Loading branch information
zou3519 committed Aug 29, 2019
1 parent 1ea1d7f commit 969bfef
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 3 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/NamedTensorUtils.h
Expand Up @@ -36,7 +36,7 @@ namespace namedinference {
// Names get propagated via the following rules:
// 1) If result does not have names, then `names` get propagated.
// 2) If result has names, then `names` must be equal to result.names
void propagate_names(Tensor& result, optional<DimnameList> names);
CAFFE2_API void propagate_names(Tensor& result, optional<DimnameList> names);
void propagate_names(Tensor& result, std::vector<Dimname>&& names, bool validate_names);
void propagate_names(Tensor& result, optional<std::vector<Dimname>>&& maybe_names, bool validate_names);
void propagate_names(TensorImpl* result, optional<DimnameList> names);
Expand Down
44 changes: 43 additions & 1 deletion test/test_namedtensor.py
@@ -1,5 +1,5 @@
import unittest
from common_utils import TestCase, run_tests
from common_utils import TestCase, run_tests, TEST_NUMPY
from common_cuda import TEST_CUDA
from collections import namedtuple
import itertools
Expand Down Expand Up @@ -314,6 +314,48 @@ def _test(factory, device):
expected = torch.full([1, 2, 3], 2, device=device).names_(*names)
self.assertTensorDataAndNamesEqual(result, expected)

def test_tensor_from_lists(self):
names = ('N', 'C')
tensor = torch.tensor([[1]], names=names)
self.assertEqual(tensor.names, names)

names = ('N',)
tensor = torch.tensor([1], names=names)
self.assertEqual(tensor.names, names)

with self.assertRaisesRegex(RuntimeError, 'Number of names'):
names = ('N', 'C')
tensor = torch.tensor([1], names=names)

@unittest.skipIf(not TEST_NUMPY, "no numpy")
def test_tensor_from_numpy(self):
import numpy as np
arr = np.array([[1]])
names = ('N', 'C')
tensor = torch.tensor([[1]], names=names)
self.assertEqual(tensor.names, names)

def test_tensor_from_tensor(self):
x = torch.randn(1, 1)
names = ('N', 'C')
tensor = torch.tensor(x, names=names)
self.assertEqual(tensor.names, names)

def test_tensor_from_named_tensor(self):
x = torch.randn(1, 1, names=('N', 'D'))
tensor = torch.tensor(x)
self.assertEqual(tensor.names, ('N', 'D'))

# there's no way to distinguish between names=None and not passing in names.
# If the user passes in names=None they are asking for trouble.
x = torch.randn(1, 1, names=('N', 'D'))
tensor = torch.tensor(x, names=None)
self.assertEqual(tensor.names, ('N', 'D'))

x = torch.randn(1, 1, names=('N', 'D'))
with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
tensor = torch.tensor(x, names=('N', 'C'))

def test_size(self):
t = torch.empty(2, 3, 5, names=('N', None, 'C'))
self.assertEqual(t.size('N'), 2)
Expand Down
44 changes: 43 additions & 1 deletion torch/csrc/utils/tensor_new.cpp
Expand Up @@ -17,6 +17,7 @@

#include <ATen/ATen.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/NamedTensorUtils.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>

Expand Down Expand Up @@ -229,7 +230,11 @@ Tensor internal_new_from_data(
bool copy_variables,
bool copy_numpy,
bool type_inference,
bool pin_memory = false) {
bool pin_memory = false
#ifdef BUILD_NAMEDTENSOR
, optional<DimnameList> names = at::nullopt
#endif
) {

if (THPUtils_checkString(data)) {
throw TypeError("new(): invalid data type '%s'", Py_TYPE(data)->tp_name);
Expand All @@ -241,6 +246,11 @@ Tensor internal_new_from_data(
if (copy_variables) {
var = var.detach();
}
#ifdef BUILD_NAMEDTENSOR
if (names) {
at::namedinference::propagate_names(var, names);
}
#endif
// infer the scalar type and device type; it's not expected to infer the layout since these constructors
// are defined per-layout-type (e.g. tensor vs sparse_coo_tensor).
const auto& inferred_scalar_type = type_inference ? var.scalar_type() : scalar_type;
Expand All @@ -258,6 +268,9 @@ Tensor internal_new_from_data(
auto device = device_opt.has_value() ? *device_opt : at::Device(computeDeviceType(type_id));
AutoNoGIL no_gil;
maybe_initialize_cuda(device);
#ifdef BUILD_NAMEDTENSOR
at::namedinference::propagate_names(tensor, names);
#endif
return tensor.to(device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/copy_numpy);
}

Expand All @@ -268,6 +281,9 @@ Tensor internal_new_from_data(
auto device = device_opt.has_value() ? *device_opt : at::Device(computeDeviceType(type_id));
AutoNoGIL no_gil;
maybe_initialize_cuda(device);
#ifdef BUILD_NAMEDTENSOR
at::namedinference::propagate_names(tensor, names);
#endif
return tensor.to(device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/copy_numpy);
}
#endif
Expand All @@ -281,6 +297,9 @@ Tensor internal_new_from_data(
auto device = device_opt.has_value() ? *device_opt : at::Device(computeDeviceType(type_id));
AutoNoGIL no_gil;
maybe_initialize_cuda(device);
#ifdef BUILD_NAMEDTENSOR
at::namedinference::propagate_names(tensor, names);
#endif
return tensor.to(device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/false);
}

Expand Down Expand Up @@ -556,10 +575,18 @@ Tensor sparse_coo_tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_t

Tensor tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
#ifdef BUILD_NAMEDTENSOR
"tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, DimnameList? names=None)",
#else
"tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
#endif
});

#ifdef BUILD_NAMEDTENSOR
ParsedArgs<6> parsed_args;
#else
ParsedArgs<5> parsed_args;
#endif
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
PyObject* data = r.pyobject(0);
Expand All @@ -572,6 +599,20 @@ Tensor tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObje
bool type_inference = r.isNone(1);
bool pin_memory = r.toBool(3);
bool args_requires_grad = r.toBool(4);
#ifdef BUILD_NAMEDTENSOR
auto __names = r.toDimnameListOptional(5);
c10::optional<DimnameList> names = __names ? c10::make_optional(DimnameList(__names.value())) : c10::nullopt;
auto new_tensor = internal_new_from_data(
typeIdWithDefault(r, 2, type_id),
r.scalartypeWithDefault(1, scalar_type),
r.deviceOptional(2),
data,
true,
true,
type_inference,
pin_memory,
names);
#else
auto new_tensor = internal_new_from_data(
typeIdWithDefault(r, 2, type_id),
r.scalartypeWithDefault(1, scalar_type),
Expand All @@ -581,6 +622,7 @@ Tensor tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObje
true,
type_inference,
pin_memory);
#endif
new_tensor.detach_(); // ensure new_tensor a leaf node
new_tensor.set_requires_grad(args_requires_grad);
return new_tensor;
Expand Down

0 comments on commit 969bfef

Please sign in to comment.