New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FX][CodeGen] Make sure fx code is valid in python #113345
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113345
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 082709a with merge base 8c704f7 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
For those who are triaging PRs, please route this to @jansel . |
Can you show us the before and after code? |
Sure. Before: the first line def forward(self, primals, tangents):
primals_1: f32[10], primals_2: f32[10], tangents_1: f32[10], = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
# File: depyf/tests/test_pytorch/test_debug_function_aot.py:5, code: x = a / (torch.abs(a) + 1)
abs_1: f32[10] = torch.ops.aten.abs.default(primals_1)
add: f32[10] = torch.ops.aten.add.Tensor(abs_1, 1); abs_1 = None
div: f32[10] = torch.ops.aten.div.Tensor(primals_1, add)
# File: depyf/tests/test_pytorch/test_debug_function_aot.py:6, code: if b.sum() < 0:
sum_1: f32[] = torch.ops.aten.sum.default(primals_2); primals_2 = None
lt: b8[] = torch.ops.aten.lt.Scalar(sum_1, 0); sum_1 = None
# File: depyf/tests/test_pytorch/test_debug_function_aot.py:5, code: x = a / (torch.abs(a) + 1)
neg: f32[10] = torch.ops.aten.neg.default(tangents_1)
div_1: f32[10] = torch.ops.aten.div.Tensor(primals_1, add)
div_2: f32[10] = torch.ops.aten.div.Tensor(div_1, add); div_1 = None
mul: f32[10] = torch.ops.aten.mul.Tensor(neg, div_2); neg = div_2 = None
div_3: f32[10] = torch.ops.aten.div.Tensor(tangents_1, add); tangents_1 = add = None
sign: f32[10] = torch.ops.aten.sign.default(primals_1); primals_1 = None
mul_1: f32[10] = torch.ops.aten.mul.Tensor(mul, sign); mul = sign = None
# File: depyf/tests/test_pytorch/test_debug_function_aot.py:5, code: x = a / (torch.abs(a) + 1)
add_1: f32[10] = torch.ops.aten.add.Tensor(div_3, mul_1); div_3 = mul_1 = None
return pytree.tree_unflatten([div, lt, add_1, None], self._out_spec) After: def forward(self, primals, tangents):
primals_1: "f32[10]"; primals_2: "f32[10]"; tangents_1: "f32[10]"
primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
# File: depyf/tests/test_pytorch/test_debug_function_aot.py:5, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10]" = torch.ops.aten.abs.default(primals_1)
add: "f32[10]" = torch.ops.aten.add.Tensor(abs_1, 1); abs_1 = None
div: "f32[10]" = torch.ops.aten.div.Tensor(primals_1, add)
# File: depyf/tests/test_pytorch/test_debug_function_aot.py:6, code: if b.sum() < 0:
sum_1: "f32[]" = torch.ops.aten.sum.default(primals_2); primals_2 = None
lt: "b8[]" = torch.ops.aten.lt.Scalar(sum_1, 0); sum_1 = None
# File: depyf/tests/test_pytorch/test_debug_function_aot.py:5, code: x = a / (torch.abs(a) + 1)
neg: "f32[10]" = torch.ops.aten.neg.default(tangents_1)
div_1: "f32[10]" = torch.ops.aten.div.Tensor(primals_1, add)
div_2: "f32[10]" = torch.ops.aten.div.Tensor(div_1, add); div_1 = None
mul: "f32[10]" = torch.ops.aten.mul.Tensor(neg, div_2); neg = div_2 = None
div_3: "f32[10]" = torch.ops.aten.div.Tensor(tangents_1, add); tangents_1 = add = None
sign: "f32[10]" = torch.ops.aten.sign.default(primals_1); primals_1 = None
mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(mul, sign); mul = sign = None
# File: depyf/tests/test_pytorch/test_debug_function_aot.py:5, code: x = a / (torch.abs(a) + 1)
add_1: "f32[10]" = torch.ops.aten.add.Tensor(div_3, mul_1); div_3 = mul_1 = None
return pytree.tree_unflatten([div, lt, add_1, None], self._out_spec) |
This is plausible I guess, although strings aren't valid types. I guess these "type" annotations never were intended to be parsed for real in Python. |
@ezyang I'm working to make Plus, since these are meant to be python code, they'd better have correct syntax I think. |
OK, well, you'll have to figure out the test fails first |
@ezyang I have updated all the failed tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR fixes two cases when fx generated code is invalid in python (syntax error): 1. multiple type annotation in one line: `var1: annotation1, var2: annotation2 = function_call()` 2. invalid type annotation for scalars like `var1: f32[] = function_call()`. Pull Request resolved: pytorch#113345 Approved by: https://github.com/ezyang
@youkaichao This is only from the |
@Chillee yes, I use the code here: The shape annotation greatly helps when users step through the code. |
@Chillee which aspect makes it annoying? |
@youkaichao It's just additional visual clutter for every line. I would prefer to instead have an option to add it for use cases like yours. |
My opinion is that readable code should be valid code, too. If someone wants to only read the code, I suppose only adding two quotes |
fwiw, I've since changed my mind on this. Making the readable code valid is quite useful :P |
This PR fixes two cases when fx generated code is invalid in python (syntax error):
var1: annotation1, var2: annotation2 = function_call()
var1: f32[] = function_call()
.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng