From 90a3049a9a0c8b18f94b658c0e04845281c3d7cf Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 2 Dec 2020 15:42:39 -0800 Subject: [PATCH] [fix] repr(torch.device) (#48655) Summary: Fixes https://github.com/pytorch/pytorch/issues/48585 In the following commit https://github.com/pytorch/pytorch/commit/4c9eb57914fb538e21c46b63cfb5c2e9d5bc2f20, type of `DeviceIndex` was changed from `uint16_t` to `uint8_t`. `uint8_t` is treated as ascii chars by std::cout and other stream operators. Hence the broken `repr` Stackoverflow Reference: https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout Pull Request resolved: https://github.com/pytorch/pytorch/pull/48655 Reviewed By: bdhirsh Differential Revision: D25272289 Pulled By: ezyang fbshipit-source-id: a1549f5f8d417138cf38795e4c373e3a487d3691 --- test/test_torch.py | 11 +++++++++++ torch/csrc/Device.cpp | 5 ++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/test/test_torch.py b/test/test_torch.py index 6c04dd00dc76..9f21efb48b85 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -367,6 +367,17 @@ def test_device(self): device_hash_set.add(hash(torch.device(device))) self.assertEqual(len(device_set), len(device_hash_set)) + def get_expected_device_repr(device): + if device.index is not None: + return "device(type='{type}', index={index})".format( + type=device.type, index=device.index) + + return "device(type='{type}')".format(type=device.type) + + for device in device_set: + dev = torch.device(device) + self.assertEqual(repr(dev), get_expected_device_repr(dev)) + def test_to(self): def test_copy_behavior(t, non_blocking=False): self.assertIs(t, t.to(t, non_blocking=non_blocking)) diff --git a/torch/csrc/Device.cpp b/torch/csrc/Device.cpp index 5dc382fbfaf7..fef36f56c420 100644 --- a/torch/csrc/Device.cpp +++ b/torch/csrc/Device.cpp @@ -30,7 +30,10 @@ PyObject *THPDevice_repr(THPDevice *self) std::ostringstream oss; oss << "device(type=\'" << self->device.type() << "\'"; if (self->device.has_index()) { - oss << ", index=" << self->device.index(); + // `self->device.index()` returns uint8_t which is treated as ascii while printing, + // hence casting it to uint16_t. + // https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout + oss << ", index=" << static_cast(self->device.index()); } oss << ")"; return THPUtils_packString(oss.str().c_str());