Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,92 @@ def test_dynamo_spmd_output_sharding_cache(self):
dynamo_res = dynamo_linear(xla_y)
self.assertEqual(met.counter_value('UncachedOutputSharding'), 1)

def test_dynamo_sharded_input(self):
device = xm.xla_device()
linear = SimpleLinear().to(device)
linear.eval()
xla_x = torch.randn(8, 128, device=device)
xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)), (1, 0))
xla_res = linear(xla_x)
xm.mark_step()

dynamo_linear = torch.compile(linear, backend="openxla")
dynamo_res = dynamo_linear(xla_x)
torch.allclose(xla_res.cpu(), dynamo_res.cpu())

def test_dynamo_input_sharding_changed(self):
device = xm.xla_device()
linear = SimpleLinear().to(device)
linear.eval()
xla_x = torch.randn(8, 128, device=device)
xla_y = torch.randn(8, 128, device=device)
xm.mark_step()

met.clear_all()
dynamo_linear = torch.compile(linear, backend="openxla")
dynamo_res = dynamo_linear(xla_x)
self.assertIn('CompileTime', met.metric_names())
self.assertEqual(met.metric_data('CompileTime')[0], 1)

# Shard the original input
xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)), (1, 0))
dynamo_res_sharded = dynamo_linear(xla_x)
torch.allclose(dynamo_res.cpu(), dynamo_res_sharded.cpu())
# one graph is being generated by .cpu call above
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should .cpu() trigger another compilation? Maybe a separate topic but I'm curious. I assume it shouldn't?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing the graph is something similar to

%1 = device_data()

didn't check the actual HLO

if self.n_devices > 1:
self.assertEqual(met.metric_data('CompileTime')[0], 3)
else:
# if there is only one device(cpu) then sharding spec will be replicated
# hence no change.
self.assertEqual(met.metric_data('CompileTime')[0], 1)

# Call the dynamo function with a different input with different sharding
xs.mark_sharding(xla_y, self._get_mesh((1, self.n_devices)), (0, 1))
dynamo_res_sharded_2 = dynamo_linear(xla_y)
if self.n_devices > 1:
self.assertEqual(met.metric_data('CompileTime')[0], 4)
else:
# if there is only one device(cpu) then sharding spec will be replicated
# hence no change.
self.assertEqual(met.metric_data('CompileTime')[0], 1)
torch.allclose(linear(xla_y).cpu(), dynamo_res_sharded_2.cpu())

@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices needed to test the mesh change")
def test_dynamo_input_sharding_threashold(self):
device = xm.xla_device()
linear = SimpleLinear().to(device)
linear.eval()
xla_x = torch.randn(8, 128, device=device)
xm.mark_step()

dynamo_linear = torch.compile(linear, backend="openxla")
if 'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD' in os.environ:
saved_var = os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD']
else:
saved_var = None
os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] = '2'

dynamo_res = dynamo_linear(xla_x)
xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)), (1, 0))
dynamo_res = dynamo_linear(xla_x)
xs.clear_sharding(xla_x)
xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)), (0, 1))
# crash will hapeen in a async execution thread, need to grab the lock again to
# surface that exception
dynamo_res = dynamo_linear(xla_x)
try:
print(dynamo_res)
except:
print('catch')
# it is hard to catch the C++ runtime error in python, instead we can check if
# after printing that dynamo_res is still a placeholder then it means C++ crashed.
self.assertTrue(torch_xla._XLAC._is_placecholder(dynamo_res))
if saved_var != None:
os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] = saved_var
else:
del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD']


if __name__ == '__main__':
test = unittest.main()
Expand Down
52 changes: 51 additions & 1 deletion torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as metrics
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu

debug = os.environ.get("TORCH_XLA_DEBUG") == "1"

Expand Down Expand Up @@ -202,7 +204,7 @@ def is_xla_tensor(tensor: torch.Tensor) -> bool:
return tensor.device.type == "xla"


def extract_internal(xla_model: torch.fx.GraphModule):
def extract_graph_helper(xla_model: torch.fx.GraphModule):
xla_args = xla_model.xla_args
assert all(
map(
Expand Down Expand Up @@ -238,6 +240,11 @@ def extract_internal(xla_model: torch.fx.GraphModule):
tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)
}

if xr.is_spmd():
xla_args_sharding_spec = torch_xla._XLAC._get_xla_sharding_specs(xla_args)
else:
xla_args_sharding_spec = ()

xla_out = xla_model(*xla_args)
if not isinstance(xla_out, (tuple, list)):
xla_out = (xla_out,)
Expand Down Expand Up @@ -308,12 +315,55 @@ def extract_internal(xla_model: torch.fx.GraphModule):
# should be removed to avoid extra computation executed and in place updates op
# mistakenlly update the input tensors.
torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))
return (xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update)


def extract_internal(xla_model: torch.fx.GraphModule):
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)
skip_checking_input_sharding_threashold = xu.getenv_as(
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we want this flag? And how to use this flag properly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getting the shardingspec for all input tensor is not free, it has some speed implication for the inference(~ 5% - 10%) on 7B. The idea of this flag is that if input sharding is the same for XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD times, we will assume everything is OK and stop checking the input shardings. It is more or less a hack.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The core idea is that user should not change the input sharding. We are mainly solving the problem for compiler overwrite the sharding. Compiler usually overwrites sharding for the 1st run and then sharding will be the same.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I guess this flag should be off by default? We should ensure program correctness first than provide some risky hacks for users to tune the performance?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The core idea is that user should not change the input sharding. We are mainly solving the problem for compiler overwrite the sharding. Compiler usually overwrites sharding for the 1st run and then sharding will be the same.

Okay, it makes sense then.


def optimized_mod(*args):
nonlocal xla_model
nonlocal xla_args_sharding_spec
nonlocal args_and_out
nonlocal graph_hash
nonlocal arg_index_to_need_update_index
nonlocal none_remover
nonlocal graph_input_matcher
nonlocal dumb_return_handler
nonlocal xla_args_need_update
nonlocal skip_checking_input_sharding_threashold

# mark_step needs to be blocking since we want to access args's XLADatas
# and they can't be placeholder.
if any(torch_xla._XLAC._check_tensor_need_materialization(args)):
xm.mark_step(wait=True)

# If input sharding has changed from the previous program, dynamo current can
# not detect this. It will mistakenly believe the program is the same. We need
# to retrace it here.
if xr.is_spmd():
# if the input sharding was the same for skip_checking_input_sharding_threashold times
# we will skip checking the input sharding since it can be expensive.
if skip_checking_input_sharding_threashold > 0:
if torch_xla._XLAC._get_xla_sharding_specs(
args) != xla_args_sharding_spec:
# update the xla_args with the input with new sharding and retrace
xla_model.xla_args = args
(xla_args_sharding_spec, args_and_ou_copy, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model)
skip_checking_input_sharding_threashold = xu.getenv_as(
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)
else:
skip_checking_input_sharding_threashold -= 1

enter_ts = time.time()
if len(args_and_out) == 0:
return ()
Expand Down
26 changes: 26 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,27 @@ void InitXlaModuleBindings(py::module m) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
return GetXLAShardingSpec(xtensor);
});
m.def("_get_xla_sharding_specs",
[](const std::vector<at::Tensor>& tensors) -> std::vector<std::string> {
tsl::profiler::TraceMe activity("_get_xla_sharding_specs",
tsl::profiler::TraceMeLevel::kInfo);
TORCH_LAZY_TIMED("_get_xla_sharding_specs");
std::vector<std::string> sharding_specs;
sharding_specs.reserve(tensors.size());
for (const at::Tensor& tensor : tensors) {
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
XLATensor::ShardingSpecPtr sharding_spec =
xtensor ? xtensor->sharding_spec() : nullptr;
if (sharding_spec != nullptr) {
sharding_specs.push_back(
xla::HloSharding::FromProto(sharding_spec->sharding)
->ToString());
} else {
sharding_specs.push_back("");
}
}
return sharding_specs;
});
m.def("_get_xla_sharding_type",
[](const at::Tensor& input) -> std::optional<int> {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
Expand Down Expand Up @@ -1616,6 +1637,11 @@ void InitXlaModuleBindings(py::module m) {
xla::HloModule::CreateFromProto(module_proto, config).value());
return module->ToString();
});
m.def("_is_placecholder", [](at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
return xtensor->CurrentDataHandle() &&
!xtensor->CurrentDataHandle()->HasValue();
});
m.def("_init_xla_lazy_backend", []() {
MapXlaEnvVarsToLazy();
InitXlaBackend();
Expand Down