Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/test_numpy_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from itertools import product
import sys

from torch.testing._internal.common_utils import \
(skipIfTorchDynamo, TestCase, run_tests)
Expand Down Expand Up @@ -257,6 +258,18 @@ def test_from_numpy(self, device) -> None:
x.strides = (3,)
self.assertRaises(ValueError, lambda: torch.from_numpy(x))

@skipIfTorchDynamo("No need to test invalid dtypes that should fail by design.")
def test_from_numpy_no_leak_on_invalid_dtype(self):
# This used to leak memory as the `from_numpy` call raised an exception and didn't decref the temporary
# object. See https://github.com/pytorch/pytorch/issues/121138
x = np.array("value".encode('ascii'))
for _ in range(1000):
try:
torch.from_numpy(x)
except TypeError:
pass
self.assertTrue(sys.getrefcount(x) == 2)

@skipMeta
def test_from_list_of_ndarray_warning(self, device):
warning_msg = r"Creating a tensor from a list of numpy.ndarrays is extremely slow"
Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/utils/tensor_numpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ at::Tensor tensor_from_numpy(
PyArray_EquivByteorders(PyArray_DESCR(array)->byteorder, NPY_NATIVE),
"given numpy array has byte order different from the native byte order. "
"Conversion between byte orders is currently not supported.");
// This has to go before the INCREF in case the dtype mapping doesn't
// exist and an exception is thrown
auto torch_dtype = numpy_dtype_to_aten(PyArray_TYPE(array));
Py_INCREF(obj);
return at::lift_fresh(at::from_blob(
data_ptr,
Expand All @@ -267,7 +270,7 @@ at::Tensor tensor_from_numpy(
pybind11::gil_scoped_acquire gil;
Py_DECREF(obj);
},
at::device(kCPU).dtype(numpy_dtype_to_aten(PyArray_TYPE(array)))));
at::device(kCPU).dtype(torch_dtype)));
}

int aten_to_numpy_dtype(const ScalarType scalar_type) {
Expand Down