diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 38af5bb2c782..c211a9bfe2f8 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -448,6 +448,15 @@ def test_nested_enable_python_mode(self) -> None: with enable_python_mode(LoggingTensor): with enable_python_mode(LoggingTensor): pass - + + def test_tolist_numpy_with_python_mode(self) -> None: + x = LoggingTensor(torch.tensor([2.0, 3.0])) + with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."): + x.tolist() + with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."): + x.numpy() + with self.assertRaises(AssertionError): + self.assertEqual(x, None) + if __name__ == '__main__': run_tests() diff --git a/test/test_torch.py b/test/test_torch.py index 3c47fd87122d..a43f0ac9f249 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -8399,6 +8399,13 @@ def generate_inputs(num_batches): finally: torch.set_num_threads(num_threads) + def test_conj_neg_tolist(self): + x = torch.randn(2, dtype=torch.cfloat) + y1 = x.conj() + y1_expect = x.conj_physical() + y2 = y1.imag + self.assertEqual(y1, y1_expect.tolist()) + self.assertEqual(y2, y1_expect.imag.tolist()) # TODO: these empy classes are temporarily instantiated for XLA compatibility # once XLA updates their test suite it should be removed diff --git a/torch/csrc/utils/tensor_list.cpp b/torch/csrc/utils/tensor_list.cpp index 7948734f1e58..1cde3a196d6b 100644 --- a/torch/csrc/utils/tensor_list.cpp +++ b/torch/csrc/utils/tensor_list.cpp @@ -30,7 +30,8 @@ static PyObject* recursive_to_list( } PyObject* tensor_to_list(const Tensor& tensor) { - Tensor data = tensor; + TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), ".tolist() is not supported for tensor subclasses."); + Tensor data = tensor.resolve_conj().resolve_neg(); if (!data.device().is_cpu()) { pybind11::gil_scoped_release no_gil; data = data.toBackend(Backend::CPU); diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 433d1e2e6808..e507ffbf448b 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -130,6 +130,8 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) { "Can't call numpy() on Tensor that has negative bit set. " "Use tensor.resolve_neg().numpy() instead."); + TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), ".numpy() is not supported for tensor subclasses."); + auto dtype = aten_to_numpy_dtype(tensor.scalar_type()); auto sizes = to_numpy_shape(tensor.sizes()); auto strides = to_numpy_shape(tensor.strides()); diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index b32c3a159fc0..b4ef79620d8d 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1781,8 +1781,10 @@ def assertEqual(self, x, y, msg: Optional[str] = None, *, assert (atol is None) == (rtol is None), "If one of atol or rtol is specified, then the other must be too" debug_msg: Optional[str] = None + if x is None or y is None: + self.assertTrue(x is None and y is None) # Tensor x Number and Number x Tensor comparisons - if isinstance(x, torch.Tensor) and isinstance(y, Number): + elif isinstance(x, torch.Tensor) and isinstance(y, Number): self.assertEqual(x.item(), y, atol=atol, rtol=rtol, msg=msg, exact_dtype=exact_dtype, exact_device=exact_device) elif isinstance(y, torch.Tensor) and isinstance(x, Number):