diff --git a/ignite/utils.py b/ignite/utils.py index cd00e8e19cb8..91a9b8f4982c 100644 --- a/ignite/utils.py +++ b/ignite/utils.py @@ -56,8 +56,12 @@ def to_onehot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: """Convert a tensor of indices of any shape `(N, ...)` to a tensor of one-hot indicators of shape `(N, num_classes, ...) and of type uint8. Output's device is equal to the input's device`. + + .. versionchanged:: 0.4.3 + This functions is now torchscriptable. """ - onehot = torch.zeros(indices.shape[0], num_classes, *indices.shape[1:], dtype=torch.uint8, device=indices.device) + new_shape = (indices.shape[0], num_classes) + indices.shape[1:] + onehot = torch.zeros(new_shape, dtype=torch.uint8, device=indices.device) return onehot.scatter_(1, indices.unsqueeze(1), 1) diff --git a/tests/ignite/test_utils.py b/tests/ignite/test_utils.py index 781bd2673ee8..1b523e8691d6 100644 --- a/tests/ignite/test_utils.py +++ b/tests/ignite/test_utils.py @@ -75,6 +75,29 @@ def test_to_onehot(): y2 = torch.argmax(y_ohe, dim=1) assert y.equal(y2) + # Test with `TorchScript` + + x = torch.tensor([0, 1, 2, 3]) + + # Test the raw `to_onehot` function + scripted_to_onehot = torch.jit.script(to_onehot) + assert scripted_to_onehot(x, 4).allclose(to_onehot(x, 4)) + + # Test inside `torch.nn.Module` + class SLP(torch.nn.Module): + def __init__(self): + super(SLP, self).__init__() + self.linear = torch.nn.Linear(4, 1) + + def forward(self, x): + x = to_onehot(x, 4) + return self.linear(x.to(torch.float)) + + eager_model = SLP() + scripted_model = torch.jit.script(eager_model) + + assert eager_model(x).allclose(scripted_model(x)) + def test_dist_setup_logger():