diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 204b95b3d705..22cd29804137 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -85,6 +85,92 @@ def test_dynamo_spmd_output_sharding_cache(self): dynamo_res = dynamo_linear(xla_y) self.assertEqual(met.counter_value('UncachedOutputSharding'), 1) + def test_dynamo_sharded_input(self): + device = xm.xla_device() + linear = SimpleLinear().to(device) + linear.eval() + xla_x = torch.randn(8, 128, device=device) + xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)), (1, 0)) + xla_res = linear(xla_x) + xm.mark_step() + + dynamo_linear = torch.compile(linear, backend="openxla") + dynamo_res = dynamo_linear(xla_x) + torch.allclose(xla_res.cpu(), dynamo_res.cpu()) + + def test_dynamo_input_sharding_changed(self): + device = xm.xla_device() + linear = SimpleLinear().to(device) + linear.eval() + xla_x = torch.randn(8, 128, device=device) + xla_y = torch.randn(8, 128, device=device) + xm.mark_step() + + met.clear_all() + dynamo_linear = torch.compile(linear, backend="openxla") + dynamo_res = dynamo_linear(xla_x) + self.assertIn('CompileTime', met.metric_names()) + self.assertEqual(met.metric_data('CompileTime')[0], 1) + + # Shard the original input + xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)), (1, 0)) + dynamo_res_sharded = dynamo_linear(xla_x) + torch.allclose(dynamo_res.cpu(), dynamo_res_sharded.cpu()) + # one graph is being generated by .cpu call above + if self.n_devices > 1: + self.assertEqual(met.metric_data('CompileTime')[0], 3) + else: + # if there is only one device(cpu) then sharding spec will be replicated + # hence no change. + self.assertEqual(met.metric_data('CompileTime')[0], 1) + + # Call the dynamo function with a different input with different sharding + xs.mark_sharding(xla_y, self._get_mesh((1, self.n_devices)), (0, 1)) + dynamo_res_sharded_2 = dynamo_linear(xla_y) + if self.n_devices > 1: + self.assertEqual(met.metric_data('CompileTime')[0], 4) + else: + # if there is only one device(cpu) then sharding spec will be replicated + # hence no change. + self.assertEqual(met.metric_data('CompileTime')[0], 1) + torch.allclose(linear(xla_y).cpu(), dynamo_res_sharded_2.cpu()) + + @unittest.skipIf(xr.global_runtime_device_count() == 1, + "Multiple devices needed to test the mesh change") + def test_dynamo_input_sharding_threashold(self): + device = xm.xla_device() + linear = SimpleLinear().to(device) + linear.eval() + xla_x = torch.randn(8, 128, device=device) + xm.mark_step() + + dynamo_linear = torch.compile(linear, backend="openxla") + if 'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD' in os.environ: + saved_var = os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] + else: + saved_var = None + os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] = '2' + + dynamo_res = dynamo_linear(xla_x) + xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)), (1, 0)) + dynamo_res = dynamo_linear(xla_x) + xs.clear_sharding(xla_x) + xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)), (0, 1)) + # crash will hapeen in a async execution thread, need to grab the lock again to + # surface that exception + dynamo_res = dynamo_linear(xla_x) + try: + print(dynamo_res) + except: + print('catch') + # it is hard to catch the C++ runtime error in python, instead we can check if + # after printing that dynamo_res is still a placeholder then it means C++ crashed. + self.assertTrue(torch_xla._XLAC._is_placecholder(dynamo_res)) + if saved_var != None: + os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] = saved_var + else: + del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index 2f4b8bdc866f..ff6679cc303b 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -14,6 +14,8 @@ import torch_xla import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as metrics +import torch_xla.runtime as xr +import torch_xla.utils.utils as xu debug = os.environ.get("TORCH_XLA_DEBUG") == "1" @@ -202,7 +204,7 @@ def is_xla_tensor(tensor: torch.Tensor) -> bool: return tensor.device.type == "xla" -def extract_internal(xla_model: torch.fx.GraphModule): +def extract_graph_helper(xla_model: torch.fx.GraphModule): xla_args = xla_model.xla_args assert all( map( @@ -238,6 +240,11 @@ def extract_internal(xla_model: torch.fx.GraphModule): tensor_id: i for i, tensor_id in enumerate(args_tensor_ids) } + if xr.is_spmd(): + xla_args_sharding_spec = torch_xla._XLAC._get_xla_sharding_specs(xla_args) + else: + xla_args_sharding_spec = () + xla_out = xla_model(*xla_args) if not isinstance(xla_out, (tuple, list)): xla_out = (xla_out,) @@ -308,12 +315,55 @@ def extract_internal(xla_model: torch.fx.GraphModule): # should be removed to avoid extra computation executed and in place updates op # mistakenlly update the input tensors. torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) + return (xla_args_sharding_spec, args_and_out, graph_hash, + arg_index_to_need_update_index, none_remover, graph_input_matcher, + dumb_return_handler, xla_args_need_update) + + +def extract_internal(xla_model: torch.fx.GraphModule): + (xla_args_sharding_spec, args_and_out, graph_hash, + arg_index_to_need_update_index, none_remover, graph_input_matcher, + dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model) + skip_checking_input_sharding_threashold = xu.getenv_as( + 'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5) def optimized_mod(*args): + nonlocal xla_model + nonlocal xla_args_sharding_spec + nonlocal args_and_out + nonlocal graph_hash + nonlocal arg_index_to_need_update_index + nonlocal none_remover + nonlocal graph_input_matcher + nonlocal dumb_return_handler + nonlocal xla_args_need_update + nonlocal skip_checking_input_sharding_threashold + # mark_step needs to be blocking since we want to access args's XLADatas # and they can't be placeholder. if any(torch_xla._XLAC._check_tensor_need_materialization(args)): xm.mark_step(wait=True) + + # If input sharding has changed from the previous program, dynamo current can + # not detect this. It will mistakenly believe the program is the same. We need + # to retrace it here. + if xr.is_spmd(): + # if the input sharding was the same for skip_checking_input_sharding_threashold times + # we will skip checking the input sharding since it can be expensive. + if skip_checking_input_sharding_threashold > 0: + if torch_xla._XLAC._get_xla_sharding_specs( + args) != xla_args_sharding_spec: + # update the xla_args with the input with new sharding and retrace + xla_model.xla_args = args + (xla_args_sharding_spec, args_and_ou_copy, graph_hash, + arg_index_to_need_update_index, none_remover, graph_input_matcher, + dumb_return_handler, + xla_args_need_update) = extract_graph_helper(xla_model) + skip_checking_input_sharding_threashold = xu.getenv_as( + 'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5) + else: + skip_checking_input_sharding_threashold -= 1 + enter_ts = time.time() if len(args_and_out) == 0: return () diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a59f4c2d3fbd..17a2bba7f4cb 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1471,6 +1471,27 @@ void InitXlaModuleBindings(py::module m) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); return GetXLAShardingSpec(xtensor); }); + m.def("_get_xla_sharding_specs", + [](const std::vector& tensors) -> std::vector { + tsl::profiler::TraceMe activity("_get_xla_sharding_specs", + tsl::profiler::TraceMeLevel::kInfo); + TORCH_LAZY_TIMED("_get_xla_sharding_specs"); + std::vector sharding_specs; + sharding_specs.reserve(tensors.size()); + for (const at::Tensor& tensor : tensors) { + XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensor::ShardingSpecPtr sharding_spec = + xtensor ? xtensor->sharding_spec() : nullptr; + if (sharding_spec != nullptr) { + sharding_specs.push_back( + xla::HloSharding::FromProto(sharding_spec->sharding) + ->ToString()); + } else { + sharding_specs.push_back(""); + } + } + return sharding_specs; + }); m.def("_get_xla_sharding_type", [](const at::Tensor& input) -> std::optional { XLATensorPtr xtensor = bridge::GetXlaTensor(input); @@ -1616,6 +1637,11 @@ void InitXlaModuleBindings(py::module m) { xla::HloModule::CreateFromProto(module_proto, config).value()); return module->ToString(); }); + m.def("_is_placecholder", [](at::Tensor& input) { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + return xtensor->CurrentDataHandle() && + !xtensor->CurrentDataHandle()->HasValue(); + }); m.def("_init_xla_lazy_backend", []() { MapXlaEnvVarsToLazy(); InitXlaBackend();