-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] PT2-Friendly Traceable, Functional Collective Communication APIs #93173
Comments
Thanks, @wconstab. This is awesome! (Leaving a comment to have myself cc on the issue.) |
so one thing @ezyang pointed out is, we might not be able to support But apart from that- is this api going to help for XLA? And is there anything wrong with it? |
Remind me what tag is? |
@wconstab Can you expand a bit on |
Well, idk what Will has in mind, but I would instead have repped it the other way: [1,2,3,4,5,6,7,8], stride=4 gives me [[1,2,3,4],[5,6,7,8]] |
@ezyang I'm not totally sure we need tag, but @kumpera has offered a few explanations for why it's needed and XLA has a similar 'channel_id' field that iiuc is used for a similar purpose. The gist is that you may want to allow more than one ProcessGroup (which owns real things like cuda streams) to exist for the same set of ranks, and let user code or compiler code be explicit about which one it's working with. @kumpera can say more. |
stride is probably unhelpful here, maybe |
Tag enables coordination of what PGs to use with the world outside of a compilation function, which is important for things like checkpointing or other eager affairs. And within a function, it enables the same set of ranks to use different PGs and enable things like different stream priorities. |
so what, we have to parse a string? Would it be better to be an int? |
The tag is just a key in a dictionary so either way works. |
Hi @wconstab, thank you so much creating an RFC for this work. A few questions for my clarity
|
For (1), we are thinking of building on top/alongside of existing processgroup. So while we offer new python frontend apis (list of ranks, processgroup, dtensor.mesh), we can 'desugar' these representations to a canonical format based on list+int which we can trace into a graph. Then backends (such as inductor) can translate them back into a real processgroup object and leverage the existing pg backends. For (2, 3), we probably won't update the existing DDP codebase but we may offer some new way of getting a data-parallel model that is compiler friendly. In general, beyond simple DDP style, we want to make the inductor backend more capable of optimizing comm+compute overlap and memory allocations, etc. Thus it should become simpler to write a pretty naive 'ddp frontend' that then becomes optimal via the compiler backend. |
Does this API assume we need to know the ranks participating in the collectives? |
My assumption is it is generally possible to code up your script in terms of one input WORLD_SIZE, which usually comes from ENV anyway, so you can get it without initializing distributed. And definitely without initializing any pgs. I do want to discuss a few related points further:
|
I think we can assume that. However I am just wondering 1) Are the ranks going to be passed to the SPMD API? 2) why not init distributed?
I like ENV vars however having a self contained script has its advantages.
Yes, same question here. For initializing process groups we need to know the mesh dimensions over which we will be sharding our model. This might not happen until after the autoparallel profiling. However initializing a global PG seems fine to me.
I think we need to know the cluster size to shard the model optimally. Given our sharding decision maybe based on cluster size, we might want to retrace. If we assume we don't want to reshard, is the tracing cost so high that we need to skip it? |
If we don't anticipate it being useful to trace a model (export it) and then later decide what the cluster size is, then we can just ignore that line of thought. However, at least for simple things like DDP/FSDP use cases this seems like a pretty nice thing to be able to do. I don't think tracing cost is that high relative to other aspects of distributed compilation. It's really just for export and reuse that I think that could be useful. |
What is the difference between the tensor "data" vs. "storage"? |
…20226) As communicated in #93173 (comment), although we are dropping `(ranks, tag)` as group identifier in funcols, there will be a grace period for migration. This PR adds temporary `(ranks, tag)` support in native funcols. It also helps us decouple the py funcol -> native funcol transition from the API change. Pull Request resolved: #120226 Approved by: https://github.com/wanchaol, https://github.com/wconstab ghstack dependencies: #120042, #120043, #120070
…default" This enables native functional collectives by default. After this PR: - The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier. - Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173). - Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users. Testing performed: - We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed. - Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env). Fallback mechansim: - Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
This enables native functional collectives by default. After this PR: - The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier. - Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173). - Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users. Testing performed: - We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed. - Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env). Fallback mechansim: - Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…default" This enables native functional collectives by default. After this PR: - The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier. - Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173). - Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users. Testing performed: - We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed. - Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env). Fallback mechansim: - Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
This enables native functional collectives by default. After this PR: - The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier. - Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173). - Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users. Testing performed: - We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed. - Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env). Fallback mechansim: - Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…default" This enables native functional collectives by default. After this PR: - The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier. - Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173). - Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users. Testing performed: - We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed. - Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env). Fallback mechansim: - Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
This enables native functional collectives by default. After this PR: - The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier. - Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173). - Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users. Testing performed: - We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed. - Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env). Fallback mechansim: - Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…default" This enables native functional collectives by default. After this PR: - The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier. - Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173). - Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users. Testing performed: - We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed. - Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env). Fallback mechansim: - Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
This enables native functional collectives by default. After this PR: - The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier. - Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173). - Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users. Testing performed: - We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed. - Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env). Fallback mechansim: - Introduced a temporary environment variable `DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
This enables native functional collectives by default. After this PR: - The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier. - Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173). - Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users. Testing performed: - We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed. - Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env). Fallback mechansim: - Introduced a temporary environment variable `TORCH_DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages. Pull Request resolved: #120370 Approved by: https://github.com/wconstab, https://github.com/yf225
…default" This enables native functional collectives by default. After this PR: - The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier. - Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173). - Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users. Testing performed: - We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed. - Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env). Fallback mechansim: - Introduced a temporary environment variable `TORCH_DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
This enables native functional collectives by default. After this PR: - The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier. - Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173). - Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users. Testing performed: - We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed. - Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env). Fallback mechansim: - Introduced a temporary environment variable `TORCH_DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…default" This enables native functional collectives by default. After this PR: - The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier. - Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173). - Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users. Testing performed: - We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed. - Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env). Fallback mechansim: - Introduced a temporary environment variable `TORCH_DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
This enables native functional collectives by default. After this PR: - The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier. - Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173). - Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users. Testing performed: - We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed. - Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env). Fallback mechansim: - Introduced a temporary environment variable `TORCH_DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages. cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
This enables native functional collectives by default. After this PR: - The Python APIs remain backward compatible. Users will receive a deprecation warning if they use `(rank, tags)` as process group identifier. - Collectives will be captured as `_c10d_functional` ops in post-grad fx graphs. The change will not affect end-users, but it will impact `torch-xla` which has implemented an all-reduce backend based on the existing `c10d_functional` IR. This excludes the migration for `torch-xla` use cases, which will be coordinated separately (see communications in #93173). - Collectives will be lowered to and codegen'd by new Inductor collective IRs (`ir._CollectiveKernel` and `ir._WaitKernel`). This change will not affect end-users. Testing performed: - We have been running a set of representative unit tests with both the new native funcol and the old py funcol in CI. These test will continue to run with the old py funcol after this PR, so they are covered until they are removed. - Manually verified with e2e llama model training with DTensor + functional collectives (https://github.com/fairinternal/xlformers/tree/pt2_llm/pt2d#create-your-local-development-env). Fallback mechansim: - Introduced a temporary environment variable `TORCH_DISABLE_NATIVE_FUNCOL` that allows users to fall back to the previous implementation. We don't expect the migration to break anything; the mechanism is a safety measure to reduce potential disruption in case the PR causes unforeseen breakages. Pull Request resolved: #120370 Approved by: https://github.com/wconstab, https://github.com/yf225
This change adds torch.distributed.traceable_collectives. This experimental API enables collectives to be fully traced by dynamo and FX. See pytorch#93173 for the RFC Pull Request resolved: pytorch#93990 Approved by: https://github.com/wconstab, https://github.com/wanchaol, https://github.com/H-Huang
PyTorch has implemented a new set of functional collective ops and is planning to remove the old ops. Migrating all_reduce to use the new op. See context in pytorch/pytorch#93173 (comment)
PyTorch has implemented a new set of functional collective ops and is planning to remove the old ops. Migrating all_reduce to use the new op. See context in pytorch/pytorch#93173 (comment)
PyTorch has implemented a new set of functional collective ops and is planning to remove the old ops. Migrating all_reduce to use the new op. See context in pytorch/pytorch#93173 (comment)
PyTorch has implemented a new set of functional collective ops and is planning to remove the old ops. Migrating all_reduce to use the new op. See context in pytorch/pytorch#93173 (comment)
…op (#6887) PyTorch has implemented a new set of functional collective ops and is planning to remove the old ops. Migrating all_reduce to use the new op. See context in pytorch/pytorch#93173 (comment)
This creates a potential issue of large functional collective ops exhausting GPU memory, since they are not inplace.
|
For (1), yes, inductor should reinplace if possible and generally optimize memory usage within the compiled region. For (2) not literally, if you call an FC in eager you get an FC along with its memory usage. (It would be a hit to peak memory but could be recouped if the tensor is discarded after some other operation. But as a general pattern, we started trying to make torch dynamo able to trace native functional collective ops and reinterpret them as functional ones for tracing purposes- so if you want to keep using in place collectives in eager for performance reasons you may still have a path to compilation. |
Hi@wconstab, When will you support Async Allgather? It's commonly used in Megatron-LM sequence parallelism. |
Hi @ConnollyLeon could you shared more details for this Async allgather in Megatron-LM SP? Any pointer to the code/doc would be great. Thanks. |
Sure @YangFei1990 Sure, please check here for all_gather and here for reduce_scatter. |
🚀 Traceable Collectives!
Collective APIs (e.g. all_reduce, all_gather, ...) are used in distributed PyTorch programs, but do not compose cleanly with compilers.
Specifically, torchDynamo and the AotAutograd pipeline for decompositions and functionalization do not work with the existing c10d collective APIs
XLA also currently has to implement some workarounds, to marry the XLA collective ops via lazy tensor tracing with the existing PyTorch / C10D side. They have to use a custom ProcessGroup implementation and swizzle PTD PG creation functions.
Goals
Non-goals
New traceable collectives python API
GROUP_TYPE
is a Union over List, DeviceMesh, ProcessGroup, etc. It allows flexible usage by different frontends.AsyncTensor
is a Tensor subclass that callswait()
automatically when the tensor is used by another op.New Dispatcher Collectives
These are the ops that actually get traced into a graph and can be manipulated by compiler passes.
The collective ops are functional, but compilers may be able to convert them to inplace. They are asynchronous.
These ops support meta device (for traceability), and support backwards via derivatives.yaml.
The semantics of these ops are that they return a real tensor, but you aren't allowed to access its data or storage.
wait()
must be called on the output of any collective before its underlying data or storage is accessed.The semantics of wait are that you must only access the storage of the tensor returned from wait. You can't think of wait as mutating its input tensor and making it safe to use.
Alternatives
The following style of API has also been considered. Its main disadvantage is in requiring a user to first initialize a processgroup, but it is also opaque and not easily interchangeable with lists of ranks or DTensors. It doesn't allow us to easily represent MPMD collectives.
Detailed Proposal
See Traceable Collectives Design
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @soumith @ngimel
The text was updated successfully, but these errors were encountered: