diff --git a/test/test_torch.py b/test/test_torch.py index 768aad7e37717..ff6f8fed9686a 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1245,26 +1245,32 @@ def test_constructor_dtypes(self): default_type = torch.Tensor().type() self.assertIs(torch.Tensor().dtype, torch.Tensor.dtype) - torch.set_default_tensor_type('torch.IntTensor') - self.assertIs(torch.int32, torch.Tensor.dtype) - self.assertIs(torch.int32, torch.IntTensor.dtype) - self.assertEqual(torch.IntStorage, torch.Storage) + torch.set_default_tensor_type('torch.FloatTensor') + self.assertIs(torch.float32, torch.Tensor.dtype) + self.assertIs(torch.float32, torch.FloatTensor.dtype) + self.assertEqual(torch.FloatStorage, torch.Storage) - torch.set_default_tensor_type(torch.int64) - self.assertIs(torch.int64, torch.Tensor.dtype) - self.assertIs(torch.int64, torch.LongTensor.dtype) - self.assertEqual(torch.LongStorage, torch.Storage) + torch.set_default_tensor_type(torch.float64) + self.assertIs(torch.float64, torch.Tensor.dtype) + self.assertIs(torch.float64, torch.DoubleTensor.dtype) + self.assertEqual(torch.DoubleStorage, torch.Storage) torch.set_default_tensor_type('torch.Tensor') - self.assertIs(torch.int64, torch.Tensor.dtype) - self.assertIs(torch.int64, torch.LongTensor.dtype) - self.assertEqual(torch.LongStorage, torch.Storage) + self.assertIs(torch.float64, torch.Tensor.dtype) + self.assertIs(torch.float64, torch.DoubleTensor.dtype) + self.assertEqual(torch.DoubleStorage, torch.Storage) if torch.cuda.is_available(): - torch.set_default_tensor_type(torch.cuda.float64) - self.assertIs(torch.cuda.float64, torch.Tensor.dtype) - self.assertIs(torch.cuda.float64, torch.cuda.DoubleTensor.dtype) - self.assertEqual(torch.cuda.DoubleStorage, torch.Storage) + torch.set_default_tensor_type(torch.cuda.float32) + self.assertIs(torch.cuda.float32, torch.Tensor.dtype) + self.assertIs(torch.cuda.float32, torch.cuda.FloatTensor.dtype) + self.assertEqual(torch.cuda.FloatStorage, torch.Storage) + + # don't support integral or sparse default types. + self.assertRaises(TypeError, lambda: torch.set_default_tensor_type('torch.IntTensor')) + self.assertRaises(TypeError, lambda: torch.set_default_tensor_type(torch.int64)) + self.assertRaises(TypeError, lambda: torch.set_default_tensor_type(torch.sparse.int64)) + self.assertRaises(TypeError, lambda: torch.set_default_tensor_type(torch.sparse.float64)) torch.set_default_tensor_type(default_type) diff --git a/torch/csrc/tensor/python_tensor.cpp b/torch/csrc/tensor/python_tensor.cpp index b93b407bb263b..ffbed54e7fb54 100644 --- a/torch/csrc/tensor/python_tensor.cpp +++ b/torch/csrc/tensor/python_tensor.cpp @@ -335,6 +335,14 @@ void py_set_default_tensor_type(PyObject* obj) { throw unavailable_type(*type); } + if (!at::isFloatingType(type->aten_type->scalarType())) { + throw TypeError("only floating-point types are supported as the default type"); + } + + if (type->aten_type->is_sparse()) { + throw TypeError("only dense types are supported as the default type"); + } + // get the storage first, so if it doesn't exist we don't change the default tensor type THPObjectPtr storage = get_storage_obj(*type); set_default_tensor_type(*type->aten_type);