Skip to content

Commit

Permalink
Disable .numpy() and .tolist() for tensor subclasses subclasses and f…
Browse files Browse the repository at this point in the history
…ix .tolist() for conjugated and negated tensors (#66082) (#66576)

Summary:
Pull Request resolved: #66082

Fixes #66024 #65779

cc ezyang anjali411 dylanbespalko mruberry Lezcano nikitaved albanD

Test Plan: Imported from OSS

Reviewed By: Gamrix, albanD

Differential Revision: D31615588

Pulled By: anjali411

fbshipit-source-id: c3e65ef0fe301630eb76732ccd7819683c09aa19
  • Loading branch information
anjali411 committed Oct 14, 2021
1 parent 4a514dd commit 3c134b8
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 3 deletions.
11 changes: 10 additions & 1 deletion test/test_python_dispatch.py
Expand Up @@ -448,6 +448,15 @@ def test_nested_enable_python_mode(self) -> None:
with enable_python_mode(LoggingTensor):
with enable_python_mode(LoggingTensor):
pass


def test_tolist_numpy_with_python_mode(self) -> None:
x = LoggingTensor(torch.tensor([2.0, 3.0]))
with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
x.tolist()
with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
x.numpy()
with self.assertRaises(AssertionError):
self.assertEqual(x, None)

if __name__ == '__main__':
run_tests()
7 changes: 7 additions & 0 deletions test/test_torch.py
Expand Up @@ -8399,6 +8399,13 @@ def generate_inputs(num_batches):
finally:
torch.set_num_threads(num_threads)

def test_conj_neg_tolist(self):
x = torch.randn(2, dtype=torch.cfloat)
y1 = x.conj()
y1_expect = x.conj_physical()
y2 = y1.imag
self.assertEqual(y1, y1_expect.tolist())
self.assertEqual(y2, y1_expect.imag.tolist())

# TODO: these empy classes are temporarily instantiated for XLA compatibility
# once XLA updates their test suite it should be removed
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/utils/tensor_list.cpp
Expand Up @@ -30,7 +30,8 @@ static PyObject* recursive_to_list(
}

PyObject* tensor_to_list(const Tensor& tensor) {
Tensor data = tensor;
TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), ".tolist() is not supported for tensor subclasses.");
Tensor data = tensor.resolve_conj().resolve_neg();
if (!data.device().is_cpu()) {
pybind11::gil_scoped_release no_gil;
data = data.toBackend(Backend::CPU);
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/utils/tensor_numpy.cpp
Expand Up @@ -130,6 +130,8 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) {
"Can't call numpy() on Tensor that has negative bit set. "
"Use tensor.resolve_neg().numpy() instead.");

TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), ".numpy() is not supported for tensor subclasses.");

auto dtype = aten_to_numpy_dtype(tensor.scalar_type());
auto sizes = to_numpy_shape(tensor.sizes());
auto strides = to_numpy_shape(tensor.strides());
Expand Down
4 changes: 3 additions & 1 deletion torch/testing/_internal/common_utils.py
Expand Up @@ -1781,8 +1781,10 @@ def assertEqual(self, x, y, msg: Optional[str] = None, *,
assert (atol is None) == (rtol is None), "If one of atol or rtol is specified, then the other must be too"
debug_msg: Optional[str] = None

if x is None or y is None:
self.assertTrue(x is None and y is None)
# Tensor x Number and Number x Tensor comparisons
if isinstance(x, torch.Tensor) and isinstance(y, Number):
elif isinstance(x, torch.Tensor) and isinstance(y, Number):
self.assertEqual(x.item(), y, atol=atol, rtol=rtol, msg=msg,
exact_dtype=exact_dtype, exact_device=exact_device)
elif isinstance(y, torch.Tensor) and isinstance(x, Number):
Expand Down

0 comments on commit 3c134b8

Please sign in to comment.