diff --git a/test/test_torch.py b/test/test_torch.py index 2bccf650e0b48..52eda30e37562 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -7310,6 +7310,20 @@ def test_ctor_with_numpy_array(self): for i in range(len(array)): self.assertEqual(tensor[i], array[i]) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_ctor_with_numpy_scalar_ctor(self): + dtypes = [ + np.double, + np.float, + np.float16, + np.int64, + np.int32, + np.int16, + np.uint8 + ] + for dtype in dtypes: + self.assertEqual(dtype(42), torch.tensor(dtype(42)).item()) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") def test_numpy_index(self): i = np.int32([0, 1, 2]) diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 3a8b4a7bbc159..d03fd55f2accf 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -139,8 +139,10 @@ ScalarType infer_scalar_type(PyObject *obj) { } #ifdef USE_NUMPY if (PyArray_Check(obj)) { - auto array = (PyArrayObject*)obj; - return numpy_dtype_to_aten(PyArray_TYPE(array)); + return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*)obj)); + } + if (PyArray_CheckScalar(obj)) { + return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*)(PyArray_FromScalar(obj, NULL)))); } #endif if (PySequence_Check(obj)) {