From 9e5f62d4824f81574000152f28e1125a7eb5a98b Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 1 Dec 2020 11:43:56 -0600 Subject: [PATCH 1/2] fix broken repr --- torch/csrc/Device.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/csrc/Device.cpp b/torch/csrc/Device.cpp index 5dc382fbfaf7..fef36f56c420 100644 --- a/torch/csrc/Device.cpp +++ b/torch/csrc/Device.cpp @@ -30,7 +30,10 @@ PyObject *THPDevice_repr(THPDevice *self) std::ostringstream oss; oss << "device(type=\'" << self->device.type() << "\'"; if (self->device.has_index()) { - oss << ", index=" << self->device.index(); + // `self->device.index()` returns uint8_t which is treated as ascii while printing, + // hence casting it to uint16_t. + // https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout + oss << ", index=" << static_cast(self->device.index()); } oss << ")"; return THPUtils_packString(oss.str().c_str()); From 535c782824a0fa7c8a05a06d4bcca19daaff459e Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 1 Dec 2020 11:47:46 -0600 Subject: [PATCH 2/2] add relevant test --- test/test_torch.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/test_torch.py b/test/test_torch.py index 84a540ae3937..4652d8d65f01 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -722,6 +722,17 @@ def test_device(self): device_hash_set.add(hash(torch.device(device))) self.assertEqual(len(device_set), len(device_hash_set)) + def get_expected_device_repr(device): + if device.index is not None: + return "device(type='{type}', index={index})".format( + type=device.type, index=device.index) + + return "device(type='{type}')".format(type=device.type) + + for device in device_set: + dev = torch.device(device) + self.assertEqual(repr(dev), get_expected_device_repr(dev)) + def test_to(self): def test_copy_behavior(t, non_blocking=False): self.assertIs(t, t.to(t, non_blocking=non_blocking))