-
-
Notifications
You must be signed in to change notification settings - Fork 654
Fix to_onehot
function
#1592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix to_onehot
function
#1592
Conversation
@Devanshu24 Thank you for this PR ! Could you add a test using torchscript ? Thanks ! |
Sure, I'll do that! :) |
@Devanshu24 could you please advance on this PR in higher priority than the one with deprecated decorator as we'd like to put it into the next release. Thanks ! |
Sure, I was learning how to use torchscript. Will try to push a commit today itself :) |
While writing the tests for from ignite.utils import to_onehot
from torch import nn
import torch
class SLP(nn.Module):
def __init__(self):
super(SLP, self).__init__()
self.linear = nn.Linear(4, 1)
def forward(self, x):
x = to_onehot(x, 4)
return self.linear(x.to(torch.float))
a = torch.tensor([0, 1, 2, 3])
eager_model = SLP()
print(eager_model(a))
script_model = torch.jit.trace(eager_model, a)
print(script_model(a))
torch.eq(eager_model(a), script_model(a)) However, if I pass a different object (of the same class) to from ignite.utils import to_onehot
from torch import nn
import torch
class SLP(nn.Module):
def __init__(self):
super(SLP, self).__init__()
self.linear = nn.Linear(4, 1)
def forward(self, x):
x = to_onehot(x, 4)
return self.linear(x.to(torch.float))
a = torch.tensor([0, 1, 2, 3])
eager_model = SLP()
print(eager_model(a))
eager_model_new = SLP()
script_model = torch.jit.trace(eager_model_new, a)
print(script_model(a))
torch.eq(eager_model(a), script_model(a)) I have only started to learn about |
@Devanshu24 maybe it should depend on how fully-connected weights are randomly initialized. Put scripted_to_onehot = torch.jit.script(to_onehot)
x = ...
assert scripted_to_onehot(x).allclose(to_onehot(x)) |
Ohh right! Thank you!!
Okay, I will use this, thanks! :) |
* Test to check scripted counterpart of raw `to_onehot` function * Test to check scripted counterpart of a NeuralNet using `to_onehot` function
2e8a500
to
050ff5c
Compare
I have added |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR looks good as far, just a minor thing to fix and it's good to go.
@Devanshu24 I updated the test myself to see if everything is ok and we can merge the PR |
Thanks @vfdev-5 ! |
Looks good, just docs thing .. versionchanged:: 0.4.3
This functions is now torchscriptable. cc: @vfdev-5 for version number |
Thanks for the reminder @ydcjeff ! This new practice is not completely intergrated :) Maybe, we can update PR template to recall (if anyone actually read it :) @Devanshu24 could you please add versionchanged docs mark as suggested. Thanks |
5e1d909
to
2478c78
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @Devanshu24 !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM Thanks @Devanshu24
On a side note, why is the docs preview action not triggered? |
I disabled it to consume less in building minutes... Will enable it since tomorrow as it will be new month with more available credits :) |
Oh okay! It's okay as the CONTRIBUTING.md is pretty detailed so I could check it locally too :D |
Fixes #1584
Check list: