Skip to content

Commit

Permalink
[fix] repr(torch.device) (#48655)
Browse files Browse the repository at this point in the history
Summary:
Fixes #48585

In the following commit 4c9eb57, type of `DeviceIndex` was changed from `uint16_t` to `uint8_t`.
`uint8_t` is treated as ascii chars by std::cout and other stream operators. Hence the broken `repr`

Stackoverflow Reference: https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout

Pull Request resolved: #48655

Reviewed By: bdhirsh

Differential Revision: D25272289

Pulled By: ezyang

fbshipit-source-id: a1549f5f8d417138cf38795e4c373e3a487d3691
  • Loading branch information
kshitij12345 authored and facebook-github-bot committed Dec 2, 2020
1 parent b006c7a commit 90a3049
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
11 changes: 11 additions & 0 deletions test/test_torch.py
Expand Up @@ -367,6 +367,17 @@ def test_device(self):
device_hash_set.add(hash(torch.device(device)))
self.assertEqual(len(device_set), len(device_hash_set))

def get_expected_device_repr(device):
if device.index is not None:
return "device(type='{type}', index={index})".format(
type=device.type, index=device.index)

return "device(type='{type}')".format(type=device.type)

for device in device_set:
dev = torch.device(device)
self.assertEqual(repr(dev), get_expected_device_repr(dev))

def test_to(self):
def test_copy_behavior(t, non_blocking=False):
self.assertIs(t, t.to(t, non_blocking=non_blocking))
Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/Device.cpp
Expand Up @@ -30,7 +30,10 @@ PyObject *THPDevice_repr(THPDevice *self)
std::ostringstream oss;
oss << "device(type=\'" << self->device.type() << "\'";
if (self->device.has_index()) {
oss << ", index=" << self->device.index();
// `self->device.index()` returns uint8_t which is treated as ascii while printing,
// hence casting it to uint16_t.
// https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout
oss << ", index=" << static_cast<uint16_t>(self->device.index());
}
oss << ")";
return THPUtils_packString(oss.str().c_str());
Expand Down

0 comments on commit 90a3049

Please sign in to comment.