Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX][CodeGen] Make sure fx code is valid in python #113345

Closed
wants to merge 6 commits into from

Conversation

youkaichao
Copy link
Collaborator

@youkaichao youkaichao commented Nov 9, 2023

This PR fixes two cases when fx generated code is invalid in python (syntax error):

  1. multiple type annotation in one line: var1: annotation1, var2: annotation2 = function_call()
  2. invalid type annotation for scalars like var1: f32[] = function_call().

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng

@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Nov 9, 2023
Copy link

pytorch-bot bot commented Nov 9, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113345

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 082709a with merge base 8c704f7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@youkaichao
Copy link
Collaborator Author

For those who are triaging PRs, please route this to @jansel .

@ezyang
Copy link
Contributor

ezyang commented Nov 9, 2023

Can you show us the before and after code?

@ezyang ezyang requested review from ezyang and Chillee November 9, 2023 14:48
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 9, 2023
@youkaichao
Copy link
Collaborator Author

Sure.

Before:

the first line primals_1: f32[10], primals_2: f32[10], tangents_1: f32[10], has a syntax error. The line lt: b8[] and sum_1: f32[] also causes syntax error.

def forward(self, primals, tangents):
    primals_1: f32[10], primals_2: f32[10], tangents_1: f32[10], = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    # File: depyf/tests/test_pytorch/test_debug_function_aot.py:5, code: x = a / (torch.abs(a) + 1)
    abs_1: f32[10] = torch.ops.aten.abs.default(primals_1)
    add: f32[10] = torch.ops.aten.add.Tensor(abs_1, 1);  abs_1 = None
    div: f32[10] = torch.ops.aten.div.Tensor(primals_1, add)
    
    # File: depyf/tests/test_pytorch/test_debug_function_aot.py:6, code: if b.sum() < 0:
    sum_1: f32[] = torch.ops.aten.sum.default(primals_2);  primals_2 = None
    lt: b8[] = torch.ops.aten.lt.Scalar(sum_1, 0);  sum_1 = None
    
    # File: depyf/tests/test_pytorch/test_debug_function_aot.py:5, code: x = a / (torch.abs(a) + 1)
    neg: f32[10] = torch.ops.aten.neg.default(tangents_1)
    div_1: f32[10] = torch.ops.aten.div.Tensor(primals_1, add)
    div_2: f32[10] = torch.ops.aten.div.Tensor(div_1, add);  div_1 = None
    mul: f32[10] = torch.ops.aten.mul.Tensor(neg, div_2);  neg = div_2 = None
    div_3: f32[10] = torch.ops.aten.div.Tensor(tangents_1, add);  tangents_1 = add = None
    sign: f32[10] = torch.ops.aten.sign.default(primals_1);  primals_1 = None
    mul_1: f32[10] = torch.ops.aten.mul.Tensor(mul, sign);  mul = sign = None
    
    # File: depyf/tests/test_pytorch/test_debug_function_aot.py:5, code: x = a / (torch.abs(a) + 1)
    add_1: f32[10] = torch.ops.aten.add.Tensor(div_3, mul_1);  div_3 = mul_1 = None
    return pytree.tree_unflatten([div, lt, add_1, None], self._out_spec)

After:

def forward(self, primals, tangents):
    primals_1: "f32[10]"; primals_2: "f32[10]"; tangents_1: "f32[10]"

    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    # File: depyf/tests/test_pytorch/test_debug_function_aot.py:5, code: x = a / (torch.abs(a) + 1)
    abs_1: "f32[10]" = torch.ops.aten.abs.default(primals_1)
    add: "f32[10]" = torch.ops.aten.add.Tensor(abs_1, 1);  abs_1 = None
    div: "f32[10]" = torch.ops.aten.div.Tensor(primals_1, add)
    
    # File: depyf/tests/test_pytorch/test_debug_function_aot.py:6, code: if b.sum() < 0:
    sum_1: "f32[]" = torch.ops.aten.sum.default(primals_2);  primals_2 = None
    lt: "b8[]" = torch.ops.aten.lt.Scalar(sum_1, 0);  sum_1 = None
    
    # File: depyf/tests/test_pytorch/test_debug_function_aot.py:5, code: x = a / (torch.abs(a) + 1)
    neg: "f32[10]" = torch.ops.aten.neg.default(tangents_1)
    div_1: "f32[10]" = torch.ops.aten.div.Tensor(primals_1, add)
    div_2: "f32[10]" = torch.ops.aten.div.Tensor(div_1, add);  div_1 = None
    mul: "f32[10]" = torch.ops.aten.mul.Tensor(neg, div_2);  neg = div_2 = None
    div_3: "f32[10]" = torch.ops.aten.div.Tensor(tangents_1, add);  tangents_1 = add = None
    sign: "f32[10]" = torch.ops.aten.sign.default(primals_1);  primals_1 = None
    mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(mul, sign);  mul = sign = None
    
    # File: depyf/tests/test_pytorch/test_debug_function_aot.py:5, code: x = a / (torch.abs(a) + 1)
    add_1: "f32[10]" = torch.ops.aten.add.Tensor(div_3, mul_1);  div_3 = mul_1 = None
    return pytree.tree_unflatten([div, lt, add_1, None], self._out_spec)

@ezyang
Copy link
Contributor

ezyang commented Nov 9, 2023

This is plausible I guess, although strings aren't valid types. I guess these "type" annotations never were intended to be parsed for real in Python.

@youkaichao
Copy link
Collaborator Author

@ezyang I'm working to make torch.compile debuggable in the https://github.com/thuml/depyf project. I try to link code to fx graphs, and it would be of great help if users can debug these functions. (By debugging, I mean step-by-step executing these functions, say to find which op leads to NaN). Therefore, I need to make these functions valid python code, just like this pr: #111635 .

Plus, since these are meant to be python code, they'd better have correct syntax I think.

@ezyang
Copy link
Contributor

ezyang commented Nov 10, 2023

OK, well, you'll have to figure out the test fails first

@youkaichao
Copy link
Collaborator Author

@ezyang I have updated all the failed tests.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

@ezyang
Copy link
Contributor

ezyang commented Nov 10, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 10, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
This PR fixes two cases when fx generated code is invalid in python (syntax error):

1. multiple type annotation in one line: `var1: annotation1, var2: annotation2 = function_call()`
2. invalid type annotation for scalars like `var1: f32[] = function_call()`.

Pull Request resolved: pytorch#113345
Approved by: https://github.com/ezyang
@Chillee
Copy link
Contributor

Chillee commented Nov 24, 2023

@youkaichao This is only from the print_readable output - do you need that to be valid python code? I find the additional strings to be somewhat annoying.

@youkaichao
Copy link
Collaborator Author

@Chillee yes, I use the code here:
https://github.com/thuml/depyf/blob/dbc220a081c6f940071016499f3ac4ef06bc0030/depyf/explain/patched_lazy_format_graph_code.py#L16

The shape annotation greatly helps when users step through the code.

@youkaichao
Copy link
Collaborator Author

@Chillee which aspect makes it annoying?

@Chillee
Copy link
Contributor

Chillee commented Nov 27, 2023

@youkaichao It's just additional visual clutter for every line. I would prefer to instead have an option to add it for use cases like yours.

@youkaichao
Copy link
Collaborator Author

My opinion is that readable code should be valid code, too. If someone wants to only read the code, I suppose only adding two quotes "" does not hurt. Human eyes can easily ignore them.

@Chillee
Copy link
Contributor

Chillee commented Mar 20, 2024

fwiw, I've since changed my mind on this. Making the readable code valid is quite useful :P

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo open source release notes: fx release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants