Skip to content

Conversation

youkaichao
Copy link
Collaborator

@youkaichao youkaichao commented Aug 29, 2023

With a new tool depyf to decompile bytecode into human readable source code, understanding dynamo becomes much more easier.

cc @svekars @carljparker @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 29, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108147

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 02b2036 with merge base 1b3dc05 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@youkaichao
Copy link
Collaborator Author

@jansel Maybe we can also refactor the structure of the doc. We can first show readers how it works, and then introduce various types of guards, as guards will occur in the decompiled source code. However, this is a moderate change, and I would like to hear your opinions before moving forward.

Anyway, finally, with the tools I created (torch._dynamo.eval_frame._debug_get_cache_entry_list and depyf), it seems dynamo can be much easier to understand!

BTW:

@msaroufim is considering to integrate the depyf package into dynamo debugging output. Can't wait to support that!

Comment on lines 294 to 303
source code of __compiled_fn_0:
def ignore_this_function_name(self, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
l_a_ = L_a_
l_b_ = L_b_
abs_1 = torch.abs(l_a_)
add = abs_1 + 1; abs_1 = None
truediv = l_a_ / add; l_a_ = add = None
sum_1 = l_b_.sum(); l_b_ = None
lt = sum_1 < 0; sum_1 = None
return (truediv, lt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually an FX graph (since the compiler above returned the original graph), so if you just call print() on it it might be cleaner than getsource().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not FX graph actually. It is the forward function of that graph. print on it gives <function forward at 0x17c7212d0>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you could print(fn.__self__) then. Or change the code above to not return .forward.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, somewhat complicated, but print(__compiled_fn_0._torchdynamo_orig_callable.__self__) does work.

Comment on lines 305 to 317
source code of __resume_at_30_1:
def ignore_this_function_name(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b

def fn(a, b):
source code of __resume_at_38_2:
def ignore_this_function_name(a, b):
x = a / (torch.abs(a) + 1)
lt = b.sum() < 0
return x, lt
if b.sum() < 0:
b = b * -1
return x * b
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This source code is wrong. It is just showing the original code, without any bytecode level changes done by dynamo.

If you print out the bytecode you will see:

  1. different args to the function not in this source code
  2. JUMP_ABSOLUTE as the first instructions, without any corresponding line in the source code. This causes these functions to start in the middle rather than at the top.

In reality, it is more like:

 def __resume_at_30_1(b, x):
       JUMP_ABSOLUTE <target>
       x = a / (torch.abs(a) + 1)
       if b.sum() < 0:
           <target>
           b = b * -1
       return x * b

There is actually no way to represent this precisely with Python source code, since python doesn't have a goto instruction.

Copy link
Collaborator Author

@youkaichao youkaichao Aug 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worry, this can be easily fixed by removing unreachable bytecode. I will use my depyf to decompile it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed it in 4678346. Previously, inspect.getsource does not work for __resume_at_xxx functions. Now I use depyf.decompile to decompile its source code.

Meanwhile, I noticed that the function names of __resume_at_xxx are not valid python function names. Maybe this should be fixed in pytorch master? WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What makes it invalid?

>>> def __resume_at_30_1(b, x):
...   pass
... 
>>> 

Copy link
Collaborator Author

@youkaichao youkaichao Aug 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__resume_at_30_1.__code__.co_name is actually <resume in toy_example>. The variable name of this function __resume_at_30_1 does not match its "codename" stored in co_name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, that is for stack traces

Copy link
Collaborator Author

@youkaichao youkaichao Aug 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible fix: change https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/resume_execution.py#L368 this line to a valid variable name.

A better fix: change it to be the same as the function variable name, i.e. __resume_at_30_1 stuff.

The best fix I think: combining both names, and change both the variable name and co_name to something like __resume_at_30_1_in_toy_example. This way, maybe we can even remove the unique id, and just use __resume_at_{offset}_in_{funcname}.

@youkaichao
Copy link
Collaborator Author

@jansel I made a picture today, which I found pretty illustrative for new users. Do you want to include it into the doc?

image

@colesbury colesbury added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: dynamo labels Aug 30, 2023
@jansel
Copy link
Contributor

jansel commented Aug 30, 2023

Yes we can include that.

@youkaichao
Copy link
Collaborator Author

Okay, the flowchart is added into the documentation. Regarding this PR, @jansel do you have any further suggestions?

Separate from this PR, we have two remaining issues:

  • Can we unify the variable name and co_name of resume functions to be __resume_at_{offset}_in_{funcname}_{uid}?
  • Do we want to integrate the depyf package in the logging of dynamo, so that it prints human readable source code rather than hard-to-understand bytecodes?

@jansel
Copy link
Contributor

jansel commented Aug 31, 2023

* Can we unify the variable name and `co_name` of resume functions to be `__resume_at_{offset}_in_{funcname}_{uid}`?

The main way these are used is showing up in stacktraces when errors happen. So I think a more human readable name is ideal. Offsets/uids are not human readable, and people won't know what __resume_at_ is. So while this change will make the output on this page better, it will make the common case worse.

* Do we want to integrate the `depyf` package in the logging of dynamo, so that it prints human readable source code rather than hard-to-understand bytecodes?

I don't think we should install it by default, since it is a debugging tool. I also worry about how reliable it will be for more complex bytecodes (there are a lot of them) and across Python versions.

@youkaichao
Copy link
Collaborator Author

I also worry about how reliable it will be for more complex bytecodes (there are a lot of them) and across Python versions.

This is also my concern. How can I test it across many complex bytecodes generated from dynamo? I would be happy to add them as my testcases.

One problem, though, is how to automatically test the correctness of decompiled code. My tests at https://github.com/youkaichao/depyf/blob/master/tests/test.py are simple programs that I know the output. For dynamo bytecode, it is more difficult to test.

@jansel
Copy link
Contributor

jansel commented Aug 31, 2023

For the resume_at functions, any bytecode is possible since we copy user bytecode into the output. Bytecodes also change from Python version to Python version.

Not sure how much you want to invest in depfy. If you want to turn it into a full-fledged decompiler then running CPython unit tests in a mode where you compile->decompile->compile would a good starting place. I'd expect that to become a pretty big project.

One other thing to mention is we are considering changing TorchDynamo guards to be implemented in C++ for performance reasons. So I don't want that to come as a surprise.

@youkaichao
Copy link
Collaborator Author

Not sure how much you want to invest in depfy.

I want to limit it to the understanding of torchdynamo. The main usecase might be understanding the guarded code, which has not so compilcated bytecode.

Regarding with the resume_at functions, I'm thinking of an alternative approach: we have the source code of the original function, and maybe we can take advantage at that, to prune the ast tree to get the code, rather than decompile all the bytecodes.

Guards in C++ is Okay, and we can understand what it is checking by inspecting the closure variables it captures. Any ongoing discussion on this topic?

@jansel
Copy link
Contributor

jansel commented Aug 31, 2023

For guards specifically, we actually generate them by creating Python code and calling exec() to generate the bytecode. So it might be cleaner to just keep the source code we used to generate them around.

You can print out the source code by setting TORCHDYNAMO_PRINT_GUARDS=1.

It gets compiled to bytecode here:

exec(pycode, global_builder.scope, out)

We could add something like:

guard_fn.source_code = guard_body

So it would be easier to inspect the generated guards. Maybe we could even wire things up so inspect.getsource() works on them.

@youkaichao
Copy link
Collaborator Author

guard_fn.source_code = guard_body

This is great for guards.

And for the resume functions, I think starting with the original function's ast might be a good idea. The original function already provides much information. However, how to link the ast with bytecodes requires a close look at how dynamo generates the bytecodes.

Is it possible to get me into the dev slack https://bit.ly/ptslack ? We can have a dedicated discussion there, which is more convenient than "chatting" at github I think :)

@jansel
Copy link
Contributor

jansel commented Aug 31, 2023

I sent you a slack invite.

@youkaichao
Copy link
Collaborator Author

I sent you a slack invite.

How can I accept the invitation? I didn't receive any email.

@youkaichao
Copy link
Collaborator Author

The one failing check is not caused by this PR, I think :)

@youkaichao
Copy link
Collaborator Author

Joined the slack channel. Thank you!

@jansel jansel added the module: docs Related to our documentation, both in docs/ and docblocks label Aug 31, 2023
@jansel jansel added the topic: not user facing topic category label Aug 31, 2023
@jansel
Copy link
Contributor

jansel commented Aug 31, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 31, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Copy link
Contributor

@svekars svekars left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, just a couple of editorial suggestions!

The following diagram demonstrates how ``torch.compile`` transforms and optimizes user-written code:

Note that we pass a simple `my_compiler` function as the backend compiler, therefore the subgraph code `__resume_at_38_2`, `__resume_at_30_1`, and `__compiled_fn_0._torchdynamo_orig_callable` remain python code. However, if we use other backends like the built-in `inductor`, the subgraph code will be compiled CUDA kernels for GPU or C++ code for CPU.
.. image:: _static/img/dynamo/flowchart.jpg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a paragraph that describes what is going on on the diagram?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, fixed in 454676e.

youkaichao and others added 2 commits September 1, 2023 00:03
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

youkaichao and others added 5 commits September 1, 2023 00:04
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
@youkaichao
Copy link
Collaborator Author

The API _debug_get_cache_entry_list has changed since #108335 . I will update the documentation accordingly today.

@youkaichao
Copy link
Collaborator Author

@svekars Hi, do you have any additional comments?

@youkaichao
Copy link
Collaborator Author

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@youkaichao
Copy link
Collaborator Author

That's strange. I unexpectedly have the access to call pytorchmerge bot to merge :)

@jansel @svekars are there any rules to follow for calling the mergebot?

@youkaichao youkaichao deleted the doc_dynamo_deepdive branch September 4, 2023 00:59
@jansel
Copy link
Contributor

jansel commented Sep 9, 2023

I think it let you because I approved the PR. Seems fine to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: docs Related to our documentation, both in docs/ and docblocks module: dynamo open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants