Skip to content

Conversation

voznesenskym
Copy link
Collaborator

@voznesenskym voznesenskym commented Oct 9, 2023

Stack from ghstack (oldest at bottom):

is_allowed is a tricky bit of functionality - it sits early up in builder and is used to drive the creation of TorchVariable (more notes here, meta only https://fb.workplace.com/groups/pytorch.dev/permalink/1393563781222098/)

If we are tracing distributed in full, we want to route certain calls in distributed to NOT PASS is_allowed (this does not, confusingly, mean that they are not allowed, lol, but rather that we dont want them to become TorchVariable), others, we are fine with preserving.

cc @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 9, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110894

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 50bfd09 with merge base 1e7947b (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

voznesenskym added a commit that referenced this pull request Oct 9, 2023
ghstack-source-id: b820faf
Pull Request resolved: #110894
@voznesenskym voznesenskym changed the title distributed skipfile and allow check Dynamo - config gated is_allowed routing, skipfiles for distributed Oct 9, 2023
_module_dir(torch) + "_export/wrappers.py",
}

if torch.distributed.is_available():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config this too, I guess.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I am not sure this matters, let's see if tests fail, the inline vs allow refactor has made this a little more confusing.

@albanD albanD removed their request for review October 9, 2023 22:03
…stributed"


`is_allowed` is a tricky bit of functionality - it sits early up in builder and is used to drive the creation of TorchVariable (more notes here, meta only https://fb.workplace.com/groups/pytorch.dev/permalink/1393563781222098/) 

If we are tracing distributed in full, we want to route certain calls in distributed to NOT PASS is_allowed (this does not, confusingly, mean that they are not allowed, lol, but rather that we dont want them to become TorchVariable), others, we are fine with preserving.



cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
voznesenskym added a commit that referenced this pull request Oct 9, 2023
ghstack-source-id: 9d25c33
Pull Request resolved: #110894

Fix
…stributed"


`is_allowed` is a tricky bit of functionality - it sits early up in builder and is used to drive the creation of TorchVariable (more notes here, meta only https://fb.workplace.com/groups/pytorch.dev/permalink/1393563781222098/) 

If we are tracing distributed in full, we want to route certain calls in distributed to NOT PASS is_allowed (this does not, confusingly, mean that they are not allowed, lol, but rather that we dont want them to become TorchVariable), others, we are fine with preserving.



cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
voznesenskym added a commit that referenced this pull request Oct 10, 2023
ghstack-source-id: 2b9902b
Pull Request resolved: #110894

Fix

Fix
@ezyang ezyang requested a review from yanboliang October 10, 2023 20:33
@ezyang
Copy link
Contributor

ezyang commented Oct 10, 2023

I'm signing off in terms of functional correctness for the hack.

However, @yanboliang should have the final say here, since this increases his workload for the refactor.

FILENAME_INLINELIST |= set(
glob.glob(_module_dir(torch) + "distributed/**/*.py", recursive=True),
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add torch.distributed into the SUBMODULE_INLINELIST? That would be a more easy way to force inline all files under a submodule.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this fails for whatever reason.

Copy link
Collaborator Author

@voznesenskym voznesenskym Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want a crack at debugging:

  1. voz/fsdp_autograd3
  2. build
  3. Patch this in
  4. torchrun --standalone --nproc_per_node=2 fsdp.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would break trunk since you land after #110835. Do you mind to send a forward fix? Or I can help to forward fix it.


# A subcheck of is_allowed, we utilize this for patching is_allowed around distributed.
# We do this because we want to allow these to be traced, and hence covered in skipfiles, but we do not want them to
# become TorchVariable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't want to make them as TorchVariable, we should not make them is_allowed returning True after my refactor. If you just want to inline these functions, can you treat them as regular python functions and go through the regular inline rules?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need some to go one way, some the other, the logic here is correct - and they shouldn't go through inline rules.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you expect them to be FX graph node? If yes, they should be wrapped as TorchVariable and is_allowed returning True. But this is not what described in the comments.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is they are under torch.* so they get routed to TorchVariable - for some like

 if obj in [
            torch.distributed._functional_collectives_impl._all_gather_into_tensor,
            torch.distributed._functional_collectives_impl._all_reduce,
            torch.distributed._functional_collectives_impl._reduce_scatter_tensor,
            torch.distributed._functional_collectives_impl._all_reduce_coalesced,
            torch.distributed._functional_collectives_impl._all_gather_into_tensor_coalesced,
            torch.distributed._functional_collectives_impl._reduce_scatter_tensor_coalesced,
        ]:

It is correct for them to become TorchVariable

For others (the rest of the checks) - we need to make sure they do not become TorchVariable. If we don't have this function, the wrong types will become TorchVariable instead of passing through CollectiveFunctionRewriteVariable.can_rewrite(value) CollectiveFunctionRewriteVariable.can_rewrite(value) etc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand Voz's explanation, but I certainly agree with Voz that there is something funny going on here.

To give an alternate example, on main, I'd like to inline into functions in _functional_collectives. I applied this patch:

diff --git a/torch/_dynamo/allowed_functions.py b/torch/_dynamo/allowed_functions.py
index 8beca2b4502..d155023ef54 100644
--- a/torch/_dynamo/allowed_functions.py
+++ b/torch/_dynamo/allowed_functions.py
@@ -175,6 +175,7 @@ def _allowed_function_ids():
             # issues observed in
             # https://github.com/pytorch/pytorch/issues/108269
             "torch.distributed.algorithms.",
+            "torch.distributed._functional_collectives.",
         )
         allowed_modules_dot = tuple([x + "." for x in allowed_modules])
         module = inspect.getmodule(obj)

But it does not work; for some reason I appear to still be trying to place things like all_to_all_single directly into the graph. With Voz's PR and my patch, and turning on trace_distributed, only then do I get the expected behavior of inlining.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more of a meta question but do we have a design doc for "the new skipfiles/allowed_functions design"? I am a little wary of all this complexity around distributed, hopefully we can do it more cleanly in a redesign

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wconstab I have this doc to track the new skipfiles/allowed_function thing: https://docs.google.com/document/d/15gk0B-aLGfQTdffTcFbPzA3DLR1ZwrnJawmt_4kflOY/edit?userstoinvite=jansel@meta.com&sharingaction=manageaccess&role=writer . Feel free to comment and leave feedback.

ezyang pushed a commit to ezyang/pytorch that referenced this pull request Oct 11, 2023
ghstack-source-id: 2b9902b
Pull Request resolved: pytorch#110894

Fix

Fix
…stributed"


`is_allowed` is a tricky bit of functionality - it sits early up in builder and is used to drive the creation of TorchVariable (more notes here, meta only https://fb.workplace.com/groups/pytorch.dev/permalink/1393563781222098/) 

If we are tracing distributed in full, we want to route certain calls in distributed to NOT PASS is_allowed (this does not, confusingly, mean that they are not allowed, lol, but rather that we dont want them to become TorchVariable), others, we are fine with preserving.



cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
…stributed"


`is_allowed` is a tricky bit of functionality - it sits early up in builder and is used to drive the creation of TorchVariable (more notes here, meta only https://fb.workplace.com/groups/pytorch.dev/permalink/1393563781222098/) 

If we are tracing distributed in full, we want to route certain calls in distributed to NOT PASS is_allowed (this does not, confusingly, mean that they are not allowed, lol, but rather that we dont want them to become TorchVariable), others, we are fine with preserving.



cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
[name for name, _ in inspect.getmembers(torch.Tensor) if re.match(r"^is_.*", name)]
)

trace_distributed = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mistake

voznesenskym added a commit that referenced this pull request Oct 11, 2023
ghstack-source-id: e99f571
Pull Request resolved: #110894

Fix

Fix

A little config magic

A little config magic
@voznesenskym voznesenskym changed the title Dynamo - config gated is_allowed routing, skipfiles for distributed Dynamo - config gated torch.distributed allow, exclusion for special leaf funcs Oct 11, 2023
_find_torch_objects(torch)
_find_torch_objects(math)

if config.trace_distributed:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have to import torch.distribute._fun... here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we do not.

@voznesenskym
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 12, 2023
@voznesenskym
Copy link
Collaborator Author

@yanboliang, revert if its not up to snuff, nbd nbd :)

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@voznesenskym
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/voznesenskym/239/head branch October 15, 2023 14:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants