Skip to content

Commit

Permalink
let torch.device be overrideable by TorchFunctionMode
Browse files Browse the repository at this point in the history
  • Loading branch information
dilililiwhy authored and pytorchmergebot committed Aug 4, 2023
1 parent aaa989c commit 1831a71
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
20 changes: 20 additions & 0 deletions test/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,26 @@ def get_full_matrices(t):
s.add(a)
s.add(DiagTensor(d))

def test_custom_device_type(self):
class CustomDeviceContext(TorchFunctionMode):

def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if func == torch.device:
if args and isinstance(args[0], int):
args = ("xla", args[0])
elif isinstance(kwargs.get('device'), int):
kwargs['device'] = f"xla:{kwargs.get('device')}"
return func(*args, **kwargs)

with CustomDeviceContext():
d_args = torch.device(0)
self.assertEqual(d_args.type, "xla")
self.assertEqual(d_args.index, 0)
d_kwargs = torch.device(device=0)
self.assertEqual(d_kwargs.type, "xla")
self.assertEqual(d_kwargs.index, 0)


if __name__ == '__main__':
run_tests()
1 change: 1 addition & 0 deletions torch/_prims_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def getnvFuserDtype(dtype: Union[torch.dtype, NumberTypeType]):


torch_function_passthrough = {
torch.device,
torch.Tensor.dim,
torch.Tensor.ndim.__get__, # type: ignore[attr-defined]
torch.Tensor.numel,
Expand Down
12 changes: 10 additions & 2 deletions torch/csrc/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#include <limits>
#include <sstream>

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
PyObject* THPUpperModuleOfDevice = nullptr;

PyObject* THPDevice_New(const at::Device& device) {
auto type = (PyTypeObject*)&THPDeviceType;
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
Expand Down Expand Up @@ -50,10 +53,14 @@ PyObject* THPDevice_pynew(
PyObject* kwargs) {
HANDLE_TH_ERRORS
static torch::PythonArgParser parser(
{"Device(Device device)",
"Device(c10::string_view type, int64_t? index=-1)"});
{"device(Device device)",
"device(c10::string_view type, int64_t? index=-1)"});
torch::ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.has_torch_function()) {
return handle_torch_function(
r, nullptr, args, kwargs, THPUpperModuleOfDevice, "torch");
}
if (r.idx == 0) {
auto device = r.device(0);
return THPDevice_New(device);
Expand Down Expand Up @@ -267,6 +274,7 @@ void THPDevice_init(PyObject* module) {
throw python_error();
}
Py_INCREF(&THPDeviceType);
THPUpperModuleOfDevice = module;
if (PyModule_AddObject(module, "device", (PyObject*)&THPDeviceType) != 0) {
throw python_error();
}
Expand Down

0 comments on commit 1831a71

Please sign in to comment.