Skip to content

Commit

Permalink
[xla] hlo_computation: compact instructions' vector on Cleanup()
Browse files Browse the repository at this point in the history
tl;dr: this gives a 1.26x compilation time speedup for a large, dense
model in XLA:GPU.

The largest perf leaf seen in profiles of a large, dense model
is related to computing the post order. Surprisingly, it is not
the DFS itself what's most expensive; rather, most of the time is
spent on scanning through HloComputation::Instructions() to identify
DFS roots.

The reason this scan becomes expensive as instructions are removed
is that the vector holding HloInstructionInfo (introduced in
cl/600130708 || openxla/xla@247280ab727)
is not shrunk as it flows through the pipeline, making us having
to walk through many deleted "tombstone" entries. Here is the
histogram of # of tombstones encountered during post order
computations for this model:

```
[        1 - 1,536,345) ****************************** (1,300,248)
[1,536,345 - 3,072,690)  (2)
[3,072,690 - 4,609,034)  (364)
[4,609,034 - 6,145,378)  (10,443)
```

To ameliorate this, this CL shrinks the vector periodically,
so far only between passes. This is done by running compaction
on the vector during HloComputation::Cleanup(), which is called
after every pass. The cost of compaction is made proportional to
the number of deleted entries by swapping--if needed--each tombstone
with the rightmost (within the vector) non-deleted entry.

This brings the number of seen tombstones down significantly:

```
[        1 -   327,699) ****************************** (937,541)
[  327,699 -   655,396)  (308)
[  655,396 -   983,094)  (0)
[  983,094 - 1,310,792)  (1)
```

Note: we could further improve compaction by calling Cleanup()
from some passes, instead of just between passes. However, that
would not yield a significant gain; at least for this model,
scanning the instructions' vector now takes ~1% of total time
(vs. ~17% before).
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde
PiperOrigin-RevId: 619057964
  • Loading branch information
cota authored and tensorflower-gardener committed Apr 4, 2024
1 parent a4d9df4 commit c99f3c4
Show file tree
Hide file tree
Showing 17 changed files with 2,216 additions and 33 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/quantization/stablehlo/BUILD
Expand Up @@ -68,6 +68,7 @@ cc_library(
"passes/restore_function_name.cc",
"passes/unfuse_mhlo_batch_norm.cc",
"passes/unwrap_xla_call_module_op.cc",
"passes/xla_call_module_to_call.cc",
],
hdrs = [
"passes/passes.h",
Expand Down
Expand Up @@ -130,6 +130,13 @@ def PostQuantizePass : Pass<"stablehlo-post-quantize", "mlir::func::FuncOp"> {
];
}

def XlaCallModuleToCallPass : Pass<"stablehlo-xla-call-module-to-call", "ModuleOp"> {
let summary = "Convert XlaCallModuleOp to func.call op";
let dependentDialects = [
"TF::TensorFlowDialect",
];
}

def UnwrapXlaCallModuleOpPass : Pass<"stablehlo-unwrap-xla-call-module-op", "ModuleOp"> {
let summary = "Unwrap XlaCallModuleOps into inline functions if not used for quantizing fused patterns.";
let dependentDialects = ["TF::TensorFlowDialect"];
Expand Down
Expand Up @@ -98,6 +98,11 @@ void QuantizeCompositeFunctionsPass::runOnOperation() {
pm.addPass(createQuantizePass(quantize_options));
pm.addNestedPass<func::FuncOp>(createPostQuantizePass());

// Convert XlaCallModuleOps lifted but not quantized to func.call op.
// The reasons these ops are not quantized may be:
// 1. Disabled due to selective quantization.
// 2. Not supported, e.g. add op for server.
pm.addPass(createXlaCallModuleToCallPass());
ModuleOp module_op = getOperation();
if (const absl::Status pm_run_status =
RunPassesOnModuleOp(mlir_dump_file_name_, pm, module_op);
Expand Down
@@ -0,0 +1,83 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <utility>

#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Support/TypeID.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"

namespace mlir::quant::stablehlo {

#define GEN_PASS_DEF_XLACALLMODULETOCALLPASS
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc"

namespace {

// Converts XlaCallModuleOps to func.call.
class XlaCallModuleToCallPass
: public impl::XlaCallModuleToCallPassBase<XlaCallModuleToCallPass> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XlaCallModuleToCallPass)

explicit XlaCallModuleToCallPass() = default;

private:
void runOnOperation() override;
};

// Converts XlaCallModuleOps to func.call.
class XlaCallModuleOpToCallOp : public OpRewritePattern<TF::XlaCallModuleOp> {
public:
explicit XlaCallModuleOpToCallOp(MLIRContext* context)
: OpRewritePattern<TF::XlaCallModuleOp>(context) {}

LogicalResult matchAndRewrite(TF::XlaCallModuleOp op,
PatternRewriter& rewriter) const override {
auto module_op = op->getParentOfType<ModuleOp>();
SymbolTable symbol_table(module_op);

auto entry_func_op = dyn_cast_or_null<func::FuncOp>(
symbol_table.lookup(GetEntryFunctionName(op)));
if (!entry_func_op) return failure();

// Replace the XlaCallModuleOp with a new CallOp.
rewriter.replaceOpWithNewOp<func::CallOp>(op, entry_func_op, op.getArgs());
return success();
}
};

void XlaCallModuleToCallPass::runOnOperation() {
ModuleOp module_op = getOperation();
MLIRContext* ctx = module_op.getContext();
RewritePatternSet patterns(&getContext());
patterns.add<XlaCallModuleOpToCallOp>(ctx);
if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) {
signalPassFailure();
}
}

} // namespace
} // namespace mlir::quant::stablehlo
Expand Up @@ -1179,6 +1179,126 @@ def test_conv_weight_only_model(
0.35,
)

@parameterized.parameters(
testing.parameter_combinations([{
'shape_dynamic': (
False,
True,
),
}])
)
@test_util.run_in_graph_and_eager_modes
def test_add_ptq_model(
self,
shape_dynamic: bool,
):
input_shape = (None, 3, 4, 3) if shape_dynamic else (2, 3, 4, 3)
self._create_add_model(
input_shape,
self._input_saved_model_path,
)

# Generate model input data.
rng = np.random.default_rng(seed=42)
static_input_shape = [dim if dim is not None else 2 for dim in input_shape]

def data_gen() -> repr_dataset.RepresentativeDataset:
for _ in range(100):
yield {
'input_tensor': rng.uniform(
low=0.0, high=1.0, size=static_input_shape
).astype(np.float32)
}

dataset_path = self.create_tempfile('tfrecord').full_path
path_map = {'serving_default': dataset_path}
repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save(
{'serving_default': data_gen()}
)

config = qc.QuantizationConfig(
static_range_ptq_preset=qc.StaticRangePtqPreset(
representative_datasets=[
qc.RepresentativeDatasetConfig(
tf_record=qc.TfRecordFile(path=dataset_path)
)
],
),
tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]),
)
quantization.quantize_saved_model(
self._input_saved_model_path,
self._output_saved_model_path,
config,
)

self.assertEqual(
self._get_num_xla_call_module_op(self._output_saved_model_path), 1
)
module_str = self._extract_first_xla_call_module_op(
self._output_saved_model_path
)

# Check add is not quantized.
self.assertTrue(re.search(r'stablehlo.add.*f32>', module_str))

@parameterized.parameters(
testing.parameter_combinations([{
'shape_dynamic': (
False,
True,
),
}])
)
@test_util.run_in_graph_and_eager_modes
def test_add_weight_only_model(
self,
shape_dynamic: bool,
):
input_shape = (None, 3, 4, 3) if shape_dynamic else (2, 3, 4, 3)
self._create_add_model(
input_shape,
self._input_saved_model_path,
)

# Generate model input data.
rng = np.random.default_rng(seed=42)
static_input_shape = [dim if dim is not None else 2 for dim in input_shape]

def data_gen() -> repr_dataset.RepresentativeDataset:
for _ in range(100):
yield {
'input_tensor': rng.uniform(
low=0.0, high=1.0, size=static_input_shape
).astype(np.float32)
}

dataset_path = self.create_tempfile('tfrecord').full_path
path_map = {'serving_default': dataset_path}
repr_dataset.TfRecordRepresentativeDatasetSaver(path_map).save(
{'serving_default': data_gen()}
)

config = qc.QuantizationConfig(
weight_only_ptq_preset=qc.WeightOnlyPtqPreset(),
tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]),
)
quantization.quantize_saved_model(
self._input_saved_model_path,
self._output_saved_model_path,
config,
)

self.assertEqual(
self._get_num_xla_call_module_op(self._output_saved_model_path), 1
)
module_str = self._extract_first_xla_call_module_op(
self._output_saved_model_path
)

# Check add is not quantized.
self.assertTrue(re.search(r'stablehlo.add.*f32>', module_str), module_str)


if __name__ == '__main__':
test.main()
Expand Up @@ -73,6 +73,20 @@ def _extract_first_xla_call_module_op(
return str(stablehlo_module)
raise ValueError('No XlaCallModule found in saved model.')

def _get_num_xla_call_module_op(self, output_saved_model_path: str) -> int:
"""Gets the number of XlaCallModule ops in the output saved model."""
root = load.load(output_saved_model_path)
tf_graph_def = root.signatures['serving_default'].graph.as_graph_def()
count = 0
for node_def in tf_graph_def.node:
if node_def.op == 'XlaCallModule':
count += 1
for function in tf_graph_def.library.function:
for node_def in function.node_def:
if node_def.op == 'XlaCallModule':
count += 1
return count

def _create_matmul_model(
self,
input_shape: Sequence[int],
Expand Down Expand Up @@ -339,6 +353,42 @@ def __call__(

return GatherModel(use_variable)

def _create_add_model(
self,
shape: Sequence[int],
saved_model_path: str,
) -> module.Module:
class AddModel(module.Module):
"""A simple model with a single add."""

def __init__(self):
pass

@def_function.function
def add(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]:
"""Performs an add operation.
Args:
input_tensor: Input tensor to perform add on.
Returns:
A map of: output key -> output result.
"""
out = math_ops.add(input_tensor, input_tensor)
return {'output': out}

model = AddModel()
saved_model_save.save(
model,
saved_model_path,
signatures=model.add.get_concrete_function(
tensor_spec.TensorSpec(
shape=shape, dtype=dtypes.float32, name='input_tensor'
)
),
)
return model

# Prepares sample einsum input data shapes.
# This function returns:
# 1. Shape for input 1
Expand Down
Expand Up @@ -69,11 +69,6 @@ func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1:
}
// CHECK-LABEL: func.func @main
// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x1024xf32>) -> tensor<1x3xf32>

// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant dense<{{.*}}> : tensor<1024x3xf32>
// CHECK: "tf.XlaCallModule"(%[[ARG_0]], %[[CONST_0]])

// CHECK: func.func private @composite_dot_general_fn_1
// CHECK-SAME: attributes {_from_xla_call_module}
// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general
// CHECK-SAME: contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
// CHECK: stablehlo.dot_general %[[ARG_0]], %[[CONST_0]]
// CHECK-NOT: tf.XlaCallModule
Expand Up @@ -715,7 +715,7 @@ module attributes {tf_saved_model.semantics} {

// -----

// Tests that XlaCallModule op is not quantized without the quantfork.stats ops.
// Tests that XlaCallModule op is not quantized and converted to func.call without the quantfork.stats ops.

module attributes {tf_saved_model.semantics} {
func.func private @not_quantized_without_stats_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} {
Expand All @@ -728,8 +728,8 @@ module attributes {tf_saved_model.semantics} {

// CHECK: func.func private @not_quantized_without_stats_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"}
// CHECK: %[[CONST_0:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32>
// CHECK: %[[XLA_CALL_MODULE_0:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[CONST_0]]) <{{{.*}}}> {{{.*_entry_function = @composite_dot_general_fn.*}}} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
// CHECK: return %[[XLA_CALL_MODULE_0]]
// CHECK: %[[CALL:.+]] = call @composite_dot_general_fn(%[[ARG_0]], %[[CONST_0]]) : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
// CHECK: return %[[CALL]]

func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
Expand Down
@@ -0,0 +1,23 @@
// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-xla-call-module-to-call | FileCheck %s

// -----

// Tests composite tf.XlaCallModule is converted to func.call.

module {
// CHECK-LABEL: func.func @main
func.func @main(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> {
// CHECK: call @composite_dot_general_fn_1
// CHECK-SAME: (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
// CHECK-NOT: tf.XlaCallModule
%0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32>
%2 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
return %2 : tensor<1x3xf32>
}
// CHECK-LABEL: func.func private @composite_dot_general_fn_1
// CHECK-SAME: -> tensor<1x3xf32>
func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
return %0 : tensor<1x3xf32>
}
}

0 comments on commit c99f3c4

Please sign in to comment.