Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] PT2-Friendly Traceable, Functional Collective Communication APIs #93173

Open
wconstab opened this issue Jan 27, 2023 · 43 comments
Open

[RFC] PT2-Friendly Traceable, Functional Collective Communication APIs #93173

wconstab opened this issue Jan 27, 2023 · 43 comments
Assignees
Labels
feature A request for a proper, new feature. module: ProxyTensor make_fx and related module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@wconstab
Copy link
Contributor

wconstab commented Jan 27, 2023

🚀 Traceable Collectives!

Collective APIs (e.g. all_reduce, all_gather, ...) are used in distributed PyTorch programs, but do not compose cleanly with compilers.

Specifically, torchDynamo and the AotAutograd pipeline for decompositions and functionalization do not work with the existing c10d collective APIs

  • there are not functional variants of these collectives
  • ProcessGroup and Work objects interfere with graph tracing and pollute the IR with non-tensor objects

XLA also currently has to implement some workarounds, to marry the XLA collective ops via lazy tensor tracing with the existing PyTorch / C10D side. They have to use a custom ProcessGroup implementation and swizzle PTD PG creation functions.

Goals

  1. provide collectives that are traceable with the PT2 stack and XLA stack
  2. provide functional collectives, which are easier for IR transformations to reason about
  3. support eager and compiled flows with the same API
  4. use plain data types in the traced API
  5. allow tracing/compilation without requiring process group init
  6. support different frontends (DTensors, ProcessGroups, etc)
  7. support autograd for collective ops
  8. clean up c10d python bindings and dispatcher registrations

Non-goals

  1. Introduce multiple stream semantics in inductor

image

New traceable collectives python API

def collective(input:Tensor, *, group: GROUP_TYPE) -> AsyncTensor

GROUP_TYPE is a Union over List, DeviceMesh, ProcessGroup, etc. It allows flexible usage by different frontends.

AsyncTensor is a Tensor subclass that calls wait() automatically when the tensor is used by another op.

New Dispatcher Collectives

aten::collective(Tensor, *, str tag, int[] ranks, int stride) -> Tensor`

These are the ops that actually get traced into a graph and can be manipulated by compiler passes.

The collective ops are functional, but compilers may be able to convert them to inplace. They are asynchronous.

These ops support meta device (for traceability), and support backwards via derivatives.yaml.

The semantics of these ops are that they return a real tensor, but you aren't allowed to access its data or storage.

c10d.wait(Tensor) -> Tensor

wait() must be called on the output of any collective before its underlying data or storage is accessed.

  • It is valid to peek at the size() or stride() (or probably other metadata) of a tensor returned from a collective, but not its data.
  • wait() is the only way to make an output from collectives safe to use by other non collective ops
  • we are considering whether wait(collective(collective)) can be implemented safely, but by default we assume it is not
    The semantics of wait are that you must only access the storage of the tensor returned from wait. You can't think of wait as mutating its input tensor and making it safe to use.

Alternatives

The following style of API has also been considered. Its main disadvantage is in requiring a user to first initialize a processgroup, but it is also opaque and not easily interchangeable with lists of ranks or DTensors. It doesn't allow us to easily represent MPMD collectives.

pg = init_process_group()
pg_id = dist.register_process_group(pg)
collective(tensor, pg_id)

Detailed Proposal

See Traceable Collectives Design

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @soumith @ngimel

@alanwaketan
Copy link
Collaborator

Thanks, @wconstab. This is awesome! (Leaving a comment to have myself cc on the issue.)

@wconstab
Copy link
Contributor Author

so one thing @ezyang pointed out is, we might not be able to support list[list[]] in our current codegen pipeline. But if we can assume rectangular shape, we can just do list[] and have an extra stride int that lets us make the second dimension.

But apart from that- is this api going to help for XLA? And is there anything wrong with it?

@ezyang
Copy link
Contributor

ezyang commented Jan 27, 2023

Remind me what tag is?

@JackCaoG
Copy link
Collaborator

@wconstab Can you expand a bit on do list[] and have an extra stride int that lets us make the second dimension. Is it something like
[0,1,2,3,4,5,6,7] and for stride 2, we have 2 group [0,2,4,6] and [1,3,5,7]? It is a bit confusing TBH lol.

@ezyang
Copy link
Contributor

ezyang commented Jan 28, 2023

Well, idk what Will has in mind, but I would instead have repped it the other way: [1,2,3,4,5,6,7,8], stride=4 gives me [[1,2,3,4],[5,6,7,8]]

@ezyang ezyang added oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 30, 2023
@ezyang ezyang added the module: ProxyTensor make_fx and related label Jan 30, 2023
@wconstab
Copy link
Contributor Author

Remind me what tag is?

@ezyang I'm not totally sure we need tag, but @kumpera has offered a few explanations for why it's needed and XLA has a similar 'channel_id' field that iiuc is used for a similar purpose.

The gist is that you may want to allow more than one ProcessGroup (which owns real things like cuda streams) to exist for the same set of ranks, and let user code or compiler code be explicit about which one it's working with. @kumpera can say more.

@kumpera
Copy link
Contributor

kumpera commented Jan 30, 2023

Well, idk what Will has in mind, but I would instead have repped it the other way: [1,2,3,4,5,6,7,8], stride=4 gives me [[1,2,3,4],[5,6,7,8]]

stride is probably unhelpful here, maybe group_size would be a better name.

@kumpera
Copy link
Contributor

kumpera commented Jan 30, 2023

Remind me what tag is?

Tag enables coordination of what PGs to use with the world outside of a compilation function, which is important for things like checkpointing or other eager affairs.

And within a function, it enables the same set of ranks to use different PGs and enable things like different stream priorities.

@ezyang
Copy link
Contributor

ezyang commented Jan 30, 2023

so what, we have to parse a string? Would it be better to be an int?

@kumpera
Copy link
Contributor

kumpera commented Jan 30, 2023

so what, we have to parse a string? Would it be better to be an int?

The tag is just a key in a dictionary so either way works.
The advantage of a string is that it would be easer to read on graph dumps.

@0x6b64
Copy link

0x6b64 commented Feb 1, 2023

Hi @wconstab, thank you so much creating an RFC for this work. A few questions for my clarity

  1. Is it right to say that the ProcessGroup/Backend API will continue to stay the same (especially from a perspective of 3p vendor implementing custom collective backends)? Looks like there is a parallel development to refactor the ProcessGroup interface here: [Feature] Dispatching PyTorch Distributed Collectives #86225

    The most noticeable part is that ProcessGroup will house a registry of backends each of which maps to the device for which they will work. At runtime, the actual collective backend if chosen depending on the device with which the input resides.

    Am I understanding this correctly that the collective API is a wrapping the call to ProcessGroup->Backend and holds onto the returned work item. I can see this is the implementation in _all_reduce in the WIP PR - https://github.com/pytorch/pytorch/pull/93379/files

    Also, there seems to be a concept of ProcessGroupWrapper's in distributed_c10d. Is it right to say that doesn't need to be decomposed into chained operations?

  2. I'm inferring that the DDP module will be updated to call the traceable_collectives allreduce instead of dist.allreduce, is this right?

  3. Does the inclusion of the traceable collectives in the ddp graph immediately imply that DDP strategy used prior to DDPOptimizer will be resurrected && the emergent positioning of the collective calls will result in optimal compute/network overlap. Point 3 is phrased loosely, but I hope it communicates the general question I'm asking.

@wconstab
Copy link
Contributor Author

wconstab commented Feb 1, 2023

For (1), we are thinking of building on top/alongside of existing processgroup. So while we offer new python frontend apis (list of ranks, processgroup, dtensor.mesh), we can 'desugar' these representations to a canonical format based on list+int which we can trace into a graph. Then backends (such as inductor) can translate them back into a real processgroup object and leverage the existing pg backends.

For (2, 3), we probably won't update the existing DDP codebase but we may offer some new way of getting a data-parallel model that is compiler friendly. In general, beyond simple DDP style, we want to make the inductor backend more capable of optimizing comm+compute overlap and memory allocations, etc. Thus it should become simpler to write a pretty naive 'ddp frontend' that then becomes optimal via the compiler backend.

@anj-s
Copy link

anj-s commented Feb 2, 2023

Does this API assume we need to know the ranks participating in the collectives?
aten::collective(Tensor, *, str tag, int[] ranks, int stride) -> Tensor
How can we get this information if we have not initialized the PG?

@wconstab
Copy link
Contributor Author

wconstab commented Feb 2, 2023

My assumption is it is generally possible to code up your script in terms of one input WORLD_SIZE, which usually comes from ENV anyway, so you can get it without initializing distributed. And definitely without initializing any pgs.

I do want to discuss a few related points further:

  1. is it bad practice to just 'peek at ENV to get world size'?

  2. should we tell people to init distributed before they get world_size?
    2a) relatedly, should we init distributed (but not init any process groups) before tracing a distributed program? Why?

  3. how important is tracing a size-agnostic program? Should we prioritize a way to trace a program that will work under any, say, power of 2 variant of WORLD_SIZE? If so this seems hard with our canonical form of list for ranks, even if we allowed list instead, since the size of that list is ultimately the thing we'd want to vary.

@anj-s
Copy link

anj-s commented Feb 2, 2023

My assumption is it is generally possible to code up your script in terms of one input WORLD_SIZE, which usually comes from ENV anyway, so you can get it without initializing distributed. And definitely without initializing any pgs.

I think we can assume that. However I am just wondering 1) Are the ranks going to be passed to the SPMD API? 2) why not init distributed?

I do want to discuss a few related points further:

  1. is it bad practice to just 'peek at ENV to get world size'?

I like ENV vars however having a self contained script has its advantages.

  1. should we tell people to init distributed before they get world_size?
    2a) relatedly, should we init distributed (but not init any process groups) before tracing a distributed program? Why?

Yes, same question here. For initializing process groups we need to know the mesh dimensions over which we will be sharding our model. This might not happen until after the autoparallel profiling. However initializing a global PG seems fine to me.

  1. how important is tracing a size-agnostic program? Should we prioritize a way to trace a program that will work under any, say, power of 2 variant of WORLD_SIZE? If so this seems hard with our canonical form of list for ranks, even if we allowed list instead, since the size of that list is ultimately the thing we'd want to vary.

I think we need to know the cluster size to shard the model optimally. Given our sharding decision maybe based on cluster size, we might want to retrace. If we assume we don't want to reshard, is the tracing cost so high that we need to skip it?

@wconstab
Copy link
Contributor Author

wconstab commented Feb 3, 2023

If we don't anticipate it being useful to trace a model (export it) and then later decide what the cluster size is, then we can just ignore that line of thought. However, at least for simple things like DDP/FSDP use cases this seems like a pretty nice thing to be able to do.

I don't think tracing cost is that high relative to other aspects of distributed compilation. It's really just for export and reuse that I think that could be useful.

@awgu
Copy link
Contributor

awgu commented Feb 3, 2023

The semantics of these ops are that they return a real tensor, but you aren't allowed to access its data or storage.
...

  • It is valid to peek at the size() or stride() (or probably other metadata) of a tensor returned from a collective, but not its data.

What is the difference between the tensor "data" vs. "storage"?

pytorchmergebot pushed a commit that referenced this issue Feb 22, 2024
…20226)

As communicated in #93173 (comment), although we are dropping `(ranks, tag)` as group identifier in funcols, there will be a grace period for migration. This PR adds temporary `(ranks, tag)` support in native funcols. It also helps us decouple the py funcol -> native funcol transition from the API change.

Pull Request resolved: #120226
Approved by: https://github.com/wanchaol, https://github.com/wconstab
ghstack dependencies: #120042, #120043, #120070
yifuwang added a commit that referenced this issue Feb 23, 2024
…default"


This enables native functional collectives by default. After this PR:
- The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier.
- Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173).
- Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users.

Testing performed:
- We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed.
- Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env).

Fallback mechansim:
- Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages.



cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
yifuwang added a commit that referenced this issue Feb 23, 2024
This enables native functional collectives by default. After this PR:
- The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier.
- Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173).
- Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users.

Testing performed:
- We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed.
- Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env).

Fallback mechansim:
- Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages.



cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
yifuwang added a commit that referenced this issue Feb 23, 2024
…default"


This enables native functional collectives by default. After this PR:
- The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier.
- Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173).
- Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users.

Testing performed:
- We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed.
- Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env).

Fallback mechansim:
- Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages.



cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
yifuwang added a commit that referenced this issue Feb 23, 2024
This enables native functional collectives by default. After this PR:
- The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier.
- Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173).
- Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users.

Testing performed:
- We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed.
- Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env).

Fallback mechansim:
- Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages.



cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
yifuwang added a commit that referenced this issue Feb 23, 2024
…default"


This enables native functional collectives by default. After this PR:
- The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier.
- Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173).
- Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users.

Testing performed:
- We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed.
- Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env).

Fallback mechansim:
- Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages.



cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
yifuwang added a commit that referenced this issue Feb 23, 2024
This enables native functional collectives by default. After this PR:
- The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier.
- Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173).
- Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users.

Testing performed:
- We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed.
- Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env).

Fallback mechansim:
- Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages.



cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
yifuwang added a commit that referenced this issue Feb 24, 2024
…default"


This enables native functional collectives by default. After this PR:
- The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier.
- Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173).
- Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users.

Testing performed:
- We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed.
- Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env).

Fallback mechansim:
- Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages.



cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
yifuwang added a commit that referenced this issue Feb 24, 2024
This enables native functional collectives by default. After this PR:
- The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier.
- Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173).
- Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users.

Testing performed:
- We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed.
- Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env).

Fallback mechansim:
- Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages.



cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue Feb 24, 2024
This enables native functional collectives by default. After this PR:
- The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier.
- Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173).
- Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users.

Testing performed:
- We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed.
- Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env).

Fallback mechansim:
- Introduced a temporary environment variable `TORCH_DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages.

Pull Request resolved: #120370
Approved by: https://github.com/wconstab, https://github.com/yf225
yifuwang added a commit that referenced this issue Feb 26, 2024
…default"


This enables native functional collectives by default. After this PR:
- The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier.
- Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173).
- Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users.

Testing performed:
- We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed.
- Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env).

Fallback mechansim:
- Introduced a temporary environment variable `TORCH_DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages.



cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
yifuwang added a commit that referenced this issue Feb 26, 2024
This enables native functional collectives by default. After this PR:
- The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier.
- Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173).
- Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users.

Testing performed:
- We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed.
- Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env).

Fallback mechansim:
- Introduced a temporary environment variable `TORCH_DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages.



cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
yifuwang added a commit that referenced this issue Feb 26, 2024
…default"


This enables native functional collectives by default. After this PR:
- The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier.
- Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173).
- Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users.

Testing performed:
- We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed.
- Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env).

Fallback mechansim:
- Introduced a temporary environment variable `TORCH_DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages.



cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
yifuwang added a commit that referenced this issue Feb 26, 2024
This enables native functional collectives by default. After this PR:
- The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier.
- Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173).
- Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users.

Testing performed:
- We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed.
- Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env).

Fallback mechansim:
- Introduced a temporary environment variable `TORCH_DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages.



cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue Feb 27, 2024
This enables native functional collectives by default. After this PR:
- The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier.
- Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173).
- Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users.

Testing performed:
- We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed.
- Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env).

Fallback mechansim:
- Introduced a temporary environment variable `TORCH_DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages.

Pull Request resolved: #120370
Approved by: https://github.com/wconstab, https://github.com/yf225
jhavukainen pushed a commit to kulinseth/pytorch that referenced this issue Mar 15, 2024
This change adds torch.distributed.traceable_collectives.

This experimental API enables collectives to be fully traced by dynamo and FX.

See pytorch#93173 for the RFC

Pull Request resolved: pytorch#93990
Approved by: https://github.com/wconstab, https://github.com/wanchaol, https://github.com/H-Huang
yifuwang added a commit to yifuwang/xla that referenced this issue Apr 4, 2024
PyTorch has implemented a new set of functional collective ops and is planning to remove the old ops. Migrating all_reduce to use the new op.

See context in pytorch/pytorch#93173 (comment)
yifuwang added a commit to yifuwang/xla that referenced this issue Apr 4, 2024
PyTorch has implemented a new set of functional collective ops and is planning to remove the old ops. Migrating all_reduce to use the new op.

See context in pytorch/pytorch#93173 (comment)
yifuwang added a commit to yifuwang/xla that referenced this issue Apr 5, 2024
PyTorch has implemented a new set of functional collective ops and is planning to remove the old ops. Migrating all_reduce to use the new op.

See context in pytorch/pytorch#93173 (comment)
yifuwang added a commit to yifuwang/xla that referenced this issue Apr 9, 2024
PyTorch has implemented a new set of functional collective ops and is planning to remove the old ops. Migrating all_reduce to use the new op.

See context in pytorch/pytorch#93173 (comment)
alanwaketan pushed a commit to pytorch/xla that referenced this issue Apr 10, 2024
…op (#6887)

PyTorch has implemented a new set of functional collective ops and is planning to remove the old ops. Migrating all_reduce to use the new op.

See context in pytorch/pytorch#93173 (comment)
@tombousso
Copy link

tombousso commented May 21, 2024

@wconstab

The collective ops are functional, but compilers may be able to convert them to inplace. They are asynchronous.

This creates a potential issue of large functional collective ops exhausting GPU memory, since they are not inplace.

  1. Is there any support currently in inductor for optimizing the functional collectives to inplace?
  2. If choosing to run a functional collective op in eager mode, is there any way to optimize to inplace?

@wconstab
Copy link
Contributor Author

@wconstab

The collective ops are functional, but compilers may be able to convert them to inplace. They are asynchronous.

This creates a potential issue of large functional collective ops exhausting GPU memory, since they are not inplace.

  1. Is there any support currently in inductor for optimizing the functional collectives to inplace?

  2. If choosing to run a functional collective op in eager mode, is there any way to optimize to inplace?

For (1), yes, inductor should reinplace if possible and generally optimize memory usage within the compiled region.

For (2) not literally, if you call an FC in eager you get an FC along with its memory usage. (It would be a hit to peak memory but could be recouped if the tensor is discarded after some other operation. But as a general pattern, we started trying to make torch dynamo able to trace native functional collective ops and reinterpret them as functional ones for tracing purposes- so if you want to keep using in place collectives in eager for performance reasons you may still have a path to compilation.

@yifuwang or @yf225 may have more up to date information.

@ConnollyLeon
Copy link

Hi@wconstab, When will you support Async Allgather? It's commonly used in Megatron-LM sequence parallelism.

@YangFei1990
Copy link

Hi@wconstab, When will you support Async Allgather? It's commonly used in Megatron-LM sequence parallelism.

Hi @ConnollyLeon could you shared more details for this Async allgather in Megatron-LM SP? Any pointer to the code/doc would be great. Thanks.

@ConnollyLeon
Copy link

Hi@wconstab, When will you support Async Allgather? It's commonly used in Megatron-LM sequence parallelism.

Hi @ConnollyLeon could you shared more details for this Async allgather in Megatron-LM SP? Any pointer to the code/doc would be great. Thanks.

Sure @YangFei1990 Sure, please check here for all_gather and here for reduce_scatter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: ProxyTensor make_fx and related module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests