diff --git a/docs/source/learn/pytorch-on-xla-devices.md b/docs/source/learn/pytorch-on-xla-devices.md index de3cd3c69409..79c73de4e1b8 100644 --- a/docs/source/learn/pytorch-on-xla-devices.md +++ b/docs/source/learn/pytorch-on-xla-devices.md @@ -346,6 +346,72 @@ device is unavailable the load will fail. PyTorch/XLA, like all of PyTorch, is under active development and this behavior may change in the future. +### Unexpected Tensor Materialization During AOT (ahead of time) Tracing + +While tensor materialization is normal for JIT workflow, it is not expected during traced inference (i.e. [AOT model tracing in AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/programming-guide/inference/trace-vs-xla-lazytensor.html)). +When working with traced inference, developers may encounter tensor materialization, which leads to graphs being compiled based on example input tensor value and unexpected program behavior. +Therefore we need to take advantage of PyTorch/XLA's debugging flags to identify when unexpected tensor materialization happens and make appropriate code changes to avoid tensor materialization. + + +A common issue occurs when tensor values are evaluated during model compilation (traced inference). Consider this example: +```python +def forward(self, tensor): + if tensor[0] == 1: + return tensor + else: + return tensor * 2 +``` + +While this code can compile and run, it may lead to unexpected behavior because: + +* The tensor value is being accessed during tracing (``tensor[0]``). +* The resulting graph becomes fixed based on the tensor value available during tracing +* Developers might incorrectly assume the condition will be evaluated dynamically during inference +* The solution for the code above is to utilize the debugging flags below to catch the issue and modify the code. One example is to feed the flag through model configuration + +See the updated code without tensor materialization: +```python +class TestModel(torch.nn.Module): + def __init__(self, flag=1): + super().__init__() + # the flag should be pre-determined based on the model configuration + # it should not be an input of the model during runtime + self.flag = flag + + def forward(self, tensor): + if self.flag: + return tensor + else: + return tensor * 2 +``` + + +#### Debugging Flags +To help catch tensor materialization issues, PyTorch/XLA provides two useful approaches: + +1. Enable warning messages for tensor materialization: +``` +import os +os.environ['PT_XLA_DEBUG_LEVEL'] = '2' +``` + +2. Disable graph execution to catch issues during development: +``` +import torch_xla +torch_xla._XLAC._set_allow_execution(False) +``` + +#### Recommendations + +Using these flags during development can help identify potential issues early in the development cycle. The recommended approach is to: + +* Use ``PT_XLA_DEBUG_LEVEL=2`` during initial development to identify potential materialization points +* Apply ``_set_allow_execution(False)`` when you want to ensure no tensor materialization occurs during tracing +* When you see warnings or errors related the tensor materialization, look into the code path and make appropriate changes. The example above moved the flag to the `__init__` function which does not depend on the model input during runtime. + +For more detailed debugging information, refer to the [XLA troubleshoot](https://github.com/pytorch/xla/blob/master/docs/source/learn/troubleshoot.md#pytorchxla-debugging-tool). + + ## Compilation Caching The XLA compiler converts the traced HLO into an executable which runs diff --git a/docs/source/learn/troubleshoot.md b/docs/source/learn/troubleshoot.md index fab620a22110..f7683d57e39c 100644 --- a/docs/source/learn/troubleshoot.md +++ b/docs/source/learn/troubleshoot.md @@ -137,19 +137,25 @@ Execution Analysis: ------------------------------------------------------------ Execution Analysis: ================================================================================ ``` -Some common causes of Compilation/Executation are 1. User manually call -`torch_xla.sync()`. 2. [Parallel +Some common causes of compilation/executation are +1. User manually calls +`torch_xla.sync()`. +2. [Parallel loader](https://github.com/pytorch/xla/blob/fe4af0080af07f78ca2b614dd91b71885a3bbbb8/torch_xla/distributed/parallel_loader.py#L49-L51) -call `torch_xla.sync()` for every x (configurable) batch. 3. Exiting a +cals `torch_xla.sync()` for every x (configurable) batch. +3. Exit a [profiler StepTrace region](https://github.com/pytorch/xla/blob/fe4af0080af07f78ca2b614dd91b71885a3bbbb8/torch_xla/debug/profiler.py#L165-L171). -4. Dynamo decide to compile/execute the graph. 5. User trying to -access(often due to logging) the value of a tensor before the +4. Dynamo decides to compile/execute the graph. +5. User tries to +access (often due to logging) the value of a tensor before the `torch_xla.sync()`. +6. User tries to access a tensor value before calling `mark_step`. See [PyTorch on XLA Devices](https://github.com/pytorch/xla/blob/master/docs/source/learn/pytorch-on-xla-devices.md) for more details. + +The op executions caused by items 1-4 are expected, and we want to avoid item 5 by +either reducing the frequency of accessing tensor values or manually adding a call to +`torch_xla.sync()` before accessing them. -The execution caused by 1-4 are expected, and we want to avoid 5 by -either reduce the frequency of accessing tensor values or manually add a -`torch_xla.sync()` before accessing. Users should expect to see this `Compilation Cause` + `Executation Cause` pairs for first couple steps. After the model