-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using CC ops with mark_sharding API throws an error. #6647
Comments
I think what you are trying to do is
I think this won't work out of the box because
The problem here is that I don't think there is an easy to change the PJRT device config on the go. @will-cromar @yeounoh in cases you guys has some better suggestions. @baoleai I remembered you guys mentioned something about SPMD + pp, wondering if you guys has some insight as well. |
Currently, SPMD cannot support communication operators at the Python layer. When combining SPMD-TP and PP, we made numerous changes to xla and the openxla spmd pass to support send/recv @yitongh . Supporting the allreduce communication operator might be more complicated. |
Based on previous experience, you will need to do the following things on GPU:
Even with the above handling, the all-reduce operator is currently not well-suited to handle sharded inputs and can only function as a replicated operation. Similar handling may be required in the TPU environment. Overall, supporting Python-side communication in the SPMD environment doesn't seem to have any particularly elegant solutions at the moment. Perhaps, as JackCaoG suggested, changing the configuration of the PJRT device might be a good approach. |
@baoleai @yitongh is the send/recv using XLA Send/Recv ? Can we use any way to skip sharding_propagation pass ? This can be an isolated graph (cut off using mark_Steps) and we can use custom_call or any attribute to skip "peeking" into the all-reduce |
@JackCaoG For the cc ops set-up, why do we need ti set up PjRT in a different way? All we need is the graph with the correct replica groups correct? (this can be borrowed from mesh during SPMD set-up). The PjRT runtime would just execute this on all "threads" (we dont need these to be different processes ) and the all-reduce would look any other all-reduce from a SPMD partitioner pass. |
🐛 Describe the bug
The crash seen is the following:
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
F0000 00:00:1709131242.311197 36940 hlo_sharding.cc:1034] Check failed: IsTuple() *** Check failure stack trace: ***
@ 0x7f1e46d752d9 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal() @ 0x7f1e40ed9700 xla::HloSharding::GetSubSharding()
@ 0x7f1e41cadd35 xla::ShardingPropagation::InferShardingFromOperands() @ 0x7f1e41cb1cec xla::ShardingPropagation::Run()::{lambda()#3}::operator()()
@ 0x7f1e41cb5d43 xla::ShardingPropagation::Run()
@ 0x7f1e41c98355 xla::HloPassPipeline::RunHelper()
@ 0x7f1e41c9933a xla::HloPassPipeline::RunPassesInternal<>()
@ 0x7f1e41c99fa4 xla::HloPassPipeline::Run()
@ 0x7f1e41100d49 neuron::HloOptimization()
@ 0x7f1e410a3ab9 neuron::Optimize()
@ 0x7f1e4109f07e neuron::PJRT_Client_Compile()
@ 0x7f1e410a0638 neuron::Decorator<>::wrapper()
@ 0x7f1e51d966c5 xla::InitializeArgsAndCompile()
@ 0x7f1e51d969e0 xla::PjRtCApiClient::Compile()
@ 0x7f1e4d3411e6 torch_xla::runtime::PjRtComputationClient::Compile()
@ 0x7f1e4d14853e torch_xla::XLAGraphExecutor::Compile()
@ 0x7f1e4d149f49 torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal()
@ 0x7f1e4d14a58b torch_xla::XLAGraphExecutor::SyncTensorsGraph()
@ 0x7f1e4d14a9b8 torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph()
@ 0x7f1e4cf1928a torch_xla::(anonymous namespace)::StepMarker()
@ 0x7f1e4cf196c6 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
@ 0x7f1e4cef6ed0 pybind11::cpp_function::dispatcher()
@ 0x5d5499 PyCFunction_Call
Aborted (core dumped)
A simple example to reproduce the bug is attached below:
The text was updated successfully, but these errors were encountered: