Skip to content

Conversation

Devanshu24
Copy link
Contributor

@Devanshu24 Devanshu24 commented Jan 29, 2021

Fixes #1584

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

@sdesrozis
Copy link
Contributor

@Devanshu24 Thank you for this PR !

Could you add a test using torchscript ? Thanks !

@Devanshu24
Copy link
Contributor Author

Could you add a test using torchscript ? Thanks !

Sure, I'll do that! :)

@Devanshu24 Devanshu24 marked this pull request as draft January 31, 2021 15:09
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 1, 2021

@Devanshu24 could you please advance on this PR in higher priority than the one with deprecated decorator as we'd like to put it into the next release. Thanks !

@Devanshu24
Copy link
Contributor Author

Sure, I was learning how to use torchscript. Will try to push a commit today itself :)

@Devanshu24
Copy link
Contributor Author

While writing the tests for TorchScript I noticed something
The following code gives True for all elements of the tensor

from ignite.utils import to_onehot
from torch import nn
import torch


class SLP(nn.Module):
    def __init__(self):
        super(SLP, self).__init__()
        self.linear = nn.Linear(4, 1)

    def forward(self, x):
        x = to_onehot(x, 4)
        return self.linear(x.to(torch.float))

a = torch.tensor([0, 1, 2, 3])
eager_model = SLP()
print(eager_model(a))
script_model = torch.jit.trace(eager_model, a)
print(script_model(a))

torch.eq(eager_model(a), script_model(a))

However, if I pass a different object (of the same class) to torch.jit.trace the outputs of the eager_model and script_model are totally different (not even close)
See the code below, the output is False for all elements of the tensor

from ignite.utils import to_onehot
from torch import nn
import torch


class SLP(nn.Module):
    def __init__(self):
        super(SLP, self).__init__()
        self.linear = nn.Linear(4, 1)

    def forward(self, x):
        x = to_onehot(x, 4)
        return self.linear(x.to(torch.float))

a = torch.tensor([0, 1, 2, 3])
eager_model = SLP()
print(eager_model(a))
eager_model_new = SLP()
script_model = torch.jit.trace(eager_model_new, a)
print(script_model(a))

torch.eq(eager_model(a), script_model(a))

I have only started to learn about TorchScript so I do not know if this is expected behaviour. It'd be great if someone could guide me on this or direct me to the correct resource.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 2, 2021

@Devanshu24 maybe it should depend on how fully-connected weights are randomly initialized. Put torch.manual_seed(12) before each model creation and it should be good. By the way, let's use torch.jit.script for scripting ops.
So, a basic test without inserting to_onehot into nn.Module can be

scripted_to_onehot = torch.jit.script(to_onehot)
x = ...
assert scripted_to_onehot(x).allclose(to_onehot(x))

@Devanshu24
Copy link
Contributor Author

@Devanshu24 maybe it should depend on how fully-connected weights are randomly initialized. Put torch.manual_seed(12) before each model creation and it should be good.

Ohh right! Thank you!!

So, a basic test without inserting to_onehot into nn.Module can be

scripted_to_onehot = torch.jit.script(to_onehot)
x = ...
assert scripted_to_onehot(x).allclose(to_onehot(x))

Okay, I will use this, thanks! :)

* Test to check scripted counterpart of raw `to_onehot` function
* Test to check scripted counterpart of a NeuralNet using `to_onehot` function
@Devanshu24
Copy link
Contributor Author

I have added TorchScript based tests for the to_onehot function. Please have a look at your convenience :)

@Devanshu24 Devanshu24 marked this pull request as ready for review February 2, 2021 13:57
Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR looks good as far, just a minor thing to fix and it's good to go.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 2, 2021

@Devanshu24 I updated the test myself to see if everything is ok and we can merge the PR

@Devanshu24
Copy link
Contributor Author

Thanks @vfdev-5 !
Sorry that .trace call skipped my sight, thanks for catching it! :)

@ydcjeff
Copy link
Contributor

ydcjeff commented Feb 2, 2021

Looks good, just docs thing

.. versionchanged:: 0.4.3
    This functions is now torchscriptable.

cc: @vfdev-5 for version number

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 2, 2021

Thanks for the reminder @ydcjeff ! This new practice is not completely intergrated :) Maybe, we can update PR template to recall (if anyone actually read it :)

@Devanshu24 could you please add versionchanged docs mark as suggested. Thanks

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @Devanshu24 !

Copy link
Contributor

@ydcjeff ydcjeff left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM Thanks @Devanshu24

@Devanshu24
Copy link
Contributor Author

On a side note, why is the docs preview action not triggered?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 2, 2021

On a side note, why is the docs preview action not triggered?

I disabled it to consume less in building minutes... Will enable it since tomorrow as it will be new month with more available credits :)

@vfdev-5 vfdev-5 merged commit 5487076 into pytorch:master Feb 2, 2021
@Devanshu24
Copy link
Contributor Author

Devanshu24 commented Feb 2, 2021

On a side note, why is the docs preview action not triggered?

I disabled it to consume less in building minutes... Will enable it since tomorrow as it will be new month with more available credits :)

Oh okay! It's okay as the CONTRIBUTING.md is pretty detailed so I could check it locally too :D

@Devanshu24 Devanshu24 deleted the fix-to_onehot branch February 2, 2021 16:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

to_onehot can't be torchscripted
4 participants