Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions aten/src/ATen/DeviceGuard.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/core/ScalarType.h>
#include <ATen/Tensor.h>
#include <ATen/core/Error.h>
#include <ATen/core/optional.h>
#include <ATen/detail/CUDAHooksInterface.h>

#include <cstddef>
Expand All @@ -28,6 +29,12 @@ struct DeviceGuard {
}
}

explicit DeviceGuard(optional<Device> device_opt) {
if (device_opt.has_value() && device_opt.value().is_cuda()) {
set_index(device_opt.value().index());
}
}

/// Calls `set_index` with the given index.
explicit DeviceGuard(int32_t index) {
set_index(index);
Expand Down
29 changes: 29 additions & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,35 @@ def test_factory_copy(self):
self.assertNotEqual(indices.data_ptr(), sparse_tensor._indices().data_ptr())
self.assertNotEqual(values.data_ptr(), sparse_tensor._values().data_ptr())

@cpu_only # just run once, we test both cpu and cuda
def test_constructor_device_legacy(self):
i = torch.tensor([[0, 1, 1], [2, 0, 2]])
v = torch.tensor([3., 4., 5.])
size = torch.Size([2, 3])

self.assertRaises(RuntimeError, lambda: torch.sparse.FloatTensor(device='cuda'))
self.assertRaises(RuntimeError, lambda: torch.sparse.FloatTensor(i, v, device='cuda'))
self.assertRaises(RuntimeError, lambda: torch.sparse.FloatTensor(i, v, size, device='cuda'))
self.assertRaises(RuntimeError, lambda: torch.sparse.FloatTensor(torch.Size([2, 3, 4]), device='cuda'))

x = torch.sparse_coo_tensor(i, v, size, device='cpu')
self.assertRaises(RuntimeError, lambda: x.new(device='cuda'))
self.assertRaises(RuntimeError, lambda: x.new(i, v, device='cuda'))
self.assertRaises(RuntimeError, lambda: x.new(i, v, size, device='cuda'))
self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cuda'))

if torch.cuda.is_available():
self.assertRaises(RuntimeError, lambda: torch.cuda.sparse.FloatTensor(device='cpu'))
self.assertRaises(RuntimeError, lambda: torch.cuda.sparse.FloatTensor(i, v, device='cpu'))
self.assertRaises(RuntimeError, lambda: torch.cuda.sparse.FloatTensor(i, v, size, device='cpu'))
self.assertRaises(RuntimeError, lambda: torch.cuda.sparse.FloatTensor(torch.Size([2, 3, 4]), device='cpu'))

x = torch.sparse_coo_tensor(i, v, size, device='cuda')
self.assertRaises(RuntimeError, lambda: x.new(device='cpu'))
self.assertRaises(RuntimeError, lambda: x.new(i, v, device='cpu'))
self.assertRaises(RuntimeError, lambda: x.new(i, v, size, device='cpu'))
self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cpu'))

@cpu_only # not really, but we only really want to run this once
def test_dtypes(self):
all_sparse_dtypes = [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.float16]
Expand Down
32 changes: 32 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2329,6 +2329,38 @@ def test_constructor_dtypes(self):

torch.set_default_tensor_type(default_type)

def test_constructor_device_legacy(self):
self.assertRaises(RuntimeError, lambda: torch.FloatTensor(device='cuda'))
self.assertRaises(RuntimeError, lambda: torch.FloatTensor(torch.Size([2, 3, 4]), device='cuda'))
self.assertRaises(RuntimeError, lambda: torch.FloatTensor((2.0, 3.0), device='cuda'))

self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cuda'))
self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cuda'))
self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cuda'))

x = torch.randn((3,), device='cpu')
self.assertRaises(RuntimeError, lambda: x.new(device='cuda'))
self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cuda'))
self.assertRaises(RuntimeError, lambda: x.new((2.0, 3.0), device='cuda'))

if torch.cuda.is_available():
self.assertRaises(RuntimeError, lambda: torch.cuda.FloatTensor(device='cpu'))
self.assertRaises(RuntimeError, lambda: torch.cuda.FloatTensor(torch.Size([2, 3, 4]), device='cpu'))
self.assertRaises(RuntimeError, lambda: torch.cuda.FloatTensor((2.0, 3.0), device='cpu'))

default_type = torch.Tensor().type()
torch.set_default_tensor_type(torch.cuda.FloatTensor)
self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cpu'))
self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cpu'))
self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cpu'))
torch.set_default_tensor_type(torch.cuda.FloatTensor)
torch.set_default_tensor_type(default_type)

x = torch.randn((3,), device='cuda')
self.assertRaises(RuntimeError, lambda: x.new(device='cpu'))
self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cpu'))
self.assertRaises(RuntimeError, lambda: x.new((2.0, 3.0), device='cpu'))

def test_type(self):
x = torch.randn(3, 3).double()
self.assertEqual(x.type('torch.FloatTensor').dtype, torch.float32)
Expand Down
63 changes: 50 additions & 13 deletions torch/csrc/utils/tensor_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,15 @@ Tensor legacy_new_from_sequence(const Type & type, at::optional<Device> device,
return legacy_new_from_data(type, device, data);
}

void check_legacy_ctor_device(const Type& type, at::optional<Device> device) {
if (device.has_value()) {
AT_CHECK(type.device_type() == device.value().type(),
"legacy constructor for device type: ", type.device_type(),
" was passed device type: ", device.value().type(),
", but device type must be: ", type.device_type());
}
}

Tensor legacy_sparse_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
"new(*, Device? device=None)",
Expand All @@ -261,22 +270,30 @@ Tensor legacy_sparse_tensor_ctor(const Type& type, PyObject* args, PyObject* kwa
ParsedArgs<4> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
auto deviceOptional = r.deviceOptional(0);
check_legacy_ctor_device(type, deviceOptional);
return at::empty({0}, type.options(r.device(0).index()));
} else if (r.idx == 1) {
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
return type.unsafeTensorFromTH(cdata, true);
} else if (r.idx == 2) {
at::DeviceGuard device_guard(r.device(2));
auto deviceOptional = r.deviceOptional(2);
check_legacy_ctor_device(type, deviceOptional);
at::DeviceGuard device_guard(deviceOptional);
return type.sparse_coo_tensor(r.tensor(0), r.tensor(1));
} else if (r.idx == 3) {
at::DeviceGuard device_guard(r.device(3));
auto deviceOptional = r.deviceOptional(3);
check_legacy_ctor_device(type, deviceOptional);
at::DeviceGuard device_guard(deviceOptional);
return type.sparse_coo_tensor(r.tensor(0), r.tensor(1), r.intlist(2));
} else if (r.idx == 4) {
PyObject* arg = r.pyobject(0);
auto deviceOptional = r.deviceOptional(1);
check_legacy_ctor_device(type, deviceOptional);
if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
// new(sequence) binds to this signature but should be treated differently
// unless the sequences is a torch.Size
return legacy_new_from_sequence(type, r.deviceOptional(1), r.pyobject(0));
return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
}
return new_with_sizes(type, r.device(1).index(), r.intlist(0));
}
Expand All @@ -294,27 +311,35 @@ Tensor legacy_sparse_tensor_new(const Type& type, PyObject* args, PyObject* kwar
ParsedArgs<5> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
at::DeviceGuard device_guard(r.device(0));
auto deviceOptional = r.deviceOptional(0);
check_legacy_ctor_device(type, deviceOptional);
at::DeviceGuard device_guard(deviceOptional);
return type.tensor();
} else if (r.idx == 1) {
auto cdata = reinterpret_cast<void*>(r.device(0).index());
auto cdata = reinterpret_cast<void*>(r.toInt64(0));

This comment was marked as off-topic.

return type.unsafeTensorFromTH(cdata, true);
} else if (r.idx == 2) {
// Note: this signature doesn't have a dtype, even though it has a device; it probably shouldn't
// have a device (we should infer it).
at::DeviceGuard device_guard(r.device(2));
auto deviceOptional = r.deviceOptional(2);
check_legacy_ctor_device(type, deviceOptional);
at::DeviceGuard device_guard(deviceOptional);
return type.sparse_coo_tensor(r.tensor(0), r.tensor(1));
} else if (r.idx == 3) {
// Note: this signature doesn't have a dtype, even though it has a device; it probably shouldn't
// have a device (we should infer it).
at::DeviceGuard device_guard(r.device(3));
auto deviceOptional = r.deviceOptional(3);
check_legacy_ctor_device(type, deviceOptional);
at::DeviceGuard device_guard(deviceOptional);
return type.sparse_coo_tensor(r.tensor(0), r.tensor(1), r.intlist(2));
} else if (r.idx == 4) {
PyObject* arg = r.pyobject(0);
auto deviceOptional = r.deviceOptional(1);
check_legacy_ctor_device(type, deviceOptional);
if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
// new(sequence) binds to this signature but should be treated differently
// unless the sequences is a torch.Size
return legacy_new_from_sequence(type, r.deviceOptional(1), r.pyobject(0));
return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
}
return new_with_sizes(type, r.device(1).index(), r.intlist(0));
}
Expand Down Expand Up @@ -346,7 +371,9 @@ Tensor legacy_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
at::DeviceGuard device_guard(r.device(0));
auto deviceOptional = r.deviceOptional(0);
check_legacy_ctor_device(type, deviceOptional);
at::DeviceGuard device_guard(deviceOptional);
return type.tensor();
} else if (r.idx == 1) {
return new_with_storage(type, r.storage(0));
Expand All @@ -357,14 +384,18 @@ Tensor legacy_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
return new_with_tensor(type, r.tensor(0));
} else if (r.idx == 4) {
PyObject* arg = r.pyobject(0);
auto deviceOptional = r.deviceOptional(1);
check_legacy_ctor_device(type, deviceOptional);
if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
// new(sequence) binds to this signature but should be treated differently
// unless the sequences is a torch.Size
return legacy_new_from_sequence(type, r.deviceOptional(1), r.pyobject(0));
return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
}
return new_with_sizes(type, r.device(1).index(), r.intlist(0));
} else if (r.idx == 5) {
return legacy_new_from_sequence(type, r.deviceOptional(1), r.pyobject(0));
auto deviceOptional = r.deviceOptional(1);
check_legacy_ctor_device(type, deviceOptional);
return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
}
throw std::runtime_error("new(): invalid arguments");
}
Expand All @@ -386,7 +417,9 @@ Tensor legacy_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) {
ParsedArgs<3> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
at::DeviceGuard device_guard(r.device(0));
auto deviceOptional = r.deviceOptional(0);
check_legacy_ctor_device(type, deviceOptional);
at::DeviceGuard device_guard(deviceOptional);
return type.tensor();
} else if (r.idx == 1) {
return new_with_storage(type, r.storage(0));
Expand All @@ -397,13 +430,17 @@ Tensor legacy_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) {
return new_with_tensor(type, r.tensor(0));
} else if (r.idx == 4) {
PyObject* arg = r.pyobject(0);
auto deviceOptional = r.deviceOptional(1);
check_legacy_ctor_device(type, deviceOptional);
if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
// new(sequence) binds to this signature but should be treated differently
// unless the sequences is a torch.Size
return legacy_new_from_sequence(type, r.deviceOptional(1), r.pyobject(0));
return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
}
return new_with_sizes(type, r.device(1).index(), r.intlist(0));
} else if (r.idx == 5) {
auto deviceOptional = r.deviceOptional(1);
check_legacy_ctor_device(type, deviceOptional);
return legacy_new_from_sequence(type, r.deviceOptional(1), r.pyobject(0));
}
throw std::runtime_error("new(): invalid arguments");
Expand Down