Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Conversation

@alanwaketan
Copy link

Summary:
This is a PoC to break graphs in order to support gradient hooks for the
AOT backend. E2E use cases are composing DDP/FSDP with dynamo.

Test Plan:
WIP.

Summary:
This is a PoC to break graphs in order to support gradient hooks for the
AOT backend. E2E use cases are composing DDP/FSDP with dynamo.

Test Plan:
WIP.
@alanwaketan

This comment was marked as outdated.

@jansel
Copy link
Contributor

jansel commented Jul 2, 2022

Overall this seems reasonable, it proves you can insert graph breaks inside the backend.

I'm wondering if putting communication primitives in the graph would also be possible.

@alanwaketan
Copy link
Author

I'm wondering if putting communication primitives in the graph would also be possible.

Do you mean tracing through the gradient hooks or manually inserting communication ops?

@jansel
Copy link
Contributor

jansel commented Jul 8, 2022

Do you mean tracing through the gradient hooks or manually inserting communication ops?

We could do whichever is easier, but the result would be a graph that contains communication ops.

@alanwaketan
Copy link
Author

This is the output from dynamo_example2.py which is an improved PoC that actually can have gradient hooks fired in between.

(pytorch39) jwtan@ip-10-200-66-59:/fsx/users/jwtan/work/torchdynamo$ gpurun python dynamo_example2.py 
STAGE:2022-07-14 06:12:26 8511:8511 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
graph_break_compiler() called with FX graph:
opcode         name              target                             args                                        kwargs
-------------  ----------------  ---------------------------------  ------------------------------------------  ------------------
placeholder    x                 x                                  ()                                          {}
placeholder    self_net1_weight  self_net1_weight                   ()                                          {}
placeholder    self_net1_bias    self_net1_bias                     ()                                          {}
placeholder    self_net2_weight  self_net2_weight                   ()                                          {}
placeholder    self_net2_bias    self_net2_bias                     ()                                          {}
placeholder    self_net3_weight  self_net3_weight                   ()                                          {}
placeholder    self_net3_bias    self_net3_bias                     ()                                          {}
placeholder    self_net4_weight  self_net4_weight                   ()                                          {}
placeholder    self_net4_bias    self_net4_bias                     ()                                          {}
call_function  linear            <built-in function linear>         (x, self_net1_weight, self_net1_bias)       {}
call_function  relu              <function relu at 0x7f353b8f8160>  (linear,)                                   {'inplace': False}
call_function  linear_1          <built-in function linear>         (relu, self_net2_weight, self_net2_bias)    {}
call_function  relu_1            <function relu at 0x7f353b8f8160>  (linear_1,)                                 {'inplace': False}
call_function  linear_2          <built-in function linear>         (relu_1, self_net3_weight, self_net3_bias)  {}
call_function  relu_2            <function relu at 0x7f353b8f8160>  (linear_2,)                                 {'inplace': False}
call_function  linear_3          <built-in function linear>         (relu_2, self_net4_weight, self_net4_bias)  {}
output         output            output                             ((linear_3,),)                              {}

graph_break_compiler() called with splitted graphs:
opcode         name              target                      args                                   kwargs
-------------  ----------------  --------------------------  -------------------------------------  --------
placeholder    self_net1_bias    self_net1_bias              ()                                     {}
placeholder    self_net1_weight  self_net1_weight            ()                                     {}
placeholder    x                 x                           ()                                     {}
call_function  linear            <built-in function linear>  (x, self_net1_weight, self_net1_bias)  {}
output         output            output                      ((linear,),)                           {}

opcode         name              target                             args                                      kwargs
-------------  ----------------  ---------------------------------  ----------------------------------------  ------------------
placeholder    linear            linear                             ()                                        {}
placeholder    self_net2_weight  self_net2_weight                   ()                                        {}
placeholder    self_net2_bias    self_net2_bias                     ()                                        {}
call_function  relu              <function relu at 0x7f353b8f8160>  (linear,)                                 {'inplace': False}
call_function  linear_1          <built-in function linear>         (relu, self_net2_weight, self_net2_bias)  {}
call_function  relu_1            <function relu at 0x7f353b8f8160>  (linear_1,)                               {'inplace': False}
output         output            output                             ((relu_1,),)                              {}

opcode         name              target                             args                                        kwargs
-------------  ----------------  ---------------------------------  ------------------------------------------  ------------------
placeholder    relu_1            relu_1                             ()                                          {}
placeholder    self_net3_weight  self_net3_weight                   ()                                          {}
placeholder    self_net3_bias    self_net3_bias                     ()                                          {}
placeholder    self_net4_weight  self_net4_weight                   ()                                          {}
placeholder    self_net4_bias    self_net4_bias                     ()                                          {}
call_function  linear_2          <built-in function linear>         (relu_1, self_net3_weight, self_net3_bias)  {}
call_function  relu_2            <function relu at 0x7f353b8f8160>  (linear_2,)                                 {'inplace': False}
call_function  linear_3          <built-in function linear>         (relu_2, self_net4_weight, self_net4_bias)  {}
output         output            output                             ((linear_3,),)                              {}

AOT compiled all 3 modules

graph_break_compiler() called with stitched graph:
opcode         name              target                                                                        args                                                                          kwargs
-------------  ----------------  ----------------------------------------------------------------------------  ----------------------------------------------------------------------------  --------
placeholder    x                 x                                                                             ()                                                                            {}
placeholder    self_net1_weight  self_net1_weight                                                              ()                                                                            {}
placeholder    self_net1_bias    self_net1_bias                                                                ()                                                                            {}
placeholder    self_net2_weight  self_net2_weight                                                              ()                                                                            {}
placeholder    self_net2_bias    self_net2_bias                                                                ()                                                                            {}
placeholder    self_net3_weight  self_net3_weight                                                              ()                                                                            {}
placeholder    self_net3_bias    self_net3_bias                                                                ()                                                                            {}
placeholder    self_net4_weight  self_net4_weight                                                              ()                                                                            {}
placeholder    self_net4_bias    self_net4_bias                                                                ()                                                                            {}
call_function  forward           <bound method aot_module_simplified.<locals>.AOTModule.forward of AOTModule(  (self_net1_bias, self_net1_weight, x)                                         {}
                                   (orig_module): GraphModule()
                                 )>
call_method    linear            __getitem__                                                                   (forward, 0)                                                                  {}
call_function  forward_1         <bound method aot_module_simplified.<locals>.AOTModule.forward of AOTModule(  (linear, self_net2_weight, self_net2_bias)                                    {}
                                   (orig_module): GraphModule()
                                 )>
call_method    relu_1            __getitem__                                                                   (forward_1, 0)                                                                {}
call_function  forward_2         <bound method aot_module_simplified.<locals>.AOTModule.forward of AOTModule(  (relu_1, self_net3_weight, self_net3_bias, self_net4_weight, self_net4_bias)  {}
                                   (orig_module): GraphModule()
                                 )>
call_method    linear_3          __getitem__                                                                   (forward_2, 0)                                                                {}
STAGE:2022-07-14 06:12:48 8511:8511 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2022-07-14 06:12:48 8511:8511 output_json.cpp:417] Completed Stage: Post Processing
output         output            output                                                                        ((linear_3,),)                                                                {}

gradient hook fired
gradient hook fired
gradient hook fired
gradient hook fired
gradient hook fired
gradient hook fired
gradient hook fired
gradient hook fired
8511: iteration 0, loss 0.8432953953742981

@alanwaketan alanwaketan requested a review from wconstab July 14, 2022 18:27
@alanwaketan
Copy link
Author

alanwaketan commented Jul 15, 2022

Here is the profiler output which indicates that the gradient hooks are fired in between each compiled AOT submodule.
Screen Shot 2022-07-14 at 10 48 38 AM

@facebook-github-bot
Copy link
Contributor

Hi @alanwaketan!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

@jansel jansel removed their request for review August 2, 2022 21:02
@jansel
Copy link
Contributor

jansel commented Oct 15, 2022

We have migrated torchdynamo to torch._dynamo and will use the pytorch/pytorch repo for future development. Please resubmit this PR to https://github.com/pytorch/pytorch/

More details and instructions to port this PR over can be found in #1588

@jansel jansel closed this Oct 15, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants