diff --git a/OP_LOWERING_GUIDE.md b/OP_LOWERING_GUIDE.md index 8288e57779ea..7ff427d1dbdc 100644 --- a/OP_LOWERING_GUIDE.md +++ b/OP_LOWERING_GUIDE.md @@ -8,6 +8,8 @@ Here's an example of what you might see from the PyTorch/XLA debugging tool for pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward, Please open a GitHub issue with the above op lowering requests. ``` +Furthermore, if possible, we want to lower operations to use `full_codegen` see our [codegen migration guide](https://github.com/pytorch/xla/edit/document_xla_override/CODEGEN_MIGRATION_GUIDE.md) for more instructions. + ## Before you start You should follow the instructions in [here](https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md) to install required dependencies and build pytorch and pytorch/XLA from the source. You do not need access to TPU to implement the lowering. It is recommended to experiment on a workstation and configure it to use XLA:CPU. You can configure Pytorch/XLA to use XLA:CPU by running @@ -89,3 +91,10 @@ The codegen will automatically generate lowerings for `lerp_.Scalar` and `lerp.S In general, if there is an operator in pytorch core that has both an out-of-place and an out= variant, it's better to write a lowering for the out-of-place variant, since you'll get a code-generated out= lowering for free. For each node we need to pass an `ir::OpKind`. Here is an ([example](https://github.com/pytorch/xla/blob/5ce99bff336325feb41a982dc80299fb53166b29/torch_xla/csrc/ops/var_mean.cpp#L36)). You can find the `OpKind` definition in [interned_strings.h](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/core/interned_strings.h). If the aten symbol is missing, you can submit a PR like [this](https://github.com/pytorch/pytorch/pull/36851). + +## Overriding `XLA` Dispatch Key +In certain cases, it might be that we need to manually override the `XLA` key implementation of an operation. Ideally codegeneration would handle this, but it is useful to know how to handle an unfortunate edge case. + +If you need to override the `XLA` dispatch key you can do this through macros in the [xla_manual_registration.cpp](https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_manual_registration.cpp) file. + +You can use the https://github.com/pytorch/xla/pull/8801 PR for reference on what files to change. diff --git a/docs/source/contribute/codegen_migration.md b/docs/source/contribute/codegen_migration.md index a84e5568ba5d..cadb15ec6feb 100644 --- a/docs/source/contribute/codegen_migration.md +++ b/docs/source/contribute/codegen_migration.md @@ -133,8 +133,17 @@ at::Tensor XLANativeFunctions::abs(const at::Tensor& self) { Find the op in `xla/codegen/xla_native_functions.yaml` and move it to the full_codegen column and run `python setup.py install` under xla directory again. The build will fail (reason explained later in this -guide) but you can still see the generated file. The code snippets below -uses `abs` as an example. \#### XLANativeFunctions.cpp +guide) but you can still see the generated file. + +If while generating the file you run into an error involving +[`shape_inference.h`](https://github.com/pytorch/pytorch/blob/main/torch/csrc/lazy/core/shape_inference.h), +you might be running into a problem with PyTorch not yet having the +necessary implementation for the function to be generated. You can +attempt to add the necessary function in +[`shape_inference.h`](https://github.com/pytorch/pytorch/blob/main/torch/csrc/lazy/core/shape_inference.h) +to be unblocked. + +The code snippets below uses `abs` as an example. \#### XLANativeFunctions.cpp ``` c++ at::Tensor XLANativeFunctions::abs(const at::Tensor & self) { diff --git a/docs/source/learn/troubleshoot.md b/docs/source/learn/troubleshoot.md index fdc97f8a0b8c..22377ecbb5c8 100644 --- a/docs/source/learn/troubleshoot.md +++ b/docs/source/learn/troubleshoot.md @@ -380,6 +380,11 @@ We don't expect users to use tools in this section to debug their models. But we might ask for them when you submit a bug report since they provide additional information that metrics report doesn't have. +### Debugging Tensor Operations + +The following tools are useful for gathering information on the execution +of lowered operations. + - `print(torch_xla._XLAC._get_xla_tensors_text([res]))` where `res` is the result tensor prints out the IR. - `print(torch_xla._XLAC._get_xla_tensors_hlo([res]))` where `res` is