Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable SPMD + dynamo for inference #5002

Merged
merged 5 commits into from
May 18, 2023
Merged

Enable SPMD + dynamo for inference #5002

merged 5 commits into from
May 18, 2023

Conversation

JackCaoG
Copy link
Collaborator

This work was done by @yeounoh and I am trying to land this pr in his behalf. The last attempt was made for @steventk-g in #4862.

Currently test failed with an Check failed: handle->HasValue(), so still WIP.

@JackCaoG
Copy link
Collaborator Author

OK There are 2 issues

  1. Dynamo async function currently silently fail if there is an exception happened. we need to add a rethrow logic
  2. dynamo's PjRtComputationClient::PjRtData::Assign(const Data& data) failed with bad_cast error
void PjRtComputationClient::PjRtData::Assign(const Data& data) {
  TF_VLOG(3) << "enter assign\n";
  const PjRtData& pjrt_data = dynamic_cast<const PjRtData&>(data);
  if (&pjrt_data != this) {
    buffer = pjrt_data.buffer;
  }
  TF_VLOG(3) << "left assign\n";
}
2023-05-13 00:26:38.471263: I third_party/xla_client/pjrt_computation_client.cc:135] enter assign

E
RuntimeError: std::bad_cast

@JackCaoG
Copy link
Collaborator Author

Ah Ok.. I think I know what's the problem, the result of the dynamo graph is a PjRtShardedData, and we tried to cast it to PjRtData. This might has to do with @jonb377 's recent pr that make most things implicitly replicated. This should be a easy fix, I can work on it next week.

FYI @yeounoh

@yeounoh yeounoh self-requested a review May 15, 2023 17:31
@yeounoh
Copy link
Collaborator

yeounoh commented May 15, 2023

Ah Ok.. I think I know what's the problem, the result of the dynamo graph is a PjRtShardedData, and we tried to cast it to PjRtData. This might has to do with @jonb377 's recent pr that make most things implicitly replicated. This should be a easy fix, I can work on it next week.

FYI @yeounoh

Thanks @JackCaoG , I am going to merge a output param sharding patch, which might change the code path a bit. Let's chat offline, I can explain further.

// Device will be Virtual device if SPMD is enabled.
torch::lazy::BackendDevice device =
ShardingUtil::UseVirtualDevice() ? ParseDeviceString("SPMD:0")
: torch_xla::GetCurrentDevice();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yeounoh I am not sure if we should just update GetCurrentDevice, any thought? We need to sit down and think about how to surface this virtual device to user soon..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I voted for GetCurrentDevice as there might be other scenario where the caller will also need to distinguish SPMD:0 with XLA:0.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetCurrentDevice is being used over 30 places in our code base now, mostly during tracing and caller trying to figure out the hw type. I think it should be fine as long as SPMD:0 can be resolved into correct hardware type. I would leave that in a separate pr since it touches too many codes and might introduce noise.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

@JackCaoG JackCaoG marked this pull request as ready for review May 16, 2023 18:20
@JackCaoG JackCaoG changed the title [WIP] Enable SPMD + dynamo for inference Enable SPMD + dynamo for inference May 16, 2023
@JackCaoG
Copy link
Collaborator Author

I think this one is ready for review, I will add more test cases(input data sharding, which I am not sure if it works or not) and features in the next pr.

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, LGTM.

@@ -590,9 +593,21 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
torch::lazy::BackendDataPtr handle =
WrapXlaData(xla::ComputationClient::Get()->CreateDataPlaceholder(
device.toString(), std::move(shape)));
// if SPMD is enabled, we assume all output will be replicated
if (ShardingUtil::UseVirtualDevice()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we now start adding this for the dynamo path? We don't need this for the LTC path?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this patch is dynamo exclusive... Should we hint this somewhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the lazy code path already have this logic, in fact I copt this logic from lazy code path lol

Copy link
Collaborator

@alanwaketan alanwaketan May 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I smell an opportunity to merge two code paths more. But let's do it in a follow up.

torch_xla/csrc/xla_graph_executor.cpp Outdated Show resolved Hide resolved
@@ -608,6 +623,9 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
if (auto xla_tensor_ptr = bridge::TryGetXlaTensor(ivalue.toTensor())) {
dataptr = xla_tensor_ptr->GetXlaData();
} else {
XLA_CHECK(device.type() != (int8_t)XlaDeviceType::SPMD)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this XLA_CHECK for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, not sure, I copy this from @yeounoh 's diff. @yeounoh any idea?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not needed, but more for a sanity check I probably added to ensure that this doesn't happen. Basically, we want to make sure that the SPMD device type is always on the backend (device data).

@JackCaoG JackCaoG requested a review from alanwaketan May 17, 2023 00:05
Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks Jack

Comment on lines +22 to +24
# Add an additional 1x1 layer at the end to ensure the final layer
# is not sharded.
self.fc3 = nn.Linear(1, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this due to the lack of output sharding propagation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, in this pr I tried to keep it that output is replicated. We can expand this after output sharding pr is ready.

@yeounoh
Copy link
Collaborator

yeounoh commented May 17, 2023

I think this one is ready for review, I will add more test cases(input data sharding, which I am not sure if it works or not) and features in the next pr.

Input sharding should (used to) work if the sharded input is used for the torch compilation. Let me know. Will take a pass on the chages now as well, thanks.

@@ -590,6 +593,15 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
torch::lazy::BackendDataPtr handle =
WrapXlaData(xla::ComputationClient::Get()->CreateDataPlaceholder(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's SPMD virtual device, then we should always use PjRtShardedData handle.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, is the logic below to call WrapDataShards not enough? This code path is shared between spmd and non-spmd code path.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants