Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable SPMD + dynamo for inference #5002

Merged
merged 5 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ function run_xla_op_tests {
run_test "$CDIR/pjrt/test_mesh_service.py"
run_test "$CDIR/spmd/test_xla_sharding.py"
run_test "$CDIR/spmd/test_xla_virtual_device.py"
run_test "$CDIR/spmd/test_dynamo_spmd.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
run_test "$CDIR/test_torch_distributed_xla_backend.py"
Expand Down
58 changes: 58 additions & 0 deletions test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
import sys

import torch
import torch.nn as nn
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.debug.metrics as met
import unittest

import test_xla_sharding_base


class SimpleLinear(nn.Module):

def __init__(self):
super(SimpleLinear, self).__init__()
self.fc1 = nn.Linear(128, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 1)
# Add an additional 1x1 layer at the end to ensure the final layer
# is not sharded.
self.fc3 = nn.Linear(1, 1)
Comment on lines +22 to +24
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this due to the lack of output sharding propagation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, in this pr I tried to keep it that output is replicated. We can expand this after output sharding pr is ready.


def forward(self, x):
y = self.relu(self.fc1(x))
z = self.fc2(y)
return self.fc3(z)


class DynamoSpmdInferenceTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
os.environ["XLA_USE_SPMD"] = "1"
super().setUpClass()

def test_dynamo_spmd_basic(self):
device = xm.xla_device()
linear = SimpleLinear().to(device)
linear.eval()
xla_x = torch.randn(1, 128, device=device)
xs.mark_sharding(linear.fc2.weight, self._get_mesh((1, self.n_devices)),
(1, 0))
xla_res = linear(xla_x)
xm.mark_step()

dynamo_linear = torch.compile(linear, backend="torchxla_trace_once")
dynamo_res = dynamo_linear(xla_x)
torch.allclose(xla_res.cpu(), dynamo_res.cpu())
# TODO(JackCaoG): add counter checks after ExecuteReplicated also creates
# a ExecuteMetric.


JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
5 changes: 5 additions & 0 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ torch::lazy::BackendDevice AtenDeviceToXlaDevice(const c10::Device& device) {
}

c10::Device XlaDeviceToAtenDevice(const torch::lazy::BackendDevice& device) {
// TODO(yeounoh) until we expose SPMD virtual device to the frontend, this
// will just be `XLA:0`.
if (device.type() == (int8_t)XlaDeviceType::SPMD) {
return c10::Device(at::kXLA, (size_t)0);
}
return c10::Device(at::kXLA,
AtenXlaDeviceMapper::Get()->GetDeviceOrdinal(device));
}
Expand Down
5 changes: 4 additions & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1728,7 +1728,10 @@ void InitXlaModuleBindings(py::module m) {
-> std::vector<at::Tensor> {
XLA_CHECK(hash_str.size() == sizeof(torch::lazy::hash_t));
torch::lazy::hash_t hash = *(torch::lazy::hash_t*)(hash_str.c_str());
torch::lazy::BackendDevice device = torch_xla::GetCurrentDevice();
// Device will be Virtual device if SPMD is enabled.
torch::lazy::BackendDevice device =
ShardingUtil::UseVirtualDevice() ? ParseDeviceString("SPMD:0")
: torch_xla::GetCurrentDevice();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yeounoh I am not sure if we should just update GetCurrentDevice, any thought? We need to sit down and think about how to surface this virtual device to user soon..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I voted for GetCurrentDevice as there might be other scenario where the caller will also need to distinguish SPMD:0 with XLA:0.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetCurrentDevice is being used over 30 places in our code base now, mostly during tracing and caller trying to figure out the hw type. I think it should be fine as long as SPMD:0 can be resolved into correct hardware type. I would leave that in a separate pr since it touches too many codes and might introduce noise.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

auto results = XLAGraphExecutor::Get()->ExecuteComputationWithBarrier(
hash, graph_inputs, device);
std::vector<at::Tensor> retlist;
Expand Down
86 changes: 73 additions & 13 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,9 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
MaybeDumpGraph("dynamo", hash);
auto cachedComputation =
XLAGraphExecutor::Get()->GetComputationCache()->Get(hash);
TF_VLOG(5) << "Cached computation (hash: " << torch::lazy::HashToString(hash)
<< ") is_sharded=" << cachedComputation->is_sharded << std::endl;

// TODO implement a fallback mechanism, or make sure those entries
// never get kicked out
XLA_CHECK(cachedComputation)
Expand All @@ -590,6 +593,16 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
torch::lazy::BackendDataPtr handle =
WrapXlaData(xla::ComputationClient::Get()->CreateDataPlaceholder(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's SPMD virtual device, then we should always use PjRtShardedData handle.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, is the logic below to call WrapDataShards not enough? This code path is shared between spmd and non-spmd code path.

device.toString(), std::move(shape)));
// If SPMD is enabled, we assume all output will be sharded or replicated
// and wrapped inside PjRtShardedData handle.
if (ShardingUtil::UseVirtualDevice()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we now start adding this for the dynamo path? We don't need this for the LTC path?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this patch is dynamo exclusive... Should we hint this somewhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the lazy code path already have this logic, in fact I copt this logic from lazy code path lol

Copy link
Collaborator

@alanwaketan alanwaketan May 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I smell an opportunity to merge two code paths more. But let's do it in a follow up.

XLATensor::ShardingSpecPtr sharding =
std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Replicate().ToProto(), shape);
handle = WrapXlaData(xla::ComputationClient::Get()->WrapDataShards(
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
{UnwrapXlaData(handle)}, GetVirtualDevice().toString(),
sharding->shape.value(), sharding->sharding));
}
placeholders.push_back(handle);
}

Expand All @@ -608,6 +621,9 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
if (auto xla_tensor_ptr = bridge::TryGetXlaTensor(ivalue.toTensor())) {
dataptr = xla_tensor_ptr->GetXlaData();
} else {
XLA_CHECK(device.type() != (int8_t)XlaDeviceType::SPMD)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this XLA_CHECK for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, not sure, I copy this from @yeounoh 's diff. @yeounoh any idea?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not needed, but more for a sanity check I probably added to ensure that this doesn't happen. Basically, we want to make sure that the SPMD device type is always on the backend (device data).

<< "SPMD device data should already be on the XLA backend "
"(XLATensor).";
dataptr = torch_xla::TensorToXlaData(ivalue.toTensor(), device);
}

Expand All @@ -619,21 +635,65 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
std::shared_ptr<XLAGraphExecutor::Async> async = std::make_shared<Async>(
&coll, std::move(arguments), placeholders, std::move(cachedComputation));

auto syncfn = [async, hash]() {
TF_VLOG(3) << "Executing Dynamo IR graph hash "
<< torch::lazy::HashToString(hash) << " on device "
<< async->device << " ...";
std::vector<torch::lazy::BackendDataPtr> results =
torch::lazy::getBackend()->ExecuteComputation(
// TODO(yeounoh) supply proper sharding specs for sharded results.
std::vector<XLATensor::ShardingSpecPtr> sharding_specs(placeholders.size());

auto syncfn = [async, hash, sharding_specs]() {
try {
TF_VLOG(3) << "Executing Dynamo IR graph hash "
<< torch::lazy::HashToString(hash) << " on device "
<< async->device << " ...";

std::vector<torch::lazy::BackendDataPtr> results;
if (async->cached_computation->is_sharded) {
std::vector<std::string> devices =
xla::ComputationClient::Get()->GetLocalDevices();
std::vector<std::vector<xla::ComputationClient::DataPtr>>
device_arguments = ShardingUtil::InputHandler(
UnwrapXlaData(async->parameters_data), devices);
xla::ComputationClient::ExecuteReplicatedOptions execute_options;
// OutputHandler creates sharded data for sharded
// tensor results. Both sharded and unsharded results should be
// "Assign"ed to the corresponding data placeholders.
std::vector<xla::ComputationClient::DataPtr> outputs =
ShardingUtil::OutputHandler(
xla::ComputationClient::Get()->ExecuteReplicated(
*async->cached_computation->computation
->client_computation(),
device_arguments, devices, execute_options),
sharding_specs);
results = WrapXlaData(outputs);
TF_VLOG(3) << "Executing Dynamo IR sharded graph hash "
<< torch::lazy::HashToString(hash) << " on devices "
<< absl::StrJoin(devices, ",") << " done!";
} else {
results = torch::lazy::getBackend()->ExecuteComputation(
async->cached_computation->computation, async->parameters_data,
async->device);
TF_VLOG(3) << "Executing Dynamo IR graph hash "
<< torch::lazy::HashToString(hash) << " on device "
<< async->device << " done!";
// Updating placeholder with actual output handle.
for (size_t i = 0; i < results.size(); ++i) {
XLA_CHECK(async->tensors_data[i] != nullptr);
async->tensors_data[i]->Assign(*results[i]);
TF_VLOG(3) << "Executing Dynamo IR graph hash "
<< torch::lazy::HashToString(hash) << " on device "
<< async->device << " done!";
}

// Updating placeholder with actual output handle.
for (size_t i = 0; i < results.size(); ++i) {
XLA_CHECK(async->tensors_data[i] != nullptr);
async->tensors_data[i]->Assign(*results[i]);
}
} catch (...) {
// There are two paths of discovery of an exception happening on an
// asynchronous task. One happens if the creator of the asynchronous task
// explicitly waits for completion, in which case the exception will be
// thrown from the Wait() API. Re-throwing the exception below makes sure
// this will be captured by the completer function created below, and
// surfaced by the Wait() API. But we also need to surface the exception
// even in case the caller does not wait, and that is accomplished by
// setting the unlockers status. In that case the exception will be
// surfaced when the user tries to acquire the device locks the next time.
for (auto& unlocker : async->unlocker) {
unlocker.SetStatus(std::current_exception());
}
throw;
}
};

Expand Down