-
Notifications
You must be signed in to change notification settings - Fork 559
Description
🚀Feature
Add support for HLO minification in PyTorch/XLA similar to the PyTorch minifier.
Motivation
Currently (as of August 15th), there are three models that result in precision failures in the Dynamo benchmarks with the torch_xla_trace_once backend. All three are seen on BartForConditionalGeneration models. These errors are difficult to debug, as it’s difficult to narrow down the operator(s) that cause the failure.
BartForConditionalGeneration: [2023-08-04 00:10:21,049] torch._dynamo.utils: [ERROR] RMSE (res-fp64): 0.40076, (ref-fp64): 0.00000 and shape=torch.Size([1, 1024, 50265])
MBartForConditionalGeneration: [2023-08-04 00:57:55,593] torch._dynamo.utils: [ERROR] RMSE (res-fp64): 0.54682, (ref-fp64): 0.00000 and shape=torch.Size([1, 1024, 50265])
PLBartForConditionalGeneration: [2023-08-04 01:26:40,250] torch._dynamo.utils: [ERROR] RMSE (res-fp64): 0.71820, (ref-fp64): 0.00000 and shape=torch.Size([1, 1024, 50005])
These three models pass on Inductor backend and on Eager mode. It is easier to debug model evaluation failures on those two modes, as the PyTorch minifier can be used. It produces a minified subgraph to narrow down the operator(s) causing the failure. The PT minifier currently only works for FX IR minification. Thus, we need a minifier at the HLO level.
We are also interested in an HLO minifier because it would help AWS Neuron customers. From the customer’s perspective, it is usually not easy for the customers to share the entire model when encountering compilation/precision errors. If they’re able to run the minifier on their model, they can share the minified subgraph with us while protecting their IP. We will also be able to more quickly debug issues, leading to an improved customer experience.
Background on Existing PyTorch Minifier
Usages
The PyTorch minifier is a TorchDynamo
debugging tool that narrows down the FX graph to its minimal subgraph that reproduces a given precision or compilation error. The forwards and backwards graphs are traced in TorchDyanmo
and AOTAutograd
, and these are used as inputs for the minifier. The output is a minimal subgraph that may include as few as one layer, thus greatly simplifying the debugging process.
In the case of a compilation failure, the minifier function is called right after the FX graph is traced. For a precision failure, during the execution of an eval loop, the minifier is ran upon an instruction having a root-mean-square error greater than the tolerance. The minifier is triggered by setting two environment variables - TORCHDYNAMO_REPRO_AFTER
, which specifies whether to run the minifier on forwards or backwards graphs in case if compilation fails, and TORCHDYNAMO_REPRO_LEVEL
, which specifies the granularity of the minifier - 1 specifies a compilation error, and 4 specifies a precision error.
Upon finding the minimal possible subgraph, it saves it as a Repro class method (usually defined to a reproducer script, repro.py. This script invokes TorchDynamo
using torch._dynamo.optimize
and passes the minified subgraph and the name of the compilation backend (e.g. aot_eager
). Finally, a forward pass (and a backward pass, if AOTAutograd
) is ran on the minimal subgraph, and the error is reproduced.
Current Implementation
The PyTorch minifier minimizes an FX graph with a given set of inputs. It attempts four strategies to minimize the graph, include suffix truncation, delta debugging, eliminating dead code, and removing unused inputs. It attempts these strategies until we’re left with the minimal possible subgraph that can reproduce the compilation or precision failure.
Suffix Truncation
The primary method used to minify the subgraph is suffix truncation. This is performed by trying to remove latter parts of the graph using a binary-search-like algorithm. It stops this when we can no longer reproduce the failure, and returns the subgraph. A step-by-step example on how it works can be seen in the Appendix section.
Delta Debugging
If the suffix truncation strategy is not sufficient, such as in cases where there are lots of prefix layers in the graph, delta debugging will be used as the secondary minification method. This strategy elevates prefix layers as input arguments to the forward function. An example of delta debugging is seen in the Appendix section.
Pitch
We want to propose an HLO minifier that mimics the behavior of existing FX minifier. It takes an HLO graph as input, and returns a minified HLO subgraph as the output. The minification tool can also generate a script that includes the subgraph. Such a script enables easier error reproduction and fix validation. This tool will be available for PT/XLA at first. However, the wider StableHLO community can also benefit from this minifier later. The proposed minification tool should be able to be easily integrated with other workflows such as TorchDynamo. More specifically, we will be using TorchDynamo as an entry point for testing the POC.
Implementation
Our intention is for this tool to be used in PyTorch, TensorFlow, JAX, and other frameworks. Thus, most of the new code we propose will be implemented at the C-level, including the HLO parsing/interpretation and the minification strategies. The code will be ported to Python via bindings, similar to current functions such as _get_xla_tensors_hlo
. The final minified subgraph can be passed into a framework-specific codegen, creating a short Python script to reproduce the compilation or precision failure. This script can be then be ran to reproduce the error.
To implement the HLO minifier, the following code changes will need to be made to the PyTorch/XLA repo.
-
In the
torch_xla/debug
directory, create anhlo-minifier
script. This will be imported by the user from a trainer script, and will support both compilation and precision minification use-cases. -
Leverage CreateModuleFromProto, ParseHloModule, and ParseOperands functions to pull the list of modules from the HLO protobuf objects and parse the modules to get the list of ops/operands.
- This will be done at the C level, with the code stored in the
torch_xla/csrc
directory.
- This will be done at the C level, with the code stored in the
-
Create functions that perform suffix truncation and delta debugging minification strategies, and also auxiliary functions that eliminate unused inputs/dead code.
- Suffix truncation and delta debugging will be implemented at the C level, and will be imported to Python through bindings. These strategies will modify the HLO graph. The new functions will be defined in the
init_python_bindings.cpp
file. - The function that takes care of cleaning unused inputs and dead code will be implemented at the Python level, as it does not require any testing or analysis of the HLO graph.
- Then, call these minification functions in a runner called
run_hlo_minifier
.
- Suffix truncation and delta debugging will be implemented at the C level, and will be imported to Python through bindings. These strategies will modify the HLO graph. The new functions will be defined in the
-
Build and run the minified graph to try and reproduce the error. This would be done at the Python level, and would differ based on the compiler used.
-
Repeat steps 3 and 4. Iteratively apply the minification strategies until the smallest possible subgraph that can reproduce the error is created.
-
Create a Python function that generates the
repro.py
script if a model fails compilation or precision.
We will add the HLO minifier code under torch_xla/debug
. This Python module will be imported from PyTorch/XLA in order to invoke the HLO minifier from a trainer script. Finally, the run_hlo_minifier
function would be called from a trainer script defined by the user.
API
A rough example is shown below to demonstrate the API and call flow. Let’s assume that torch.cat
is causing a compilation error when attempting to concatenate two tensors with different shapes across a common dimension. The example uses torch.compile
and thus is dependent on TorchDynamo, but our intention is for the HLO minifier to work with other compiler frontends. Please note that this code is high-level and its only purpose is to open up a discussion.
Trainer script
The trainer script is provided by the user. They import the HLO minification runner function from torch_xla.debug
. The user should specify “compilation” or “precision” to force the HLO minifier to test for a specific error - otherwise, it will default to check for precision errors.
from torch.nn import *
from torch_xla.debug import run_hlo_minifier
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.self_mlp_0 = Linear(in_features=64, out_features=256, bias=True)
def forward(self, cat):
self_mlp_0 = self.self_mlp_0(cat)
x = torch.add(torch.randn(4), torch.ones(4))
y = torch.mul(torch.randn(8,4), torch.randn(8,4))
cond = torch.randn(3,4)
cat_1 = torch.cat((y, cond), dim = 0) #assume that cat is causing an error
return (cat_1, self_mlp_0)
inputs = [torch.randn(3), torch.randn(4)]
model = Model()
try:
trained_model = torch.compile(model, backend="torchxla_trace_once")
run_trained_model = trained_model(inputs)
except Exception as e:
print(f"Encountered {type(e)}: {e}")
run_hlo_minifier(compiled_model, inputs, "compilation")
baseline_model = torch.compile(model) #eager mode baseline
baseline_trained_model = baseline_model(inputs)
#run minifier if delta is observed vs baseline
if not torch.allclose(run_trained_model, baseline_trained_model):
run_hlo_minifier(trained_model, inputs, "precision")
else:
print(f"Model passes evaluation")
Output: Repro script
The repro script includes the minified subgraph and the user-defined inputs. If the error is precision-related, a comparison against a golden generated with eager mode will be added to the repro script. Our example is a compilation error, so we skip this.
from torch.nn import *
class MinifiedModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, y, cond):
return torch.cat((y, cond), dim = 0)
inputs = [torch.randn(3), torch.randn(4)]
model = MinifiedModel()
optim_model = torch.compile(model, backend="torchxla_trace_once")
optim_model(inputs)
HLO minifier script
The HLO minifier runner function, minifier strategy functions, and other helper function prototypes are displayed below. This code will be saved in the torch_xla/debug
directory.
import torch_xla
def run_hlo_minifier(model, inputs, error_type="precision"):
#Iterate through all HLO modules, cycling through minification strategies and create reproducer
modules = _suffix_truncation(model.modules(), inputs, error_type)
modules, input_args = _delta_debugging(modules, inputs, error_type)
input_args = clean_up_unused_args(modules, input_args)
create_repro_script(modules, input_args, inputs)
def _suffix_truncation(modules, inputs, error_type="precision"):
# Try and truncate suffix using pre-set seeds, mimicking the behavior of FX minifier
# Attempt to truncate HLO modules
modules = torch_xla._XLAC._suffix_truncation(modules, inputs, iter, error_type)
# Verify we can reproduce the error. Call suffix truncation function as needed
# Increment iterator until we cycle through all seeds, or when graph has 1 node
...
return modules
def _delta_debugging(modules, inputs, error_type="precision"):
# Try and remove prefix, and elevate variables to input arguments for the graph
mod_modules, inputs = torch_xla._XLAC._delta_debugging(modules, inputs, iter, error_type)
# Verify we can reproduce the error. Call delta debugging function as needed
# Increment iterator until we reach the end of HLO modules
...
return (mod_modules, inputs)
def clean_up_unusued args(modules, input_args):
#Try and clean up dead code by looking for unused input args
...
return mod_input_args
def create_repro_script(minified_modules, input_args, inputs):
with ("repro.py", 'w') as repro_script:
# write repro script that calls minifed graph to reproduce the model error
...
Additional context
After the above is implemented, we are interested in modifying minifier functionality in an attempt to improve its granularity. This can be done by using graph pruning and model compression techniques. This would be applied to both FX and HLO minifiers, with the goal of trying to narrow down the minified graph to a single instruction more reliably. We want to build on existing open minifier issues regarding accuracy minification and help with some issues seen with handling quantization and dynamic shape support use-cases.
Appendix
Suffix Truncation Example (FX minifier)
Let’s say that our backend fails on division ops. Let’s say the forward function looks something like this:
def forward(a_in, b_in):
b_in = torch.ops.aten.div.Tensor(a_in, b_in)
a_in = torch.ops.aten.add.Tensor(a_in, b_in)
b_in = torch.ops.aten.mul.Tensor(a_in, b_in)
a_in = torch.ops.aten.sqrt(b_in)
a_in = torch.ops.aten.add.Tensor(b_in, b_in)
b_in = torch.ops.aten.sigmoid(a_in)
b_in = torch.ops.aten.maximum(a_in, b_in)
a_in = torch.ops.aten.pow.Tensor_Tensor(b_in, a_in)
We first cut the last half of the graph.
def forward(a_in, b_in):
b_in = torch.ops.aten.div.Tensor(a_in, b_in)
a_in = torch.ops.aten.add.Tensor(a_in, b_in)
b_in = torch.ops.aten.mul.Tensor(a_in, b_in)
a_in = torch.ops.aten.sqrt(b_in)
We can still reproduce the error, so we cut the last 3/4ths of the graph.
def forward(a_in, b_in):
b_in = torch.ops.aten.div.Tensor(a_in, b_in)
We can still reproduce the error, and we’re left with a minified subgraph.
Delta Debugging + Clean-Up Example (FX minifier)
Let’s say that our backend fails on multiplication ops. For example, let’s say our forward function looks like this:
def forward(a_in):
b_in = x / 10
c_in = b_in - 5
d_in = c_in * 5
return d_in
If a multiplication operation causes the graph to fail, the middle node can be removed, with c being promoted to an input.
def forward(a_in, c_in):
b_in = x / 10
d_in = c_in * 5
return d_in
Similarly, we can do this for b_in
and x
.
def forward(a_in, c_in, b_in, x):
d_in = c_in * 5
return d_in
This is close, but the graph isn’t fully minified at this point. This is because we have unused inputs. There are two minor, trivial remaining strategies used, which are to clear unused inputs and any dead code. We don’t need a_in
, b_in
, or x
to reproduce the failure in our 1-layer minified graph, so let’s clear them. Also, this function can become a one-line function by moving d_in
to the return statement.
def forward(c_in):
return c_in * 5
Now the graph is fully minified using delta debugging, and cleaning up unused inputs and dead code.
cc @alanwaketan, @JackCaoG, @qihqi