-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Add more model symbolic tracing tests from torchvision #55744
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit 9ff2c82 (more details on the Dr. CI page):
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! I'm really impressed with the quality of this PR. Thanks so much for doing this
test/test_fx.py
Outdated
@@ -41,7 +41,7 @@ | |||
from fx.named_tup import MyNamedTup | |||
|
|||
try: | |||
from torchvision.models import resnet18 | |||
from torchvision import models |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we name this something more unique like torchvision_models
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it looks like some of the existing tests have broken because of the change from importing resnet18
to importing models
. Can you go through and switch those over to call resnet18.models
?
test/test_fx.py
Outdated
qgraph_script = torch.jit.script(qgraph) | ||
d = out_transform(qgraph(x)) | ||
e = out_transform(qgraph_script(x)) | ||
torch.testing.assert_allclose(a, d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you're going to have to increase the tolerance (atol or rtol kwargs) if you're comparing between an unquantized and a quantized model. It might be a better idea to just remove the quantization tests, that stuff should be covered in test_quantize.py
anyway
@suraj813 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: Fixes pytorch#55398 Generates tests that calls `symbolic_trace` on torchvision models and verifies the parity of outputs from eager model, `fx.GraphModule`, `jit.ScriptModule`. Test errors: GoogleNet and Inception models throw a type mismatch when scripting the traced `fx.GraphModule`. ``` Return value was annotated as having type __torch__.torchvision.models.googlenet.GoogLeNetOutputs but is actually of type Tensor: dropout = self.dropout(flatten); flatten = None fc = self.fc(dropout); dropout = None return fc ~~~~~~~~~ <--- HERE ``` Relevant type-inconsistency https://github.com/pytorch/vision/blob/512ea299d4b2d2bbac3498a75a2d8c0190cfcb39/torchvision/models/googlenet.py#L200 ``` torch.jit.unused def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs: if self.training and self.aux_logits: return _GoogLeNetOutputs(x, aux2, aux1) else: return x # type: ignore[return-value] ``` Pull Request resolved: pytorch#55744 Reviewed By: albanD Differential Revision: D27920595 Pulled By: suraj813 fbshipit-source-id: 01f6f2aef7badbde29b5162a7787b5af9398090d
Fixes #55398
Generates tests that calls
symbolic_trace
on torchvision models and verifies the parity of outputs from eager model,fx.GraphModule
,jit.ScriptModule
.Test errors: GoogleNet and Inception models throw a type mismatch when scripting the traced
fx.GraphModule
.Relevant type-inconsistency https://github.com/pytorch/vision/blob/512ea299d4b2d2bbac3498a75a2d8c0190cfcb39/torchvision/models/googlenet.py#L200