From d5e9edd87f58a27924f1c8a319a0d9157cabfabc Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Sun, 20 May 2018 07:26:51 +0000 Subject: [PATCH 1/4] serialization for torch.device --- test/test_torch.py | 17 +++++++++++++++++ torch/csrc/Device.cpp | 40 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/test/test_torch.py b/test/test_torch.py index 93771f18d7f5b..cf3c557db7132 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -36,6 +36,13 @@ break +class NetwithDevice(torch.nn.Module): + def __init__(self): + super(NetwithDevice, self).__init__() + self.param = torch.nn.Parameter(torch.Tensor(3, 5)) + self.device = torch.device('cpu:0') + + class FilelikeMock(object): def __init__(self, data, has_fileno=True, has_readinto=False): if has_readinto: @@ -6251,6 +6258,16 @@ def test_half_tensor(self): xh2 = torch.load(f) self.assertEqual(xh.float(), xh2.float()) + def test_save_net_with_device(self): + net = NetwithDevice() + with tempfile.NamedTemporaryFile() as f: + torch.save(net, f) + f.seek(0) + net2 = torch.load(f) + self.assertEqual(type(net), type(net2)) + self.assertEqual(net.state_dict(), net2.state_dict()) + self.assertEqual(net.device, net2.device) + @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') def test_half_tensor_cuda(self): x = torch.randn(5, 5).half() diff --git a/torch/csrc/Device.cpp b/torch/csrc/Device.cpp index bc9d9e0081cb7..f708c4114b8db 100644 --- a/torch/csrc/Device.cpp +++ b/torch/csrc/Device.cpp @@ -43,7 +43,7 @@ PyObject *THPDevice_repr(THPDevice *self) return THPUtils_packString(oss.str().c_str()); } -PyObject *THPDevice_str(THPDevice*self) +PyObject *THPDevice_str(THPDevice *self) { std::ostringstream oss; if (!self->device.is_default) { @@ -137,6 +137,37 @@ PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) { END_HANDLE_TH_ERRORS } +PyObject *THPDevice_reduce(THPDevice *self) +{ + PyObject *ret, *mod, *obj; + ret = PyTuple_New(2); + + if (ret == NULL) + return NULL; + + mod = PyImport_ImportModule("torch"); + if (mod == NULL) { + Py_DECREF(ret); + return NULL; + } + + obj = PyObject_GetAttrString(mod, "device"); + Py_DECREF(mod); + if (obj == NULL) { + Py_DECREF(ret); + return NULL; + } + + PyTuple_SET_ITEM(ret, 0, obj); + if (self->device.is_default) { + PyTuple_SET_ITEM(ret, 1, Py_BuildValue("(s)", deviceTypeString(self->device.type))); + } else { + PyTuple_SET_ITEM(ret, 1, Py_BuildValue("(si)", deviceTypeString(self->device.type), self->device.index)); + } + + return ret; +} + typedef PyObject *(*getter)(PyObject *, void *); static struct PyGetSetDef THPDevice_properties[] = { @@ -145,6 +176,11 @@ static struct PyGetSetDef THPDevice_properties[] = { {nullptr} }; +static PyMethodDef THPDevice_methods[] = { + {"__reduce__", (PyCFunction)THPDevice_reduce, METH_NOARGS, nullptr}, + {NULL} /* Sentinel */ +}; + PyTypeObject THPDeviceType = { PyVarObject_HEAD_INIT(nullptr, 0) "torch.device", /* tp_name */ @@ -173,7 +209,7 @@ PyTypeObject THPDeviceType = { 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ - 0, /* tp_methods */ + THPDevice_methods, /* tp_methods */ 0, /* tp_members */ THPDevice_properties, /* tp_getset */ 0, /* tp_base */ From 3b1ecb0c11df9af61af0ee1e6ebbc100ed955ac5 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Sun, 20 May 2018 14:03:34 +0000 Subject: [PATCH 2/4] address comments --- test/test_torch.py | 22 ++++++---------------- torch/csrc/Device.cpp | 23 +++++------------------ 2 files changed, 11 insertions(+), 34 deletions(-) diff --git a/test/test_torch.py b/test/test_torch.py index cf3c557db7132..9072aa2ba8c8b 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -36,13 +36,6 @@ break -class NetwithDevice(torch.nn.Module): - def __init__(self): - super(NetwithDevice, self).__init__() - self.param = torch.nn.Parameter(torch.Tensor(3, 5)) - self.device = torch.device('cpu:0') - - class FilelikeMock(object): def __init__(self, data, has_fileno=True, has_readinto=False): if has_readinto: @@ -6258,15 +6251,12 @@ def test_half_tensor(self): xh2 = torch.load(f) self.assertEqual(xh.float(), xh2.float()) - def test_save_net_with_device(self): - net = NetwithDevice() - with tempfile.NamedTemporaryFile() as f: - torch.save(net, f) - f.seek(0) - net2 = torch.load(f) - self.assertEqual(type(net), type(net2)) - self.assertEqual(net.state_dict(), net2.state_dict()) - self.assertEqual(net.device, net2.device) + def test_serialize_device(self): + device_str = ['cpu', 'cpu:0', 'cuda', 'cuda:0'] + device_obj = [torch.device(d) for d in device_str] + for device in device_obj: + device_copied = copy.deepcopy(device) + self.assertEqual(device, device_copied) @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') def test_half_tensor_cuda(self): diff --git a/torch/csrc/Device.cpp b/torch/csrc/Device.cpp index f708c4114b8db..dc7f50a1538a3 100644 --- a/torch/csrc/Device.cpp +++ b/torch/csrc/Device.cpp @@ -139,26 +139,13 @@ PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) { PyObject *THPDevice_reduce(THPDevice *self) { - PyObject *ret, *mod, *obj; - ret = PyTuple_New(2); + PyObject *ret = PyTuple_New(2); + if (!ret) return NULL; - if (ret == NULL) - return NULL; + py::object torch_module = py::module::import("torch"); + py::object torch_device = torch_module.attr("device"); + PyTuple_SET_ITEM(ret, 0, torch_device.release().ptr()); - mod = PyImport_ImportModule("torch"); - if (mod == NULL) { - Py_DECREF(ret); - return NULL; - } - - obj = PyObject_GetAttrString(mod, "device"); - Py_DECREF(mod); - if (obj == NULL) { - Py_DECREF(ret); - return NULL; - } - - PyTuple_SET_ITEM(ret, 0, obj); if (self->device.is_default) { PyTuple_SET_ITEM(ret, 1, Py_BuildValue("(s)", deviceTypeString(self->device.type))); } else { From 43ba14ad03fce443b39a57a0ae456c78888a3482 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Sun, 20 May 2018 14:09:59 +0000 Subject: [PATCH 3/4] use THPObjectPtr instead of PyObject* --- torch/csrc/Device.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torch/csrc/Device.cpp b/torch/csrc/Device.cpp index dc7f50a1538a3..9976d9008b249 100644 --- a/torch/csrc/Device.cpp +++ b/torch/csrc/Device.cpp @@ -139,20 +139,20 @@ PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) { PyObject *THPDevice_reduce(THPDevice *self) { - PyObject *ret = PyTuple_New(2); - if (!ret) return NULL; + auto ret = THPObjectPtr{PyTuple_New(2)}; + if (!ret) throw python_error(); py::object torch_module = py::module::import("torch"); py::object torch_device = torch_module.attr("device"); - PyTuple_SET_ITEM(ret, 0, torch_device.release().ptr()); + PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr()); if (self->device.is_default) { - PyTuple_SET_ITEM(ret, 1, Py_BuildValue("(s)", deviceTypeString(self->device.type))); + PyTuple_SET_ITEM(ret.get(), 1, Py_BuildValue("(s)", deviceTypeString(self->device.type))); } else { - PyTuple_SET_ITEM(ret, 1, Py_BuildValue("(si)", deviceTypeString(self->device.type), self->device.index)); + PyTuple_SET_ITEM(ret.get(), 1, Py_BuildValue("(si)", deviceTypeString(self->device.type), self->device.index)); } - return ret; + return ret.release(); } typedef PyObject *(*getter)(PyObject *, void *); From 61c3c508969c3c10fc3586be23da2fb833c83370 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Mon, 21 May 2018 01:28:00 +0000 Subject: [PATCH 4/4] handle errors --- torch/csrc/Device.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torch/csrc/Device.cpp b/torch/csrc/Device.cpp index 9976d9008b249..f11add881ac6a 100644 --- a/torch/csrc/Device.cpp +++ b/torch/csrc/Device.cpp @@ -139,6 +139,7 @@ PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) { PyObject *THPDevice_reduce(THPDevice *self) { + HANDLE_TH_ERRORS auto ret = THPObjectPtr{PyTuple_New(2)}; if (!ret) throw python_error(); @@ -146,13 +147,17 @@ PyObject *THPDevice_reduce(THPDevice *self) py::object torch_device = torch_module.attr("device"); PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr()); + THPObjectPtr args; if (self->device.is_default) { - PyTuple_SET_ITEM(ret.get(), 1, Py_BuildValue("(s)", deviceTypeString(self->device.type))); + args = THPObjectPtr{Py_BuildValue("(s)", deviceTypeString(self->device.type))}; } else { - PyTuple_SET_ITEM(ret.get(), 1, Py_BuildValue("(si)", deviceTypeString(self->device.type), self->device.index)); + args = THPObjectPtr{Py_BuildValue("(si)", deviceTypeString(self->device.type), self->device.index)}; } + if (!args) throw python_error(); + PyTuple_SET_ITEM(ret.get(), 1, args.release()); return ret.release(); + END_HANDLE_TH_ERRORS } typedef PyObject *(*getter)(PyObject *, void *);