-
Notifications
You must be signed in to change notification settings - Fork 559
Support input sharding changed after first dynamo tracing #5477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3412fe8
dbca95e
df2e27b
bc374c3
af494e2
b9ec261
cd29929
ca23546
0f4417d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,8 @@ | |
| import torch_xla | ||
| import torch_xla.core.xla_model as xm | ||
| import torch_xla.debug.metrics as metrics | ||
| import torch_xla.runtime as xr | ||
| import torch_xla.utils.utils as xu | ||
|
|
||
| debug = os.environ.get("TORCH_XLA_DEBUG") == "1" | ||
|
|
||
|
|
@@ -202,7 +204,7 @@ def is_xla_tensor(tensor: torch.Tensor) -> bool: | |
| return tensor.device.type == "xla" | ||
|
|
||
|
|
||
| def extract_internal(xla_model: torch.fx.GraphModule): | ||
| def extract_graph_helper(xla_model: torch.fx.GraphModule): | ||
| xla_args = xla_model.xla_args | ||
| assert all( | ||
| map( | ||
|
|
@@ -238,6 +240,11 @@ def extract_internal(xla_model: torch.fx.GraphModule): | |
| tensor_id: i for i, tensor_id in enumerate(args_tensor_ids) | ||
| } | ||
|
|
||
| if xr.is_spmd(): | ||
| xla_args_sharding_spec = torch_xla._XLAC._get_xla_sharding_specs(xla_args) | ||
| else: | ||
| xla_args_sharding_spec = () | ||
|
|
||
| xla_out = xla_model(*xla_args) | ||
| if not isinstance(xla_out, (tuple, list)): | ||
| xla_out = (xla_out,) | ||
|
|
@@ -308,12 +315,55 @@ def extract_internal(xla_model: torch.fx.GraphModule): | |
| # should be removed to avoid extra computation executed and in place updates op | ||
| # mistakenlly update the input tensors. | ||
| torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) | ||
| return (xla_args_sharding_spec, args_and_out, graph_hash, | ||
| arg_index_to_need_update_index, none_remover, graph_input_matcher, | ||
| dumb_return_handler, xla_args_need_update) | ||
|
|
||
|
|
||
| def extract_internal(xla_model: torch.fx.GraphModule): | ||
| (xla_args_sharding_spec, args_and_out, graph_hash, | ||
| arg_index_to_need_update_index, none_remover, graph_input_matcher, | ||
| dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model) | ||
| skip_checking_input_sharding_threashold = xu.getenv_as( | ||
| 'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would we want this flag? And how to use this flag properly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. getting the shardingspec for all input tensor is not free, it has some speed implication for the inference(~ 5% - 10%) on 7B. The idea of this flag is that if input sharding is the same for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The core idea is that user should not change the input sharding. We are mainly solving the problem for compiler overwrite the sharding. Compiler usually overwrites sharding for the 1st run and then sharding will be the same. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then I guess this flag should be off by default? We should ensure program correctness first than provide some risky hacks for users to tune the performance? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Okay, it makes sense then. |
||
|
|
||
| def optimized_mod(*args): | ||
| nonlocal xla_model | ||
| nonlocal xla_args_sharding_spec | ||
| nonlocal args_and_out | ||
| nonlocal graph_hash | ||
| nonlocal arg_index_to_need_update_index | ||
| nonlocal none_remover | ||
| nonlocal graph_input_matcher | ||
| nonlocal dumb_return_handler | ||
| nonlocal xla_args_need_update | ||
| nonlocal skip_checking_input_sharding_threashold | ||
|
|
||
| # mark_step needs to be blocking since we want to access args's XLADatas | ||
| # and they can't be placeholder. | ||
| if any(torch_xla._XLAC._check_tensor_need_materialization(args)): | ||
| xm.mark_step(wait=True) | ||
|
|
||
| # If input sharding has changed from the previous program, dynamo current can | ||
| # not detect this. It will mistakenly believe the program is the same. We need | ||
| # to retrace it here. | ||
| if xr.is_spmd(): | ||
| # if the input sharding was the same for skip_checking_input_sharding_threashold times | ||
| # we will skip checking the input sharding since it can be expensive. | ||
| if skip_checking_input_sharding_threashold > 0: | ||
| if torch_xla._XLAC._get_xla_sharding_specs( | ||
| args) != xla_args_sharding_spec: | ||
| # update the xla_args with the input with new sharding and retrace | ||
| xla_model.xla_args = args | ||
| (xla_args_sharding_spec, args_and_ou_copy, graph_hash, | ||
| arg_index_to_need_update_index, none_remover, graph_input_matcher, | ||
| dumb_return_handler, | ||
| xla_args_need_update) = extract_graph_helper(xla_model) | ||
| skip_checking_input_sharding_threashold = xu.getenv_as( | ||
| 'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5) | ||
| else: | ||
| skip_checking_input_sharding_threashold -= 1 | ||
|
|
||
| enter_ts = time.time() | ||
| if len(args_and_out) == 0: | ||
| return () | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should .cpu() trigger another compilation? Maybe a separate topic but I'm curious. I assume it shouldn't?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am guessing the graph is something similar to
didn't check the actual HLO