Skip to content

[RFC] Add HLO Minification API #5461

@awskila

Description

@awskila

🚀Feature

Add support for HLO minification in PyTorch/XLA similar to the PyTorch minifier.

Motivation

Currently (as of August 15th), there are three models that result in precision failures in the Dynamo benchmarks with the torch_xla_trace_once backend. All three are seen on BartForConditionalGeneration models. These errors are difficult to debug, as it’s difficult to narrow down the operator(s) that cause the failure.

BartForConditionalGeneration: [2023-08-04 00:10:21,049] torch._dynamo.utils: [ERROR] RMSE (res-fp64): 0.40076, (ref-fp64): 0.00000 and shape=torch.Size([1, 1024, 50265])

MBartForConditionalGeneration: [2023-08-04 00:57:55,593] torch._dynamo.utils: [ERROR] RMSE (res-fp64): 0.54682, (ref-fp64): 0.00000 and shape=torch.Size([1, 1024, 50265])

PLBartForConditionalGeneration: [2023-08-04 01:26:40,250] torch._dynamo.utils: [ERROR] RMSE (res-fp64): 0.71820, (ref-fp64): 0.00000 and shape=torch.Size([1, 1024, 50005])

These three models pass on Inductor backend and on Eager mode. It is easier to debug model evaluation failures on those two modes, as the PyTorch minifier can be used. It produces a minified subgraph to narrow down the operator(s) causing the failure. The PT minifier currently only works for FX IR minification. Thus, we need a minifier at the HLO level.

We are also interested in an HLO minifier because it would help AWS Neuron customers. From the customer’s perspective, it is usually not easy for the customers to share the entire model when encountering compilation/precision errors. If they’re able to run the minifier on their model, they can share the minified subgraph with us while protecting their IP. We will also be able to more quickly debug issues, leading to an improved customer experience.

Background on Existing PyTorch Minifier

Usages

The PyTorch minifier is a TorchDynamo debugging tool that narrows down the FX graph to its minimal subgraph that reproduces a given precision or compilation error. The forwards and backwards graphs are traced in TorchDyanmo and AOTAutograd, and these are used as inputs for the minifier. The output is a minimal subgraph that may include as few as one layer, thus greatly simplifying the debugging process.

In the case of a compilation failure, the minifier function is called right after the FX graph is traced. For a precision failure, during the execution of an eval loop, the minifier is ran upon an instruction having a root-mean-square error greater than the tolerance. The minifier is triggered by setting two environment variables - TORCHDYNAMO_REPRO_AFTER, which specifies whether to run the minifier on forwards or backwards graphs in case if compilation fails, and TORCHDYNAMO_REPRO_LEVEL, which specifies the granularity of the minifier - 1 specifies a compilation error, and 4 specifies a precision error.

Upon finding the minimal possible subgraph, it saves it as a Repro class method (usually defined to a reproducer script, repro.py. This script invokes TorchDynamo using torch._dynamo.optimize and passes the minified subgraph and the name of the compilation backend (e.g. aot_eager). Finally, a forward pass (and a backward pass, if AOTAutograd) is ran on the minimal subgraph, and the error is reproduced.

Current Implementation

The PyTorch minifier minimizes an FX graph with a given set of inputs. It attempts four strategies to minimize the graph, include suffix truncation, delta debugging, eliminating dead code, and removing unused inputs. It attempts these strategies until we’re left with the minimal possible subgraph that can reproduce the compilation or precision failure.

Suffix Truncation

The primary method used to minify the subgraph is suffix truncation. This is performed by trying to remove latter parts of the graph using a binary-search-like algorithm. It stops this when we can no longer reproduce the failure, and returns the subgraph. A step-by-step example on how it works can be seen in the Appendix section.

Delta Debugging

If the suffix truncation strategy is not sufficient, such as in cases where there are lots of prefix layers in the graph, delta debugging will be used as the secondary minification method. This strategy elevates prefix layers as input arguments to the forward function. An example of delta debugging is seen in the Appendix section.

Pitch

We want to propose an HLO minifier that mimics the behavior of existing FX minifier. It takes an HLO graph as input, and returns a minified HLO subgraph as the output. The minification tool can also generate a script that includes the subgraph. Such a script enables easier error reproduction and fix validation. This tool will be available for PT/XLA at first. However, the wider StableHLO community can also benefit from this minifier later. The proposed minification tool should be able to be easily integrated with other workflows such as TorchDynamo. More specifically, we will be using TorchDynamo as an entry point for testing the POC.

Implementation

Our intention is for this tool to be used in PyTorch, TensorFlow, JAX, and other frameworks. Thus, most of the new code we propose will be implemented at the C-level, including the HLO parsing/interpretation and the minification strategies. The code will be ported to Python via bindings, similar to current functions such as _get_xla_tensors_hlo. The final minified subgraph can be passed into a framework-specific codegen, creating a short Python script to reproduce the compilation or precision failure. This script can be then be ran to reproduce the error.

To implement the HLO minifier, the following code changes will need to be made to the PyTorch/XLA repo.

  1. In the torch_xla/debug directory, create an hlo-minifier script. This will be imported by the user from a trainer script, and will support both compilation and precision minification use-cases.

  2. Leverage CreateModuleFromProto, ParseHloModule, and ParseOperands functions to pull the list of modules from the HLO protobuf objects and parse the modules to get the list of ops/operands.

    • This will be done at the C level, with the code stored in the torch_xla/csrc directory.
  3. Create functions that perform suffix truncation and delta debugging minification strategies, and also auxiliary functions that eliminate unused inputs/dead code.

    • Suffix truncation and delta debugging will be implemented at the C level, and will be imported to Python through bindings. These strategies will modify the HLO graph. The new functions will be defined in the init_python_bindings.cpp file.
    • The function that takes care of cleaning unused inputs and dead code will be implemented at the Python level, as it does not require any testing or analysis of the HLO graph.
    • Then, call these minification functions in a runner called run_hlo_minifier.
  4. Build and run the minified graph to try and reproduce the error. This would be done at the Python level, and would differ based on the compiler used.

  5. Repeat steps 3 and 4. Iteratively apply the minification strategies until the smallest possible subgraph that can reproduce the error is created.

  6. Create a Python function that generates the repro.py script if a model fails compilation or precision.

We will add the HLO minifier code under torch_xla/debug. This Python module will be imported from PyTorch/XLA in order to invoke the HLO minifier from a trainer script. Finally, the run_hlo_minifier function would be called from a trainer script defined by the user.

API

A rough example is shown below to demonstrate the API and call flow. Let’s assume that torch.cat is causing a compilation error when attempting to concatenate two tensors with different shapes across a common dimension. The example uses torch.compile and thus is dependent on TorchDynamo, but our intention is for the HLO minifier to work with other compiler frontends. Please note that this code is high-level and its only purpose is to open up a discussion.

Trainer script

The trainer script is provided by the user. They import the HLO minification runner function from torch_xla.debug. The user should specify “compilation” or “precision” to force the HLO minifier to test for a specific error - otherwise, it will default to check for precision errors.

from torch.nn import *
from torch_xla.debug import run_hlo_minifier

class Model(torch.nn.Module):
   def __init__(self):
        super().__init__()
        self.self_mlp_0 = Linear(in_features=64, out_features=256, bias=True)

    def forward(self, cat):
        self_mlp_0 = self.self_mlp_0(cat)
        x = torch.add(torch.randn(4), torch.ones(4))
        y = torch.mul(torch.randn(8,4), torch.randn(8,4))
        cond = torch.randn(3,4) 
        cat_1 = torch.cat((y, cond), dim = 0) #assume that cat is causing an error
        return (cat_1, self_mlp_0)

inputs = [torch.randn(3), torch.randn(4)]
model = Model()

try:
    trained_model = torch.compile(model, backend="torchxla_trace_once")
    run_trained_model = trained_model(inputs)
except Exception as e:
    print(f"Encountered {type(e)}: {e}")
    run_hlo_minifier(compiled_model, inputs, "compilation")

baseline_model = torch.compile(model) #eager mode baseline
baseline_trained_model = baseline_model(inputs)

#run minifier if delta is observed vs baseline
if not torch.allclose(run_trained_model, baseline_trained_model):
    run_hlo_minifier(trained_model, inputs, "precision")
else:
    print(f"Model passes evaluation")

Output: Repro script

The repro script includes the minified subgraph and the user-defined inputs. If the error is precision-related, a comparison against a golden generated with eager mode will be added to the repro script. Our example is a compilation error, so we skip this.

from torch.nn import *

class MinifiedModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
     
    def forward(self, y, cond):
        return torch.cat((y, cond), dim = 0)

inputs = [torch.randn(3), torch.randn(4)]
model = MinifiedModel()
optim_model = torch.compile(model, backend="torchxla_trace_once")
optim_model(inputs)

HLO minifier script

The HLO minifier runner function, minifier strategy functions, and other helper function prototypes are displayed below. This code will be saved in the torch_xla/debug directory.

import torch_xla

def run_hlo_minifier(model, inputs, error_type="precision"):
   #Iterate through all HLO modules, cycling through minification strategies and create reproducer
   modules = _suffix_truncation(model.modules(), inputs, error_type)
   modules, input_args = _delta_debugging(modules, inputs, error_type)
   input_args = clean_up_unused_args(modules, input_args)
   create_repro_script(modules, input_args, inputs)

def _suffix_truncation(modules, inputs, error_type="precision"):
    # Try and truncate suffix using pre-set seeds, mimicking the behavior of FX minifier
    # Attempt to truncate HLO modules
    modules = torch_xla._XLAC._suffix_truncation(modules, inputs, iter, error_type)
    
    # Verify we can reproduce the error. Call suffix truncation function as needed
    # Increment iterator until we cycle through all seeds, or when graph has 1 node
    ...
    return modules

def _delta_debugging(modules, inputs, error_type="precision"):
    # Try and remove prefix, and elevate variables to input arguments for the graph
    mod_modules, inputs = torch_xla._XLAC._delta_debugging(modules, inputs, iter, error_type)
 
    # Verify we can reproduce the error. Call delta debugging function as needed
    # Increment iterator until we reach the end of HLO modules
    ...
    return (mod_modules, inputs)
         
def clean_up_unusued args(modules, input_args):
    #Try and clean up dead code by looking for unused input args
    ...
    return mod_input_args

def create_repro_script(minified_modules, input_args, inputs):
    with ("repro.py", 'w') as repro_script:
        # write repro script that calls minifed graph to reproduce the model error
        ...

Additional context

After the above is implemented, we are interested in modifying minifier functionality in an attempt to improve its granularity. This can be done by using graph pruning and model compression techniques. This would be applied to both FX and HLO minifiers, with the goal of trying to narrow down the minified graph to a single instruction more reliably. We want to build on existing open minifier issues regarding accuracy minification and help with some issues seen with handling quantization and dynamic shape support use-cases.

Appendix

Suffix Truncation Example (FX minifier)

Let’s say that our backend fails on division ops. Let’s say the forward function looks something like this:

def forward(a_in, b_in):
   b_in = torch.ops.aten.div.Tensor(a_in, b_in)
   a_in = torch.ops.aten.add.Tensor(a_in, b_in)
   b_in = torch.ops.aten.mul.Tensor(a_in, b_in)
   a_in = torch.ops.aten.sqrt(b_in)
   a_in = torch.ops.aten.add.Tensor(b_in, b_in)
   b_in = torch.ops.aten.sigmoid(a_in)
   b_in = torch.ops.aten.maximum(a_in, b_in)
   a_in = torch.ops.aten.pow.Tensor_Tensor(b_in, a_in)

We first cut the last half of the graph.

def forward(a_in, b_in):
   b_in = torch.ops.aten.div.Tensor(a_in, b_in)
   a_in = torch.ops.aten.add.Tensor(a_in, b_in)
   b_in = torch.ops.aten.mul.Tensor(a_in, b_in)
   a_in = torch.ops.aten.sqrt(b_in)

We can still reproduce the error, so we cut the last 3/4ths of the graph.

def forward(a_in, b_in):
   b_in = torch.ops.aten.div.Tensor(a_in, b_in)

We can still reproduce the error, and we’re left with a minified subgraph.

Delta Debugging + Clean-Up Example (FX minifier)

Let’s say that our backend fails on multiplication ops. For example, let’s say our forward function looks like this:

def forward(a_in):
  b_in = x / 10
  c_in = b_in - 5
  d_in = c_in * 5
  return d_in

If a multiplication operation causes the graph to fail, the middle node can be removed, with c being promoted to an input.

def forward(a_in, c_in):
  b_in = x / 10
  d_in = c_in * 5
  return d_in

Similarly, we can do this for b_in and x .

def forward(a_in, c_in, b_in, x):
  d_in = c_in * 5
  return d_in

This is close, but the graph isn’t fully minified at this point. This is because we have unused inputs. There are two minor, trivial remaining strategies used, which are to clear unused inputs and any dead code. We don’t need a_in, b_in, or x to reproduce the failure in our 1-layer minified graph, so let’s clear them. Also, this function can become a one-line function by moving d_in to the return statement.

def forward(c_in):
  return c_in * 5

Now the graph is fully minified using delta debugging, and cleaning up unused inputs and dead code.

cc @alanwaketan, @JackCaoG, @qihqi

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions