-
Notifications
You must be signed in to change notification settings - Fork 429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable SPMD + dynamo for inference #5002
Conversation
OK There are 2 issues
|
Thanks @JackCaoG , I am going to merge a output param sharding patch, which might change the code path a bit. Let's chat offline, I can explain further. |
…aceholder if SPMD is enabled
// Device will be Virtual device if SPMD is enabled. | ||
torch::lazy::BackendDevice device = | ||
ShardingUtil::UseVirtualDevice() ? ParseDeviceString("SPMD:0") | ||
: torch_xla::GetCurrentDevice(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yeounoh I am not sure if we should just update GetCurrentDevice
, any thought? We need to sit down and think about how to surface this virtual device to user soon..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I voted for GetCurrentDevice as there might be other scenario where the caller will also need to distinguish SPMD:0 with XLA:0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GetCurrentDevice
is being used over 30 places in our code base now, mostly during tracing and caller trying to figure out the hw type. I think it should be fine as long as SPMD:0
can be resolved into correct hardware type. I would leave that in a separate pr since it touches too many codes and might introduce noise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it.
I think this one is ready for review, I will add more test cases(input data sharding, which I am not sure if it works or not) and features in the next pr. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally, LGTM.
@@ -590,9 +593,21 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( | |||
torch::lazy::BackendDataPtr handle = | |||
WrapXlaData(xla::ComputationClient::Get()->CreateDataPlaceholder( | |||
device.toString(), std::move(shape))); | |||
// if SPMD is enabled, we assume all output will be replicated | |||
if (ShardingUtil::UseVirtualDevice()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we now start adding this for the dynamo path? We don't need this for the LTC path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this patch is dynamo exclusive... Should we hint this somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the lazy code path already have this logic, in fact I copt this logic from lazy code path lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I smell an opportunity to merge two code paths more. But let's do it in a follow up.
@@ -608,6 +623,9 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( | |||
if (auto xla_tensor_ptr = bridge::TryGetXlaTensor(ivalue.toTensor())) { | |||
dataptr = xla_tensor_ptr->GetXlaData(); | |||
} else { | |||
XLA_CHECK(device.type() != (int8_t)XlaDeviceType::SPMD) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this XLA_CHECK for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not needed, but more for a sanity check I probably added to ensure that this doesn't happen. Basically, we want to make sure that the SPMD device type is always on the backend (device data).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks Jack
# Add an additional 1x1 layer at the end to ensure the final layer | ||
# is not sharded. | ||
self.fc3 = nn.Linear(1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this due to the lack of output sharding propagation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, in this pr I tried to keep it that output is replicated. We can expand this after output sharding pr is ready.
Input sharding should (used to) work if the sharded input is used for the torch compilation. Let me know. Will take a pass on the chages now as well, thanks. |
@@ -590,6 +593,15 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( | |||
torch::lazy::BackendDataPtr handle = | |||
WrapXlaData(xla::ComputationClient::Get()->CreateDataPlaceholder( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's SPMD virtual device, then we should always use PjRtShardedData handle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, is the logic below to call WrapDataShards
not enough? This code path is shared between spmd and non-spmd code path.
This work was done by @yeounoh and I am trying to land this pr in his behalf. The last attempt was made for @steventk-g in #4862.
Currently test failed with an
Check failed: handle->HasValue()
, so still WIP.