From 098ebdce3b87935010f62b00095130a12a62803a Mon Sep 17 00:00:00 2001 From: Devanshu24 Date: Fri, 29 Jan 2021 22:08:25 +0530 Subject: [PATCH 1/4] Fix `to_onehot` function --- ignite/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ignite/utils.py b/ignite/utils.py index cd00e8e19cb8..5e880e0a3b8e 100644 --- a/ignite/utils.py +++ b/ignite/utils.py @@ -57,7 +57,8 @@ def to_onehot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: tensor of one-hot indicators of shape `(N, num_classes, ...) and of type uint8. Output's device is equal to the input's device`. """ - onehot = torch.zeros(indices.shape[0], num_classes, *indices.shape[1:], dtype=torch.uint8, device=indices.device) + new_shape = (indices.shape[0], num_classes) + indices.shape[1:] + onehot = torch.zeros(new_shape, dtype=torch.uint8, device=indices.device) return onehot.scatter_(1, indices.unsqueeze(1), 1) From 050ff5cb6a08466fb1cd1b734b50f368a373c8d2 Mon Sep 17 00:00:00 2001 From: Devanshu24 Date: Tue, 2 Feb 2021 19:21:38 +0530 Subject: [PATCH 2/4] Add TorchScript tests for `to_onehot` function * Test to check scripted counterpart of raw `to_onehot` function * Test to check scripted counterpart of a NeuralNet using `to_onehot` function --- tests/ignite/test_utils.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/ignite/test_utils.py b/tests/ignite/test_utils.py index 781bd2673ee8..d68927811b09 100644 --- a/tests/ignite/test_utils.py +++ b/tests/ignite/test_utils.py @@ -75,6 +75,29 @@ def test_to_onehot(): y2 = torch.argmax(y_ohe, dim=1) assert y.equal(y2) + # Test with `TorchScript` + + x = torch.tensor([0, 1, 2, 3]) + + # Test the raw `to_onehot` function + scripted_to_onehot = torch.jit.script(to_onehot) + assert scripted_to_onehot(x, 4).allclose(to_onehot(x, 4)) + + # Test inside `torch.nn.Module` + class SLP(torch.nn.Module): + def __init__(self): + super(SLP, self).__init__() + self.linear = torch.nn.Linear(4, 1) + + def forward(self, x): + x = to_onehot(x, 4) + return self.linear(x.to(torch.float)) + + eager_model = SLP() + scripted_model = torch.jit.trace(eager_model, x) + + assert eager_model(x).allclose(scripted_model(x)) + def test_dist_setup_logger(): From 7711e04ae5d6349e9e8f6279718abe3b0c05499e Mon Sep 17 00:00:00 2001 From: vfdev Date: Tue, 2 Feb 2021 16:29:31 +0100 Subject: [PATCH 3/4] Update test_utils.py --- tests/ignite/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ignite/test_utils.py b/tests/ignite/test_utils.py index d68927811b09..1b523e8691d6 100644 --- a/tests/ignite/test_utils.py +++ b/tests/ignite/test_utils.py @@ -94,7 +94,7 @@ def forward(self, x): return self.linear(x.to(torch.float)) eager_model = SLP() - scripted_model = torch.jit.trace(eager_model, x) + scripted_model = torch.jit.script(eager_model) assert eager_model(x).allclose(scripted_model(x)) From 2478c786f006adbd61de1fecd42c8fe6986754af Mon Sep 17 00:00:00 2001 From: Devanshu24 Date: Tue, 2 Feb 2021 21:07:28 +0530 Subject: [PATCH 4/4] Add `versionchanged` docs --- ignite/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ignite/utils.py b/ignite/utils.py index 5e880e0a3b8e..91a9b8f4982c 100644 --- a/ignite/utils.py +++ b/ignite/utils.py @@ -56,6 +56,9 @@ def to_onehot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: """Convert a tensor of indices of any shape `(N, ...)` to a tensor of one-hot indicators of shape `(N, num_classes, ...) and of type uint8. Output's device is equal to the input's device`. + + .. versionchanged:: 0.4.3 + This functions is now torchscriptable. """ new_shape = (indices.shape[0], num_classes) + indices.shape[1:] onehot = torch.zeros(new_shape, dtype=torch.uint8, device=indices.device)