From b3eabf31a27122d75e6bf9d5749522277f3922f2 Mon Sep 17 00:00:00 2001 From: Farzin Houshmand Date: Wed, 20 Mar 2024 12:13:11 -0700 Subject: [PATCH] Add double buffer removal pass FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/10638 from jaro-sevcik:host-allocation-custom-call 0f78007475921031549c8555ad56ed8903efe6cb PiperOrigin-RevId: 617588813 --- .../compiler/mlir/lite/tests/optimize.mlir | 11 + .../mlir/lite/transforms/legalize_patterns.td | 3 - .../mlir/lite/transforms/optimize_patterns.td | 14 + tensorflow/compiler/mlir/lite/utils/utils.td | 3 + third_party/stablehlo/temporary.patch | 75 ++- .../xla/third_party/stablehlo/temporary.patch | 75 ++- third_party/xla/xla/pjrt/pjrt_future.h | 66 +- third_party/xla/xla/pjrt/pjrt_future_test.cc | 19 + third_party/xla/xla/service/BUILD | 44 ++ third_party/xla/xla/service/gpu/BUILD | 4 +- .../xla/service/gpu/cudnn_fusion_compiler.cc | 62 +- .../xla/service/gpu/cudnn_support_utils.cc | 7 + .../xla/xla/service/gpu/cudnn_support_utils.h | 6 + third_party/xla/xla/service/gpu/fusions/BUILD | 4 + .../xla/xla/service/gpu/fusions/cudnn_test.cc | 46 +- .../xla/service/gpu/fusions/reduction_base.cc | 6 +- .../gpu/fusions/reduction_base_test.cc | 18 +- .../xla/service/gpu/ir_emitter_unnested.cc | 3 + .../service/gpu/model/coalescing_analysis.cc | 12 +- third_party/xla/xla/service/gpu/tests/BUILD | 12 + .../service/gpu/tests/nop_custom_call_test.cc | 51 ++ .../xla/service/gpu/triton_fusion_analysis.cc | 6 + .../service/while_double_buffer_removal.cc | 346 +++++++++++ .../xla/service/while_double_buffer_removal.h | 52 ++ .../while_double_buffer_removal_test.cc | 582 ++++++++++++++++++ .../xla/xla/service/while_loop_simplifier.h | 2 + .../xla/xla/stream_executor/cuda/cuda_dnn.cc | 13 +- 27 files changed, 1408 insertions(+), 134 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/tests/nop_custom_call_test.cc create mode 100644 third_party/xla/xla/service/while_double_buffer_removal.cc create mode 100644 third_party/xla/xla/service/while_double_buffer_removal.h create mode 100644 third_party/xla/xla/service/while_double_buffer_removal_test.cc diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 75c1a791eeca73..b367183d8ecf61 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -2652,6 +2652,17 @@ func.func @FuseAddWithFullyConnectedWithQuantizedWeight(%arg: tensor<2x512xf32>) // CHECK: tfl.add } +// CHECK-LABEL: @FuseBatchMatMulAndTransposeWithQuantizedWeight +func.func @FuseBatchMatMulAndTransposeWithQuantizedWeight(%arg: tensor<1x2xf32>) -> tensor<1x3xf32> { + %cst_3 = arith.constant dense<[1, 0]> : tensor<2xi32> + %79 = "tfl.pseudo_qconst"() {qtype = tensor<3x2x!quant.uniform:f32:0, {2.378620e-03,2.848260e-03,2.545190e-03}>>, value = dense<10> : tensor<3x2xi8>} : () -> tensor<3x2x!quant.uniform:f32:0, {2.378620e-03,2.848260e-03,2.545190e-03}>> + %80 = "tfl.transpose"(%79, %cst_3) : (tensor<3x2x!quant.uniform:f32:0, {2.378620e-03,2.848260e-03,2.545190e-03}>>, tensor<2xi32>) -> tensor<2x3x!quant.uniform:f32:1, {2.378620e-03,2.848260e-03,2.545190e-03}>> + %81 = "tfl.batch_matmul"(%arg, %80) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<1x2xf32>, tensor<2x3x!quant.uniform:f32:1, {2.378620e-03,2.848260e-03,2.545190e-03}>>) -> tensor<1x3xf32> + func.return %81 : tensor<1x3xf32> + + // CHECK: tfl.fully_connected +} + // CHECK-LABEL: @FuseAddWithFullyConnectedNoBias // Note: Currently not fused. func.func @FuseAddWithFullyConnectedNoBias(%arg: tensor<2x512xf32>) -> tensor<2x1024xf32> { diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index cfe9bc754d8077..240773a82a9657 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -73,9 +73,6 @@ def CreateTFCastToInt32Op : NativeCodeCall< def CreateInt32ConstOrCast : NativeCodeCall< "CreateInt32ConstOrCast($0, $_loc, $_builder)">; -def CreateNoneValue : NativeCodeCall< - "$_builder.create($0.getLoc(), $_builder.getUnitAttr())">; - // Creates an int32 constant op from an integer attribute $0. def CreateInt32ConstOpFromIntAttr : NativeCodeCall<"$_builder.create($_loc, DenseElementsAttr::get(RankedTensorType::get({}, $_builder.getI32Type()), {static_cast($0.cast().getInt())}))">; diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 0b068972c8fd30..cad25e828f1ac5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -1562,6 +1562,20 @@ def FuseTransposeAfterBatchMatmul : Pat< ), [(AreLastTwoDimsTransposed $perm_value)]>; +// Fuse redundant RHS TFL_TransposeOp into TFL_BatchMatMulOp if rhs is constant +// tensor of rank-2. +def FuseTransposeIntoBatchMatMulRHS: Pat< + (TFL_BatchMatMulOp $lhs, + (TFL_TransposeOp (TFL_QConstOp:$input $_, $_), (Arith_ConstantOp:$perm_value $p0)), + $adj_x, $adj_y, $asymmetric_quantize_inputs), + (TFL_FullyConnectedOp + $lhs, + $input, (CreateNoneValue $lhs), TFL_AF_None, TFL_FCWO_Default, + ConstBoolAttrTrue, $asymmetric_quantize_inputs), + [(HasRank<2> $input), + (AreLastTwoDimsTransposed $perm_value), + (IsBoolAttrEqual<"false"> $adj_y)]>; + // Replace conv-->transpose-->add with conv-->add-->transpose // The bias needs only reshape (i.e. ReshapeNCHWBiasToNHWC) and not transpose // because the bias's shape simply changes from NxCx1x1 to Nx1x1xC. diff --git a/tensorflow/compiler/mlir/lite/utils/utils.td b/tensorflow/compiler/mlir/lite/utils/utils.td index 42af8c67b2a7ce..77f971339acae1 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.td +++ b/tensorflow/compiler/mlir/lite/utils/utils.td @@ -19,6 +19,9 @@ include "mlir/IR/OpBase.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/IR/PatternBase.td" +def CreateNoneValue : NativeCodeCall< + "$_builder.create($0.getLoc(), $_builder.getUnitAttr())">; + // Returns shape of a ranked tensor. // if called without a ranked tensor it will fail. def GetShape: NativeCodeCall<"GetShape($0)">; diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 4b79cdb46dcdf9..e0447d44f59cb9 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -446,7 +446,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/CMakeLists.txt b/stablehlo diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp --- stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp +++ stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp -@@ -0,0 +1,506 @@ +@@ -0,0 +1,505 @@ +/* Copyright 2023 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); @@ -521,7 +521,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + // reduce_window_i1 + SmallVector inputTypes; + for (auto [index, input] : llvm::enumerate(inputs)) { -+ auto inputType = input.getType().dyn_cast(); ++ auto inputType = dyn_cast(input.getType()); + inputTypes.push_back(inputType); + if (!inputType) + return op_.emitError() @@ -531,7 +531,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + // reduce_window_i2 + SmallVector initValueTypes; + for (auto [index, initValue] : llvm::enumerate(initValues)) { -+ auto initValueType = initValue.getType().dyn_cast(); ++ auto initValueType = dyn_cast(initValue.getType()); + initValueTypes.push_back(initValueType); + if (!initValueType || !initValueType.hasRank() || + initValueType.getRank() != 0) @@ -543,7 +543,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + // reduce_window_i3...reduce_window_i7 + auto checkRank = [&](StringRef name, int64_t index, Value dynamicAttr, + int64_t expectedRank) -> LogicalResult { -+ auto type = dynamicAttr.getType().dyn_cast(); ++ auto type = dyn_cast(dynamicAttr.getType()); + if (!type || !type.hasRank() || type.getRank() != expectedRank || + !type.getElementType().isIntOrIndex()) { + if (index < 0) index += op_->getNumOperands(); @@ -562,7 +562,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + return failure(); + + // reduce_window_i7 -+ auto paddingType = getPadding().getType().dyn_cast(); ++ auto paddingType = dyn_cast(getPadding().getType()); + if (!paddingType || !paddingType.hasRank() || paddingType.getRank() != 2 || + paddingType.getDimSize(1) != 2 || + !paddingType.getElementType().isIntOrIndex()) @@ -598,7 +598,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + // verify them in that case, that seems like too much at this point. + auto checkShape = [&](StringRef name, int64_t index, Value dynamicAttr, + ArrayRef expectedShape) -> LogicalResult { -+ auto type = dynamicAttr.getType().cast(); ++ auto type = cast(dynamicAttr.getType()); + if (type.getShape() != expectedShape) { + if (index < 0) index += op_->getNumOperands(); + return op_.emitError() @@ -622,7 +622,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + // reduce_window_c13 + if (op_.getCalledComputations().size() != 1) + return op_.emitError() << "expects called_computations to have 1 element"; -+ auto bodyAttr = op_.getCalledComputations()[0].cast(); ++ auto bodyAttr = cast(op_.getCalledComputations()[0]); + auto bodyFunc = + op_->getParentOfType().lookupSymbol(bodyAttr); + if (!bodyFunc) @@ -644,7 +644,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + SmallVector resultTypes; + std::optional> resultShape; + for (auto result : results) { -+ auto resultType = result.getType().dyn_cast(); ++ auto resultType = dyn_cast(result.getType()); + resultTypes.push_back(resultType); + if (!resultType) return op_.emitError() << "expects results to be tensors"; + @@ -683,32 +683,32 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh +} + +TypedValue DynamicReduceWindowOpAdaptor::getWindowDimensions() { -+ return op_.getInputs()[op_.getInputs().size() - 5] -+ .cast>(); ++ return cast>( ++ op_.getInputs()[op_.getInputs().size() - 5]); +} + +TypedValue DynamicReduceWindowOpAdaptor::getWindowStrides() { -+ return op_.getInputs()[op_.getInputs().size() - 4] -+ .cast>(); ++ return cast>( ++ op_.getInputs()[op_.getInputs().size() - 4]); +} + +TypedValue DynamicReduceWindowOpAdaptor::getBaseDilations() { -+ return op_.getInputs()[op_.getInputs().size() - 3] -+ .cast>(); ++ return cast>( ++ op_.getInputs()[op_.getInputs().size() - 3]); +} + +TypedValue DynamicReduceWindowOpAdaptor::getWindowDilations() { -+ return op_.getInputs()[op_.getInputs().size() - 2] -+ .cast>(); ++ return cast>( ++ op_.getInputs()[op_.getInputs().size() - 2]); +} + +TypedValue DynamicReduceWindowOpAdaptor::getPadding() { -+ return op_.getInputs()[op_.getInputs().size() - 1] -+ .cast>(); ++ return cast>( ++ op_.getInputs()[op_.getInputs().size() - 1]); +} + +Region& DynamicReduceWindowOpAdaptor::getBody() { -+ auto bodyAttr = op_.getCalledComputations()[0].cast(); ++ auto bodyAttr = cast(op_.getCalledComputations()[0]); + auto bodyFunc = + op_->getParentOfType().lookupSymbol(bodyAttr); + return bodyFunc.getBody(); @@ -758,20 +758,20 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + auto output = op_.getResults()[1]; + + // dynamic_rng_bit_generator_i1 -+ if (!rngAlgorithmAttr.isa()) ++ if (!isa(rngAlgorithmAttr)) + return op_.emitError() + << "expects a #stablehlo rng_algorithm"; + + // dynamic_rng_bit_generator_i2 + // TODO(#643): Clarify supported types for RngBitGeneratorOp. -+ auto initialStateType = initialState.getType().dyn_cast(); ++ auto initialStateType = dyn_cast(initialState.getType()); + if (!initialStateType || !initialStateType.getElementType().isIntOrFloat()) + return op_.emitError() + << "expects initial_state (operand #0) " + << "to be a tensor of integer or floating-point type"; + + // dynamic_rng_bit_generator_i3 -+ auto outputShapeType = outputShape.getType().dyn_cast(); ++ auto outputShapeType = dyn_cast(outputShape.getType()); + if (!outputShapeType || !outputShapeType.hasRank() || + outputShapeType.getRank() != 1 || + !outputShapeType.getElementType().isIntOrIndex()) @@ -781,14 +781,14 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + + // dynamic_rng_bit_generator_o1 + // TODO(#643): Clarify supported types for RngBitGeneratorOp. -+ auto outputStateType = outputState.getType().dyn_cast(); ++ auto outputStateType = dyn_cast(outputState.getType()); + if (!outputStateType || !outputStateType.getElementType().isIntOrFloat()) + return op_.emitError() + << "expects output_state (result #0) " + << "to be a tensor of integer or floating-point type"; + + // dynamic_rng_bit_generator_o2 -+ auto outputType = output.getType().dyn_cast(); ++ auto outputType = dyn_cast(output.getType()); + if (!outputType || !outputType.getElementType().isIntOrFloat()) + return op_.emitError() + << "expects output (result #1) " @@ -812,25 +812,24 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh +} + +RngAlgorithm DynamicRngBitGeneratorOpAdaptor::getRngAlgorithm() { -+ return op_->getDiscardableAttr("rng_algorithm") -+ .cast() ++ return cast(op_->getDiscardableAttr("rng_algorithm")) + .getValue(); +} + +TypedValue DynamicRngBitGeneratorOpAdaptor::getInitialState() { -+ return op_.getInputs()[0].cast>(); ++ return cast>(op_.getInputs()[0]); +} + +TypedValue DynamicRngBitGeneratorOpAdaptor::getOutputShape() { -+ return op_.getInputs()[1].cast>(); ++ return cast>(op_.getInputs()[1]); +} + +TypedValue DynamicRngBitGeneratorOpAdaptor::getOutputState() { -+ return op_.getResults()[0].cast>(); ++ return cast>(op_.getResults()[0]); +} + +TypedValue DynamicRngBitGeneratorOpAdaptor::getOutput() { -+ return op_.getResults()[1].cast>(); ++ return cast>(op_.getResults()[1]); +} + +std::optional getDynamicRngBitGeneratorOp( @@ -864,7 +863,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + auto indices = op_.getResults()[1]; + + // dynamic_top_k_i1 -+ auto operandType = operand.getType().dyn_cast(); ++ auto operandType = dyn_cast(operand.getType()); + if (!operandType || !operandType.hasRank() || operandType.getRank() < 1 || + !operandType.getElementType().isIntOrFloat()) + return op_.emitError() @@ -873,7 +872,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + << "of rank at least 1"; + + // dynamic_top_k_i2 -+ auto kType = k.getType().dyn_cast(); ++ auto kType = dyn_cast(k.getType()); + if (!kType || !kType.hasRank() || kType.getRank() != 0 || + !kType.getElementType().isIntOrIndex()) + return op_.emitError() @@ -881,7 +880,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + << "to be a 0-dimensional tensor of integer or index type"; + + // dynamic_top_k_o1 -+ auto valuesType = values.getType().dyn_cast(); ++ auto valuesType = dyn_cast(values.getType()); + if (!valuesType || !valuesType.hasRank() || valuesType.getRank() < 1 || + !valuesType.getElementType().isIntOrFloat()) + return op_.emitError() @@ -890,7 +889,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + << "of rank at least 1"; + + // dynamic_top_k_o2 -+ auto indicesType = indices.getType().dyn_cast(); ++ auto indicesType = dyn_cast(indices.getType()); + if (!indicesType || !indicesType.hasRank() || indicesType.getRank() < 1 || + !indicesType.getElementType().isSignlessInteger(32)) + return op_.emitError() << "expects indices (result #1) " @@ -930,19 +929,19 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh +} + +TypedValue DynamicTopKOpAdaptor::getOperand() { -+ return op_.getInputs()[0].cast>(); ++ return cast>(op_.getInputs()[0]); +} + +TypedValue DynamicTopKOpAdaptor::getK() { -+ return op_.getInputs()[1].cast>(); ++ return cast>(op_.getInputs()[1]); +} + +TypedValue DynamicTopKOpAdaptor::getValues() { -+ return op_.getResults()[0].cast>(); ++ return cast>(op_.getResults()[0]); +} + +TypedValue DynamicTopKOpAdaptor::getIndices() { -+ return op_.getResults()[1].cast>(); ++ return cast>(op_.getResults()[1]); +} + +std::optional getDynamicTopKOp(CustomCallOp op) { diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 4b79cdb46dcdf9..e0447d44f59cb9 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -446,7 +446,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/CMakeLists.txt b/stablehlo diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp --- stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp +++ stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp -@@ -0,0 +1,506 @@ +@@ -0,0 +1,505 @@ +/* Copyright 2023 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); @@ -521,7 +521,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + // reduce_window_i1 + SmallVector inputTypes; + for (auto [index, input] : llvm::enumerate(inputs)) { -+ auto inputType = input.getType().dyn_cast(); ++ auto inputType = dyn_cast(input.getType()); + inputTypes.push_back(inputType); + if (!inputType) + return op_.emitError() @@ -531,7 +531,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + // reduce_window_i2 + SmallVector initValueTypes; + for (auto [index, initValue] : llvm::enumerate(initValues)) { -+ auto initValueType = initValue.getType().dyn_cast(); ++ auto initValueType = dyn_cast(initValue.getType()); + initValueTypes.push_back(initValueType); + if (!initValueType || !initValueType.hasRank() || + initValueType.getRank() != 0) @@ -543,7 +543,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + // reduce_window_i3...reduce_window_i7 + auto checkRank = [&](StringRef name, int64_t index, Value dynamicAttr, + int64_t expectedRank) -> LogicalResult { -+ auto type = dynamicAttr.getType().dyn_cast(); ++ auto type = dyn_cast(dynamicAttr.getType()); + if (!type || !type.hasRank() || type.getRank() != expectedRank || + !type.getElementType().isIntOrIndex()) { + if (index < 0) index += op_->getNumOperands(); @@ -562,7 +562,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + return failure(); + + // reduce_window_i7 -+ auto paddingType = getPadding().getType().dyn_cast(); ++ auto paddingType = dyn_cast(getPadding().getType()); + if (!paddingType || !paddingType.hasRank() || paddingType.getRank() != 2 || + paddingType.getDimSize(1) != 2 || + !paddingType.getElementType().isIntOrIndex()) @@ -598,7 +598,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + // verify them in that case, that seems like too much at this point. + auto checkShape = [&](StringRef name, int64_t index, Value dynamicAttr, + ArrayRef expectedShape) -> LogicalResult { -+ auto type = dynamicAttr.getType().cast(); ++ auto type = cast(dynamicAttr.getType()); + if (type.getShape() != expectedShape) { + if (index < 0) index += op_->getNumOperands(); + return op_.emitError() @@ -622,7 +622,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + // reduce_window_c13 + if (op_.getCalledComputations().size() != 1) + return op_.emitError() << "expects called_computations to have 1 element"; -+ auto bodyAttr = op_.getCalledComputations()[0].cast(); ++ auto bodyAttr = cast(op_.getCalledComputations()[0]); + auto bodyFunc = + op_->getParentOfType().lookupSymbol(bodyAttr); + if (!bodyFunc) @@ -644,7 +644,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + SmallVector resultTypes; + std::optional> resultShape; + for (auto result : results) { -+ auto resultType = result.getType().dyn_cast(); ++ auto resultType = dyn_cast(result.getType()); + resultTypes.push_back(resultType); + if (!resultType) return op_.emitError() << "expects results to be tensors"; + @@ -683,32 +683,32 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh +} + +TypedValue DynamicReduceWindowOpAdaptor::getWindowDimensions() { -+ return op_.getInputs()[op_.getInputs().size() - 5] -+ .cast>(); ++ return cast>( ++ op_.getInputs()[op_.getInputs().size() - 5]); +} + +TypedValue DynamicReduceWindowOpAdaptor::getWindowStrides() { -+ return op_.getInputs()[op_.getInputs().size() - 4] -+ .cast>(); ++ return cast>( ++ op_.getInputs()[op_.getInputs().size() - 4]); +} + +TypedValue DynamicReduceWindowOpAdaptor::getBaseDilations() { -+ return op_.getInputs()[op_.getInputs().size() - 3] -+ .cast>(); ++ return cast>( ++ op_.getInputs()[op_.getInputs().size() - 3]); +} + +TypedValue DynamicReduceWindowOpAdaptor::getWindowDilations() { -+ return op_.getInputs()[op_.getInputs().size() - 2] -+ .cast>(); ++ return cast>( ++ op_.getInputs()[op_.getInputs().size() - 2]); +} + +TypedValue DynamicReduceWindowOpAdaptor::getPadding() { -+ return op_.getInputs()[op_.getInputs().size() - 1] -+ .cast>(); ++ return cast>( ++ op_.getInputs()[op_.getInputs().size() - 1]); +} + +Region& DynamicReduceWindowOpAdaptor::getBody() { -+ auto bodyAttr = op_.getCalledComputations()[0].cast(); ++ auto bodyAttr = cast(op_.getCalledComputations()[0]); + auto bodyFunc = + op_->getParentOfType().lookupSymbol(bodyAttr); + return bodyFunc.getBody(); @@ -758,20 +758,20 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + auto output = op_.getResults()[1]; + + // dynamic_rng_bit_generator_i1 -+ if (!rngAlgorithmAttr.isa()) ++ if (!isa(rngAlgorithmAttr)) + return op_.emitError() + << "expects a #stablehlo rng_algorithm"; + + // dynamic_rng_bit_generator_i2 + // TODO(#643): Clarify supported types for RngBitGeneratorOp. -+ auto initialStateType = initialState.getType().dyn_cast(); ++ auto initialStateType = dyn_cast(initialState.getType()); + if (!initialStateType || !initialStateType.getElementType().isIntOrFloat()) + return op_.emitError() + << "expects initial_state (operand #0) " + << "to be a tensor of integer or floating-point type"; + + // dynamic_rng_bit_generator_i3 -+ auto outputShapeType = outputShape.getType().dyn_cast(); ++ auto outputShapeType = dyn_cast(outputShape.getType()); + if (!outputShapeType || !outputShapeType.hasRank() || + outputShapeType.getRank() != 1 || + !outputShapeType.getElementType().isIntOrIndex()) @@ -781,14 +781,14 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + + // dynamic_rng_bit_generator_o1 + // TODO(#643): Clarify supported types for RngBitGeneratorOp. -+ auto outputStateType = outputState.getType().dyn_cast(); ++ auto outputStateType = dyn_cast(outputState.getType()); + if (!outputStateType || !outputStateType.getElementType().isIntOrFloat()) + return op_.emitError() + << "expects output_state (result #0) " + << "to be a tensor of integer or floating-point type"; + + // dynamic_rng_bit_generator_o2 -+ auto outputType = output.getType().dyn_cast(); ++ auto outputType = dyn_cast(output.getType()); + if (!outputType || !outputType.getElementType().isIntOrFloat()) + return op_.emitError() + << "expects output (result #1) " @@ -812,25 +812,24 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh +} + +RngAlgorithm DynamicRngBitGeneratorOpAdaptor::getRngAlgorithm() { -+ return op_->getDiscardableAttr("rng_algorithm") -+ .cast() ++ return cast(op_->getDiscardableAttr("rng_algorithm")) + .getValue(); +} + +TypedValue DynamicRngBitGeneratorOpAdaptor::getInitialState() { -+ return op_.getInputs()[0].cast>(); ++ return cast>(op_.getInputs()[0]); +} + +TypedValue DynamicRngBitGeneratorOpAdaptor::getOutputShape() { -+ return op_.getInputs()[1].cast>(); ++ return cast>(op_.getInputs()[1]); +} + +TypedValue DynamicRngBitGeneratorOpAdaptor::getOutputState() { -+ return op_.getResults()[0].cast>(); ++ return cast>(op_.getResults()[0]); +} + +TypedValue DynamicRngBitGeneratorOpAdaptor::getOutput() { -+ return op_.getResults()[1].cast>(); ++ return cast>(op_.getResults()[1]); +} + +std::optional getDynamicRngBitGeneratorOp( @@ -864,7 +863,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + auto indices = op_.getResults()[1]; + + // dynamic_top_k_i1 -+ auto operandType = operand.getType().dyn_cast(); ++ auto operandType = dyn_cast(operand.getType()); + if (!operandType || !operandType.hasRank() || operandType.getRank() < 1 || + !operandType.getElementType().isIntOrFloat()) + return op_.emitError() @@ -873,7 +872,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + << "of rank at least 1"; + + // dynamic_top_k_i2 -+ auto kType = k.getType().dyn_cast(); ++ auto kType = dyn_cast(k.getType()); + if (!kType || !kType.hasRank() || kType.getRank() != 0 || + !kType.getElementType().isIntOrIndex()) + return op_.emitError() @@ -881,7 +880,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + << "to be a 0-dimensional tensor of integer or index type"; + + // dynamic_top_k_o1 -+ auto valuesType = values.getType().dyn_cast(); ++ auto valuesType = dyn_cast(values.getType()); + if (!valuesType || !valuesType.hasRank() || valuesType.getRank() < 1 || + !valuesType.getElementType().isIntOrFloat()) + return op_.emitError() @@ -890,7 +889,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh + << "of rank at least 1"; + + // dynamic_top_k_o2 -+ auto indicesType = indices.getType().dyn_cast(); ++ auto indicesType = dyn_cast(indices.getType()); + if (!indicesType || !indicesType.hasRank() || indicesType.getRank() < 1 || + !indicesType.getElementType().isSignlessInteger(32)) + return op_.emitError() << "expects indices (result #1) " @@ -930,19 +929,19 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh +} + +TypedValue DynamicTopKOpAdaptor::getOperand() { -+ return op_.getInputs()[0].cast>(); ++ return cast>(op_.getInputs()[0]); +} + +TypedValue DynamicTopKOpAdaptor::getK() { -+ return op_.getInputs()[1].cast>(); ++ return cast>(op_.getInputs()[1]); +} + +TypedValue DynamicTopKOpAdaptor::getValues() { -+ return op_.getResults()[0].cast>(); ++ return cast>(op_.getResults()[0]); +} + +TypedValue DynamicTopKOpAdaptor::getIndices() { -+ return op_.getResults()[1].cast>(); ++ return cast>(op_.getResults()[1]); +} + +std::optional getDynamicTopKOp(CustomCallOp op) { diff --git a/third_party/xla/xla/pjrt/pjrt_future.h b/third_party/xla/xla/pjrt/pjrt_future.h index 787710669a3491..66825f88c47c2a 100644 --- a/third_party/xla/xla/pjrt/pjrt_future.h +++ b/third_party/xla/xla/pjrt/pjrt_future.h @@ -281,26 +281,29 @@ class PjRtFutureBase : public PjRtFutureMoveControl< } // namespace internal // PjRtFuture is a simple future that is returned by PjRt APIs that -// enqueue asynchronous work, reporting a value of type T (frequently T=Status) -// when the work is complete. +// enqueue asynchronous work, reporting a value of type T when the work is +// complete. // // PjRtFuture can be used by the client to wait for work to complete, either via // a blocking call or a callback. // // The implementation wraps a tsl::AsyncValueRef, but we prefer to -// encapsulate the AVR rather than returning it directly for two reasons. +// encapsulate the AVR rather than returning it directly for three reasons. // -// First, we want to retain portability in case a future implementation moves +// First, in contrast to AsyncValueRef which has a smart-pointer semantics, +// future has more of a value semantics, i.e. future of a move-only type also +// is a move-only type. You can think of a move-only (unique) future as a box to +// pass a value of type T between asynchronous producer/consumer: you can open +// the box once to put the value into it and you can open the box only once to +// take the value out of it. For copyable types PjRtFuture is a copyable +// type, although all copies share the same underlying value. +// +// Second, we want to retain portability in case a future implementation moves // away from AsyncValueRef ---- we don't want clients to call arbitrary // AsyncValueRef APIs. // -// Second, we want to export different semantics, for example we support +// Third, we want to export different semantics, for example we support // integration between blocking and profiling (e.g., TraceMe). -// -// There are two ways to construct a PjRtFuture, one used by clients that -// natively use TSL concurrency library, which already have import APIs for -// constructing AsyncValueRefs; and another that avoids exposing TSL APIs and -// can be used by non-TSL clients. template class PjRtFuture : public internal::PjRtFutureBase { using Base = internal::PjRtFutureBase; @@ -391,22 +394,41 @@ class PjRtFuture : public internal::PjRtFutureBase { // The client should avoid any potentially re-entrant API calls within the // callback, for example by using the callback to enqueue work on a // client-owned threadpool. - void OnReady(absl::AnyInvocable callback) { + template && + !Base::is_unique()>* = nullptr> + void OnReady(F&& f) & { + CHECK(Base::IsValid()); + Base::promise().AndThen( + [promise = Base::promise(), f = std::forward(f)]() mutable { + DCHECK(promise.IsConcrete()); + f(*promise); + }); + } + + // Registers callback to be called once the promise is ready, with the final + // value. + // + // callback may be called on an internal system thread or the calling thread. + // The client should avoid any potentially re-entrant API calls within the + // callback, for example by using the callback to enqueue work on a + // client-owned threadpool. + template + : std::is_invocable_v>* = nullptr> + void OnReady(F&& f) && { CHECK(Base::IsValid()); Base::promise().AndThen( - [promise = Base::promise(), callback = std::move(callback)]() mutable { + [promise = Base::promise(), f = std::forward(f)]() mutable { DCHECK(promise.IsConcrete()); - if constexpr (std::is_copy_constructible_v) { - std::move(callback)(*promise); - return; + if constexpr (Base::is_unique()) { + f(std::move(*promise)); + } else { + // We can't move from the promise to the caller because for + // non-unique futures we can have multiple copies of the PjRtFuture + // sharing the same underlying promise object. + f(*promise); } - // For non-copyable types, we have no ways to check the number of - // waiters but we have to move the data into the consumer callback. - // Registering two callbacks will lead to double-move of the data. It - // is users' responsibility to make sure only one waiter is - // registered. - // TODO(yunlongl): Implement `PjRtUniqueFuture`. - std::move(callback)(std::move(*promise)); }); } }; diff --git a/third_party/xla/xla/pjrt/pjrt_future_test.cc b/third_party/xla/xla/pjrt/pjrt_future_test.cc index e4a67236248a15..b4b961cf3098f6 100644 --- a/third_party/xla/xla/pjrt/pjrt_future_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_future_test.cc @@ -86,6 +86,25 @@ TEST(PjRtFutureTest, AwaitMoveOnlyFuture) { EXPECT_EQ(*std::move(future).Await(), 42); } +TEST(PjRtFutureTest, OnReadyRvalueFuture) { + auto promise = PjRtFuture::CreatePromise(); + PjRtFuture future(promise); + + promise.Set(42); + + std::move(future).OnReady([](int32_t value) { EXPECT_EQ(value, 42); }); +} + +TEST(PjRtFutureTest, OnReadyMoveOnlyFuture) { + auto promise = PjRtFuture>::CreatePromise(); + PjRtFuture> future(promise); + + promise.Set(std::make_unique(42)); + + std::move(future).OnReady( + [](std::unique_ptr value) { EXPECT_EQ(*value, 42); }); +} + TEST(PjRtFutureTest, StatelessError) { auto promise = PjRtFuture<>::CreatePromise(); PjRtFuture<> future(promise); diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index d6f660ecf06113..e85761d45c8fa0 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -3181,6 +3181,50 @@ xla_cc_test( ], ) +cc_library( + name = "while_double_buffer_removal", + srcs = ["while_double_buffer_removal.cc"], + hdrs = ["while_double_buffer_removal.h"], + deps = [ + ":hlo_alias_analysis", + ":hlo_pass", + ":pattern_matcher", + ":tuple_simplifier", + ":while_loop_analysis", + ":while_loop_simplifier", + ":while_loop_unroller", + "//xla:literal_util", + "//xla:shape_util", + "//xla:statusor", + "//xla:util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "while_double_buffer_removal_test", + srcs = ["while_double_buffer_removal_test.cc"], + deps = [ + ":while_double_buffer_removal", + "//xla:literal", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:test_utils", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "while_loop_unroller", srcs = ["while_loop_unroller.cc"], diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index d790a876786484..d951ae8c385e44 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1301,6 +1301,7 @@ cc_library( srcs = ["triton_fusion_analysis.cc"], hdrs = ["triton_fusion_analysis.h"], deps = [ + ":cudnn_support_utils", ":matmul_utils", ":triton_tiling_propagation", "//xla:autotuning_proto_cc", @@ -2947,9 +2948,11 @@ cc_library( hdrs = if_cuda_is_configured(["cudnn_fusion_compiler.h"]), deps = if_cuda_is_configured([ ":backend_configs_cc", + ":cudnn_support_utils", ":ir_emission_utils", ":kernel_reuse_cache", ":matmul_utils", + ":triton_fusion_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", @@ -2962,7 +2965,6 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", - "//xla/service/gpu:triton_fusion_analysis", "//xla/service:hlo_pass", "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor/cuda:cudnn_frontend_helpers", diff --git a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc b/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc index 2c3f4ec7221b30..9911a3fc3ed25b 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc +++ b/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_query.h" #include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/cudnn_support_utils.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/gpu/matmul_utils.h" @@ -371,6 +372,21 @@ absl::StatusOr> HloFusionToCuDnnGraph( if (hlo->opcode() == HloOpcode::kParameter) { CHECK(hlo_to_cudnn.contains(hlo)); continue; + } else if (hlo->opcode() == HloOpcode::kCustomCall) { + if (hlo->user_count() != 1 || + !IsWorkspaceAllocationRoot(*hlo->users()[0])) { + VLOG(3) << "Custom calls are only expected to be used for workspace " + "allocation."; + return std::nullopt; + } + continue; + } else if (hlo->opcode() == HloOpcode::kTuple) { + if (!IsWorkspaceAllocationRoot(*hlo)) { + VLOG(3) << "Tuples are only expected at outputs for workspace " + "allocation."; + return std::nullopt; + } + continue; } else if (hlo->opcode() == HloOpcode::kReshape || hlo->opcode() == HloOpcode::kBitcast || hlo->opcode() == HloOpcode::kTranspose || @@ -466,6 +482,7 @@ absl::StatusOr PrepareGraph( if (!graph.has_value()) { return absl::InternalError("Construction of cuDNN graph failed."); } + VLOG(6) << graph->Graph().print(); TF_ASSIGN_OR_RETURN(bool supported, graph->Prepare(dnn_support)); if (!supported) { return absl::InternalError("cuDNN graph is not supported."); @@ -473,6 +490,28 @@ absl::StatusOr PrepareGraph( return *graph; } +StatusOr AddWorkspace(HloInstruction& fusion, + const int64_t workspace_size) { + if (workspace_size == 0 || fusion.shape().IsTuple()) { + return &fusion; + } + HloComputation* computation = fusion.fused_instructions_computation(); + HloInstruction* custom_call = + computation->AddInstruction(HloInstruction::CreateCustomCall( + ShapeUtil::MakeShape(S8, {workspace_size}), {}, + kWorkspaceAllocationCustomCallTarget)); + HloInstruction* output_tuple = + computation->AddInstruction(HloInstruction::CreateTuple( + {computation->root_instruction(), custom_call})); + computation->set_root_instruction(output_tuple, true); + HloInstruction* new_fusion = fusion.parent()->AddInstruction( + fusion.CloneWithNewShape(output_tuple->shape())); + TF_RETURN_IF_ERROR(fusion.ReplaceAllUsesWith(fusion.parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(new_fusion, 0)))); + TF_RETURN_IF_ERROR(fusion.parent()->RemoveInstruction(&fusion)); + return new_fusion; +} + class CuDnnFusionVisitor : public DfsHloRewriteVisitor { public: explicit CuDnnFusionVisitor( @@ -495,10 +534,11 @@ class CuDnnFusionVisitor : public DfsHloRewriteVisitor { VLOG(4) << "Processing " << hlo->ToString(); VLOG(4) << "Plan ID: " << plan_id; - const std::string cache_key = + const std::string fingerprint_without_workspace = GetComputationFingerprint(hlo->fused_instructions_computation(), {}); - std::string& cache_entry = compilation_results_[cache_key]; - if (cache_entry.empty()) { + auto workspace_size_it = + workspace_sizes_.find(fingerprint_without_workspace); + if (workspace_size_it == workspace_sizes_.cend()) { TF_ASSIGN_OR_RETURN( se::gpu::CudnnGraph graph, PrepareGraph(dnn_support_, *DynCast(hlo))); @@ -524,19 +564,22 @@ class CuDnnFusionVisitor : public DfsHloRewriteVisitor { return absl::InternalError("No cuDNN plans can be built."); } } - - if (graph.Graph().get_workspace_size() != 0) { - return absl::UnimplementedError( - "Support of workspace allocation is not added yet."); - } + const int64_t workspace_size = graph.Graph().get_workspace_size(); + workspace_sizes_.insert(workspace_size_it, + {fingerprint_without_workspace, workspace_size}); + TF_ASSIGN_OR_RETURN(hlo, AddWorkspace(*hlo, workspace_size)); std::vector serialized_graph; RETURN_IF_CUDNN_FRONTEND_ERROR(graph.Graph().serialize(serialized_graph)); - cache_entry = + // Compute a new fingerprint with a potential workspace for the + // compilation results to match a fingerprint computed by the emitter. + compilation_results_[GetComputationFingerprint( + hlo->fused_instructions_computation(), {})] = std::string(reinterpret_cast(serialized_graph.data()), serialized_graph.size()); } else { VLOG(4) << "Cache hit."; + TF_ASSIGN_OR_RETURN(hlo, AddWorkspace(*hlo, workspace_size_it->second)); } auto cudnn_config = gpu_config.mutable_fusion_backend_config() ->mutable_cudnn_fusion_config(); @@ -551,6 +594,7 @@ class CuDnnFusionVisitor : public DfsHloRewriteVisitor { se::dnn::DnnSupport& dnn_support_; // . CuDnnFusionCompiler::BinaryMap& compilation_results_; + absl::flat_hash_map workspace_sizes_; }; } // namespace diff --git a/third_party/xla/xla/service/gpu/cudnn_support_utils.cc b/third_party/xla/xla/service/gpu/cudnn_support_utils.cc index 864943884a56f9..30ac372a86ff9e 100644 --- a/third_party/xla/xla/service/gpu/cudnn_support_utils.cc +++ b/third_party/xla/xla/service/gpu/cudnn_support_utils.cc @@ -212,5 +212,12 @@ CudnnInferTransposeForBiasReordering(const Shape& shape) { return CudnnReorderTransposeConfig{split_shape, shape, permutation}; } +bool IsWorkspaceAllocationRoot(const HloInstruction& root) { + return root.IsRoot() && root.opcode() == HloOpcode::kTuple && + root.operand_count() == 2 && + root.operand(1)->IsCustomCall(kWorkspaceAllocationCustomCallTarget) && + root.operand(1)->operand_count() == 0; +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/cudnn_support_utils.h b/third_party/xla/xla/service/gpu/cudnn_support_utils.h index 780e0593f9b366..4a1f362b16e168 100644 --- a/third_party/xla/xla/service/gpu/cudnn_support_utils.h +++ b/third_party/xla/xla/service/gpu/cudnn_support_utils.h @@ -70,6 +70,12 @@ CudnnInferTransposeForFilterReordering( absl::StatusOr CudnnInferTransposeForBiasReordering(const Shape& shape); +inline constexpr absl::string_view kWorkspaceAllocationCustomCallTarget = + "__nop"; + +// Detects `ROOT tuple(..., custom-call())` used to allocate workspace buffers. +bool IsWorkspaceAllocationRoot(const HloInstruction& root); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 66adb638abd213..ba0fbf21cbd411 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -677,7 +677,11 @@ xla_test( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:executable", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:cudnn_fusion_compiler", "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/runtime:thunk", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:stream_executor_headers", "//xla/tests:filecheck", diff --git a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc index 05deee92233dff..cf746848b8d8d1 100644 --- a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc @@ -26,8 +26,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" #include "xla/service/executable.h" +#include "xla/service/gpu/cudnn_fusion_compiler.h" +#include "xla/service/gpu/runtime/thunk.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream_executor_pimpl.h" #include "xla/tests/filecheck.h" @@ -75,6 +79,44 @@ class CuDnnFusionTest : public GpuCodegenTest { using CuDnnFusionExecutionTest = CuDnnFusionTest; +namespace m = ::xla::match; + +TEST_F(CuDnnFusionExecutionTest, WorkspaceAllocationWorks) { + if (!IsAtLeastCuDnn91()) { + GTEST_SKIP() << "This test case requests a workspace only with cuDNN 9.1+."; + } + const std::string kHloText = R"( +fusion1 { + p0 = f32[32,96] parameter(0) + p1 = f32[96,64] parameter(1) + ROOT r = f32[32,64] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[32,96] parameter(0) + p1 = f32[96,64] parameter(1) + ROOT _ = f32[32,64] fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + Thunk::BinaryMap dnn_compiled_graphs; + CuDnnFusionCompiler cudnn_compiler(*backend().default_stream_executor(), + dnn_compiled_graphs); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cudnn_compiler.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::GetTupleElement(m::Fusion()))); + EXPECT_THAT(module->entry_computation() + ->root_instruction() + ->operand(0) + ->fused_instructions_computation() + ->root_instruction(), + GmockMatch(m::Tuple(m::Dot(), m::CustomCall()))); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + TEST_F(CuDnnFusionExecutionTest, NoTritonConfigIsAssignedAtZeroAutotuningLevel) { EXPECT_EQ(GetDebugOptionsForTest().xla_gpu_autotune_level(), 0); @@ -347,8 +389,8 @@ ENTRY e { ; CHECK: ENTRY ; CHECK-NEXT: parameter ; CHECK-NEXT: parameter -; CHECK-NEXT: ROOT -; CHECK-SAME: command_buffer +; CHECK: command_buffer +; CHECK-NOT: fusion )"); TF_ASSERT_OK(filecheck_result.status()); EXPECT_TRUE(filecheck_result.value()); diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc index dc146b235313c6..cf4cc1b3d44ee3 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc @@ -80,9 +80,9 @@ int GetVectorSize(const HloFusionAnalysis& analysis, return 1; } - // Enabling vectorization if number of threads is <= warpsize leads to half or - // more of the threads not doing any work. - if (num_threads <= WarpSize()) { + // Enabling vectorization if (number_threads * vector_size) is <= + // minor_reduced_dimension otherwise exist threads not doing any work. + if (num_threads * 2 > reduction_dimensions.dimensions[kRowMinorReduced]) { return 1; } diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_base_test.cc index f0e2d914a3e8cd..98bbdd115b1638 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_base_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_base_test.cc @@ -83,10 +83,10 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { EXPECT_THAT( fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( (d3 * 8 + d0 floordiv 32) floordiv 64, (d3 * 8 + d0 floordiv 32) mod 64, - d0 mod 32 + s2 * 32 + (d0 mod 32 + s2 * 32) * 2 + s3 ) domain: d0 in [0, 255] @@ -97,8 +97,9 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { d5 in [0, 0] s0 in [0, 0] s1 in [0, 0] - s2 in [0, 15] - d0 mod 32 + s2 * 32 in [0, 511] + s2 in [0, 7] + s3 in [0, 1] + d0 mod 32 + s2 * 32 in [0, 255] d3 * 8 + d0 floordiv 32 in [0, 6399] )")); EXPECT_THAT( @@ -319,10 +320,10 @@ TEST_F(ReductionTest, ThreadIndexingSideOutput) { mlir::MLIRContext mlir_context; constexpr char kExpectedIndexing[] = R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( (d3 * 8 + d0 floordiv 32) floordiv 64, (d3 * 8 + d0 floordiv 32) mod 64, - d0 mod 32 + s2 * 32 + (d0 mod 32 + s2 * 32) * 2 + s3 ) domain: d0 in [0, 255] @@ -333,8 +334,9 @@ TEST_F(ReductionTest, ThreadIndexingSideOutput) { d5 in [0, 0] s0 in [0, 0] s1 in [0, 0] - s2 in [0, 15] - d0 mod 32 + s2 * 32 in [0, 511] + s2 in [0, 7] + s3 in [0, 1] + d0 mod 32 + s2 * 32 in [0, 255] d3 * 8 + d0 floordiv 32 in [0, 6399] )"; EXPECT_THAT( diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 4588038782f623..21ab2e706b33ff 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -2921,6 +2921,9 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( if (instr->custom_call_target() == "__gpu$xla.gpu.triton") { return EmitTritonCustomCall(custom_call); } + if (instr->custom_call_target() == kNopCustomCallTarget) { + return absl::OkStatus(); + } return EmitCustomCallThunk(custom_call); } case HloOpcode::kFusion: { diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc index 8f84b2c3f33f4a..282c5146af68e9 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc @@ -262,19 +262,19 @@ void AssignValuesToOuterLoopIVs(IndexingMap* indexing_map) { } MLIRContext* mlir_context = indexing_map->GetMLIRContext(); llvm::SmallVector symbol_replacements; - for (const RangeVar& range_var : indexing_map->GetRangeVars()) { - symbol_replacements.push_back( - getAffineConstantExpr(range_var.range.lower, mlir_context)); + for (int64_t symbol_id = 0; symbol_id < indexing_map->GetRangeVarsCount() - 1; + ++symbol_id) { + symbol_replacements.push_back(getAffineConstantExpr( + indexing_map->GetRangeVar(symbol_id).range.lower, mlir_context)); } - symbol_replacements.push_back(mlir::getAffineSymbolExpr( - indexing_map->GetRangeVarsCount() - 1, mlir_context)); + symbol_replacements.push_back(mlir::getAffineSymbolExpr(0, mlir_context)); AffineMap thread_x_to_input_no_dim_symbols = indexing_map->GetAffineMap().replaceDimsAndSymbols( {}, symbol_replacements, indexing_map->GetDimVarsCount(), 1); *indexing_map = IndexingMap{thread_x_to_input_no_dim_symbols, indexing_map->GetDimVars(), - indexing_map->GetRangeVars(), + {indexing_map->GetRangeVars().back()}, {}}; indexing_map->Simplify(GetIndexingMapForInstruction); indexing_map->RemoveUnusedSymbols(); diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 84907258603ca7..81f0e2e6d82a29 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -1080,3 +1080,15 @@ xla_cc_test( "@local_tsl//tsl/platform:test_main", ], ) + +xla_cc_test( + name = "nop_custom_call_test", + srcs = ["nop_custom_call_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + "//xla:xla_proto_cc", + "//xla/service:gpu_plugin", + "//xla/tests:hlo_test_base", + "@local_tsl//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/xla/service/gpu/tests/nop_custom_call_test.cc b/third_party/xla/xla/service/gpu/tests/nop_custom_call_test.cc new file mode 100644 index 00000000000000..0da317ef3d1398 --- /dev/null +++ b/third_party/xla/xla/service/gpu/tests/nop_custom_call_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { +namespace { + +class NopCustomCallTest : public HloTestBase {}; + +TEST_F(NopCustomCallTest, RunAllocateBufferAndUpdate) { + // The test uses a custom call with the AllocateBuffer target (also known as + // kNopCustomCallTarget) to allocate an output buffer. Then it verifies + // we can successfully modify the buffer. + const char* hlo_text = R"( + HloModule AllocateBuffer, is_scheduled=true + + overwrite_one { + p0 = s32[1] parameter(0) + c0 = s32[] constant(0) + c1 = s32[1] constant({1}) + ROOT dus0 = s32[1] dynamic-update-slice(p0, c1, c0) + } + + ENTRY main { + buffer = s32[1] custom-call(), custom_call_target="AllocateBuffer" + ROOT fusion = s32[1] fusion(buffer), kind=kLoop, calls=overwrite_one + })"; + auto module = ParseAndReturnVerifiedModule(hlo_text).value(); + + Literal result = ExecuteNoHloPasses(std::move(module), {}); + Literal expected = LiteralUtil::CreateR1({1}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc index e18c037139a945..d63a5c186e05c1 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/cudnn_support_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/triton_tiling_propagation.h" #include "xla/service/instruction_fusion.h" @@ -298,6 +299,11 @@ absl::Status TritonFusionAnalysis::ExecuteForDotFusion( while (!output->IsRoot()) { TF_RET_CHECK(output->user_count() == 1); const HloInstruction* input = output; + // Tuple with a custom call can be added at root to allocate a workspace + // buffer. These do not need to participate in propagation of dimensions. + if (IsWorkspaceAllocationRoot(*output->users()[0])) { + break; + } output = output->users()[0]; DimOrdersAndReqsOrError result = GetPropagatedDimOrdersAndRequirements( *output, context.dim_orders().at(input), diff --git a/third_party/xla/xla/service/while_double_buffer_removal.cc b/third_party/xla/xla/service/while_double_buffer_removal.cc new file mode 100644 index 00000000000000..d6db5481a07c9c --- /dev/null +++ b/third_party/xla/xla/service/while_double_buffer_removal.cc @@ -0,0 +1,346 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/while_double_buffer_removal.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/tuple_simplifier.h" +#include "xla/service/while_loop_analysis.h" +#include "xla/service/while_loop_simplifier.h" +#include "xla/service/while_loop_unroller.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +struct LoopInfo { + // Tuple index into the loop's parameter tuple of the induction variable. + int64_t indvar_index; + + // Loop trip count. + int64_t trip_count; +}; + +// To guarantee that the entire shape is written to, all indices must be +// zero except for one, which must be the loop induction variable. +bool MatchDynamicSliceInDim(HloInstruction* ds, const LoopInfo& loop_info) { + // Check that the DUS is a DynamicUpdateSlice. + HloInstruction* to_be_sliced; + if (!Match(ds, + match::DynamicSlice().WithOperand(0, match::Op(&to_be_sliced)))) { + std::cout << ds->name() << " here4" << std::endl; + return false; + } + + if (!Match(to_be_sliced, match::GetTupleElement())) { + std::cout << ds->name() << " here5" << std::endl; + return false; + } + + int64_t ds_dim = -1; + for (int64_t operand_index = 1; operand_index < ds->operand_count(); + ++operand_index) { + HloInstruction* operand = ds->mutable_operand(operand_index); + // All constants must be zero in order to write the entire shape. + if (Match(operand, match::ConstantScalar())) { + std::optional offset = + LiteralUtil::LiteralAsScalarInt64(operand->literal()); + if (offset.value() != 0) { + ds_dim = -1; + break; + } + } + + HloInstruction* slice_offset = operand; + // Check that the update offset is the loop induction variable. + if (Match(slice_offset, match::GetTupleElement(match::Parameter(), + loop_info.indvar_index))) { + ds_dim = operand_index - 1; + } + } + + if (ds_dim == -1) { + std::cout << ds->name() << " here6" << std::endl; + return false; + } + + // The shape's broadcast_dim must be exactly equal to the loop trip count. + if (to_be_sliced->shape().dimensions(ds_dim) != loop_info.trip_count) { + std::cout << ds->name() << " here7" << std::endl; + return false; + } + + return true; +} + +// To guarantee that the entire shape is written to, all indices must be +// zero except for one, which must be the loop induction variable. +bool MatchDynamicUpdateSliceInDim(HloInstruction* dus, HloInstruction* user, + const LoopInfo& loop_info) { + // Check that the DUS is a DynamicUpdateSlice. + HloInstruction* to_be_updated; + if (!Match(dus, match::DynamicUpdateSlice().WithOperand( + 0, match::Op(&to_be_updated)))) { + return false; + } + if (to_be_updated != user) { + return false; + } + + int64_t dus_dim = -1; + for (int64_t operand_index = 2; operand_index < dus->operand_count(); + ++operand_index) { + HloInstruction* operand = dus->mutable_operand(operand_index); + // All constants must be zero in order to write the entire shape. + if (Match(operand, match::ConstantScalar())) { + std::optional offset = + LiteralUtil::LiteralAsScalarInt64(operand->literal()); + if (offset.value() != 0) { + dus_dim = -1; + break; + } + } + // Check that the update offset is the loop induction variable. + if (Match(operand, match::GetTupleElement(match::Parameter(), + loop_info.indvar_index))) { + dus_dim = operand_index - 2; + } + } + + if (dus_dim == -1) { + return false; + } + + // The shape's broadcast_dim must be exactly equal to the loop trip count. + if (user->shape().dimensions(dus_dim) != loop_info.trip_count) { + return false; + } + + return true; +} + +bool LoopIndexIsReadOnly(const HloAliasAnalysis& alias_analysis, + HloInstruction* while_instr, int64_t idx) { + const HloDataflowAnalysis& dataflow_analysis = + alias_analysis.dataflow_analysis(); + return !( + dataflow_analysis.GetValueSet(while_instr->while_init(), {idx}) + .values() + .size() > 1 || + dataflow_analysis.GetValueSet(while_instr, {idx}).values().size() > 1 || + dataflow_analysis.GetUniqueValueAt(while_instr, {idx}) != + dataflow_analysis.GetUniqueValueAt(while_instr->while_init(), {idx})); +} + +std::vector> LoopHasDoubleBuffer( + const HloAliasAnalysis& alias_analysis, HloInstruction* while_instr, + const LoopInfo& loop_info) { + HloComputation* computation = while_instr->while_body(); + HloInstruction* body_param = computation->parameter_instruction(0); + + // Finding the buffer indices + std::vector possible_buffers; + for (int64_t param_idx = 0; + param_idx < while_instr->while_init()->operand_count(); ++param_idx) { + for (HloInstruction* gte : body_param->users()) { + if (!Match(gte, match::GetTupleElement().WithTupleIndex(param_idx))) { + continue; + } + if (gte->operand(0) != body_param) { + continue; + } + // The buffer should only be updated. + if (gte->user_count() > 1) { + continue; + } + for (HloInstruction* gte_user : gte->users()) { + if (MatchDynamicUpdateSliceInDim(gte_user, gte, loop_info)) { + // The buffer should be written at the same index + if (computation->root_instruction()->mutable_operand(param_idx) == + gte_user) { + possible_buffers.push_back(gte); + std::cout << "buffer index: " << param_idx + << ", shape = " << gte->shape().ToString() << gte->name() + << ", update_value = " + << gte_user->mutable_operand(1)->name() << std::endl; + } + } + } + } + } + + // Finding the input indices + std::vector possible_inputs; + for (int64_t param_idx = 0; + param_idx < while_instr->while_init()->operand_count(); ++param_idx) { + for (HloInstruction* gte : body_param->users()) { + if (!Match(gte, match::GetTupleElement().WithTupleIndex(param_idx))) { + continue; + } + if (gte->operand(0) != body_param) { + continue; + } + + // The input should only be sliced and passed to the next iteration. + if (gte->user_count() > 2) { + continue; + } + + for (HloInstruction* gte_user : gte->users()) { + std::cout << "checking: " << gte->name() << std::endl; + if (MatchDynamicSliceInDim(gte_user, loop_info)) { + // The input should be read-only + if (LoopIndexIsReadOnly(alias_analysis, while_instr, + gte->tuple_index())) { + possible_inputs.push_back(gte); + std::cout << "input" << " index: " << param_idx + << ", shape = " << gte->shape().ToString() << gte->name() + << std::endl; + } + } + } + } + } + + std::vector> out; + std::vector unique_inputs; + for (HloInstruction* buffer : possible_buffers) { + for (HloInstruction* input : possible_inputs) { + if (ShapeUtil::Equal(input->shape(), buffer->shape())) { + // Make sure all the inputs are unique, if we encounter a used input, we + // move over to the next candidate. + if (absl::c_find(unique_inputs, input) != unique_inputs.end()) { + continue; + } + unique_inputs.push_back(input); + out.emplace_back(buffer, input); + std::cout << buffer->GetModule()->unique_id() + << " matching: " << buffer->name() << " with " + << input->name() << std::endl; + break; + } + } + } + return out; +} + +absl::StatusOr RemoveDoubleBuffers(HloModule* module, int64_t id) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + // if (id != module->unique_id()) { + // return false; + // } + std::cout << "Removing double buffer for " << id << std::endl; + std::vector while_instrs; + for (auto* comp : module->computations()) { + absl::c_copy_if( + comp->instructions(), std::back_inserter(while_instrs), + [&](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile && + // We only remove double buffers in nested loops for now. + instr->parent()->IsWhileBodyComputation(); + }); + } + bool changed = false; + for (HloInstruction* while_instr : while_instrs) { + bool removed = false; + std::optional indvar_index = + GetLoopInductionVarTupleIdx(while_instr); + std::optional trip_count = + ComputeWhileLoopTripCount(while_instr, /*max_brute_force_iters=*/0); + if (indvar_index.has_value() && trip_count.has_value()) { + std::cout << "loop: " << while_instr->name() << " -> " + << trip_count.value() << std::endl; + LoopInfo loop_info{*indvar_index, *trip_count}; + auto out = LoopHasDoubleBuffer(*alias_analysis, while_instr, loop_info); + for (const auto& [buffer, input] : out) { + // We only consider buffers that are allocated inside the loop. + // Therefore, we skip buffers that are passed as the loop input. + if (Match(while_instr->while_init()->mutable_operand( + buffer->tuple_index()), + match::GetTupleElement(match::Parameter()))) { + continue; + } + std::cout << while_instr->name() << " -> " + << "name() << ", " + << "input: " << input->name() << ">" << std::endl; + TF_RETURN_IF_ERROR(input->ReplaceAllUsesWith(buffer)); + TF_RETURN_IF_ERROR(while_instr->while_init()->ReplaceOperandWith( + buffer->tuple_index(), + while_instr->while_init()->mutable_operand(input->tuple_index()))); + if (input->user_count() == 0) { + TF_RETURN_IF_ERROR( + while_instr->while_body()->RemoveInstruction(input)); + removed = true; + } + } + if (removed) { + TF_RETURN_IF_ERROR(TryRemoveDeadWhileParams(while_instr).status()); + changed = true; + } + std::cout << "======================================" << std::endl; + } + } + if (changed) { + TF_RETURN_IF_ERROR(TupleSimplifier{}.Run(module).status()); + TF_RETURN_IF_ERROR(module->RemoveUnusedComputations()); + } + return changed; +} + +} // namespace + +absl::StatusOr WhileDoubleBufferRemoval::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + VLOG(2) << "HLO module before WhileDoubleBufferRemoval:"; + XLA_VLOG_LINES(2, module->ToString()); + + // TODO: we might want to simplify compare instructions before this. It helps + // us identify more inputs and buffer + TF_ASSIGN_OR_RETURN(bool changed, RemoveDoubleBuffers(module, module_id_)); + + if (changed) { + VLOG(2) << "HLO module after WhileDoubleBufferRemoval:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) << "HLO module unchanged after WhileDoubleBufferRemoval"; + } + + return changed; +} + +} // namespace xla diff --git a/third_party/xla/xla/service/while_double_buffer_removal.h b/third_party/xla/xla/service/while_double_buffer_removal.h new file mode 100644 index 00000000000000..78e34f5a477b73 --- /dev/null +++ b/third_party/xla/xla/service/while_double_buffer_removal.h @@ -0,0 +1,52 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_WHILE_DOUBLE_BUFFER_REMOVAL_H_ +#define XLA_SERVICE_WHILE_DOUBLE_BUFFER_REMOVAL_H_ + +#include +#include + +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/statusor.h" + +namespace xla { + +// This pass recognizes the following patterns in nested loops: +// TODO: write the pattern +// The buffer must be allocated inside the outer loop in order to be replaced. +class WhileDoubleBufferRemoval : public HloModulePass { + public: + ~WhileDoubleBufferRemoval() override = default; + + // Default unroll_factor of -1 indicates full unrolling + explicit WhileDoubleBufferRemoval(int64_t module_id) + : module_id_(module_id) {} + + absl::string_view name() const override { + return "while-double-buffer-removal"; + } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + const int64_t module_id_; +}; + +} // namespace xla + +#endif // XLA_SERVICE_WHILE_DOUBLE_BUFFER_REMOVAL_H_ diff --git a/third_party/xla/xla/service/while_double_buffer_removal_test.cc b/third_party/xla/xla/service/while_double_buffer_removal_test.cc new file mode 100644 index 00000000000000..d8a3c0a4fce4fa --- /dev/null +++ b/third_party/xla/xla/service/while_double_buffer_removal_test.cc @@ -0,0 +1,582 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/while_double_buffer_removal.h" + +#include + +#include +#include + +#include +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/test_utils.h" +#include "xla/tests/verified_hlo_module.h" + +namespace xla { +namespace { + +using WhileDoubleBufferRemovalTest = HloTestBase; + +TEST_F(WhileDoubleBufferRemovalTest, RemoveDoubleBuffer) { + [[maybe_unused]] constexpr char kModule[] = R"( + HloModule jit_scan + + wide.region_0.7 { + wide.arg_tuple.8 = (s32[], s32[], s32[8], s32[8]) parameter(0) + get-tuple-element.46 = s32[] get-tuple-element(wide.arg_tuple.8), index=0 + get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 + get-tuple-element.48 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 + get-tuple-element.54 = s32[8] get-tuple-element(wide.arg_tuple.8), index=3 + + dynamic-slice.0 = s32[1] dynamic-slice(get-tuple-element.54, get-tuple-element.46), dynamic_slice_sizes={1} + reshape.2 = s32[] reshape(dynamic-slice.0) + add.1 = s32[] add(get-tuple-element.47, reshape.2) + + reshape.3 = s32[1] reshape(add.1) + dynamic-update-slice.0 = s32[8] dynamic-update-slice(get-tuple-element.48, reshape.3, get-tuple-element.46) + + const = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.46, const) + ROOT tuple.10 = (s32[], s32[], s32[8], s32[8]) tuple(add.0, add.1, dynamic-update-slice.0, get-tuple-element.54) + } // wide.region_0.7 + + wide.region_1.29 { + constant.5 = s32[] constant(8) + wide.arg_tuple.30 = (s32[], s32[], s32[8], s32[8]) parameter(0) + get-tuple-element.16 = s32[] get-tuple-element(wide.arg_tuple.30), index=0 + ROOT compare.0 = pred[] compare(get-tuple-element.16, constant.5), direction=LT + } + + outer_body { + wide.arg_tuple.8 = (s32[], s32[], s32[8]) parameter(0) + get-tuple-element.46 = s32[] get-tuple-element(wide.arg_tuple.8), index=0 + get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 + get-tuple-element.48 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 + + constant.3 = s32[] constant(0) + broadcast = s32[8] broadcast(constant.3), dimensions={} + + tuple.8 = (s32[], s32[], s32[8], s32[8]) tuple(constant.3, get-tuple-element.47, broadcast, get-tuple-element.48) + while = (s32[], s32[], s32[8], s32[8]) while(tuple.8), condition=wide.region_1.29, body=wide.region_0.7 + get-tuple-element.40 = s32[8] get-tuple-element(while), index=2 + + const = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.46, const) + ROOT out = (s32[], s32[], s32[8]) tuple(add.0, get-tuple-element.47, get-tuple-element.40) + } + + outer_cond { + constant.5 = s32[] constant(8) + wide.arg_tuple.30 = (s32[], s32[], s32[8]) parameter(0) + get-tuple-element.16 = s32[] get-tuple-element(wide.arg_tuple.30), index=0 + ROOT compare.0 = pred[] compare(get-tuple-element.16, constant.5), direction=LT + } + + main.43 { + constant.3 = s32[] constant(0) + init = s32[] constant(0) + array = s32[8] constant({1,2,3,4,5,6,7,8}) + tuple.8 = (s32[], s32[], s32[8]) tuple(constant.3, init, array) + while = (s32[], s32[], s32[8]) while(tuple.8), condition=outer_cond, body=outer_body + ROOT get-tuple-element.40 = s32[8] get-tuple-element(while), index=2 + } // main.43 + + )"; + + auto module = ParseAndReturnVerifiedModule(kModule).value(); + + VLOG(3) << "before:\n" << module->ToString(); + + Literal reference = ExecuteAndTransfer(module->Clone(), {}); + std::cout << "before:\n" << reference.ToString() << std::endl; + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileDoubleBufferRemoval().Run(module.get())); + EXPECT_TRUE(simplified_loop); + + VLOG(3) << "after:\n" << module->ToString(); + + // Index 2 and 3 of the while are replaced with the input arrays. + for (HloInstruction* instr : + module->entry_computation()->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kWhile) { + EXPECT_EQ(instr->while_init()->operand(2)->opcode(), + HloOpcode::kConstant); + } + } + Literal actual = ExecuteAndTransfer(module->Clone(), {}); + + ASSERT_TRUE(LiteralTestUtil::NearOrEqual(/*expected=*/reference, + /*actual=*/actual, std::nullopt)); +} + +TEST_F(WhileDoubleBufferRemovalTest, RemoveDoubleBuffer2) { + [[maybe_unused]] constexpr char kModule[] = R"( + HloModule jit_scan + + wide.region_0.7 { + wide.arg_tuple.8 = (s32[], s32[], s32[8], s32[8], s32[8], s32[8]) parameter(0) + get-tuple-element.46 = s32[] get-tuple-element(wide.arg_tuple.8), index=0 + get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 + get-tuple-element.48 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 + get-tuple-element.54 = s32[8] get-tuple-element(wide.arg_tuple.8), index=3 + get-tuple-element.55 = s32[8] get-tuple-element(wide.arg_tuple.8), index=4 + get-tuple-element.56 = s32[8] get-tuple-element(wide.arg_tuple.8), index=5 + + dynamic-slice.0 = s32[1] dynamic-slice(get-tuple-element.54, get-tuple-element.46), dynamic_slice_sizes={1} + reshape.2 = s32[] reshape(dynamic-slice.0) + add.1 = s32[] add(get-tuple-element.47, reshape.2) + + reshape.3 = s32[1] reshape(add.1) + dynamic-update-slice.0 = s32[8] dynamic-update-slice(get-tuple-element.48, reshape.3, get-tuple-element.46) + + dynamic-slice.1 = s32[1] dynamic-slice(get-tuple-element.56, get-tuple-element.46), dynamic_slice_sizes={1} + reshape.4 = s32[] reshape(dynamic-slice.1) + add.2 = s32[] multiply(get-tuple-element.47, reshape.4) + + reshape.5 = s32[1] reshape(add.2) + dynamic-update-slice.1 = s32[8] dynamic-update-slice(get-tuple-element.55, reshape.5, get-tuple-element.46) + + const = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.46, const) + ROOT tuple.10 = (s32[], s32[], s32[8], s32[8], s32[8], s32[8]) tuple(add.0, add.1, dynamic-update-slice.0, get-tuple-element.54, dynamic-update-slice.1, get-tuple-element.56) + } // wide.region_0.7 + + wide.region_1.29 { + constant.5 = s32[] constant(8) + wide.arg_tuple.30 = (s32[], s32[], s32[8], s32[8], s32[8], s32[8]) parameter(0) + get-tuple-element.16 = s32[] get-tuple-element(wide.arg_tuple.30), index=0 + ROOT compare.0 = pred[] compare(get-tuple-element.16, constant.5), direction=LT + } + + outer_body { + wide.arg_tuple.8 = (s32[], s32[], s32[8], s32[8]) parameter(0) + get-tuple-element.46 = s32[] get-tuple-element(wide.arg_tuple.8), index=0 + get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 + get-tuple-element.48 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 + get-tuple-element.54 = s32[8] get-tuple-element(wide.arg_tuple.8), index=3 + + constant.3 = s32[] constant(0) + broadcast = s32[8] broadcast(constant.3), dimensions={} + broadcast2 = s32[8] broadcast(constant.3), dimensions={} + + tuple.8 = (s32[], s32[], s32[8], s32[8], s32[8], s32[8]) tuple(constant.3, get-tuple-element.47, broadcast, get-tuple-element.48, broadcast2, get-tuple-element.54) + while = (s32[], s32[], s32[8], s32[8], s32[8], s32[8]) while(tuple.8), condition=wide.region_1.29, body=wide.region_0.7 + get-tuple-element.40 = s32[8] get-tuple-element(while), index=2 + get-tuple-element.41 = s32[8] get-tuple-element(while), index=4 + + const = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.46, const) + ROOT out = (s32[], s32[], s32[8], s32[8]) tuple(add.0, get-tuple-element.47, get-tuple-element.40, get-tuple-element.41) + } + + outer_cond { + constant.5 = s32[] constant(8) + wide.arg_tuple.30 = (s32[], s32[], s32[8], s32[8]) parameter(0) + get-tuple-element.16 = s32[] get-tuple-element(wide.arg_tuple.30), index=0 + ROOT compare.0 = pred[] compare(get-tuple-element.16, constant.5), direction=LT + } + + main.43 { + constant.3 = s32[] constant(0) + init = s32[] constant(0) + array = s32[8] constant({1,2,3,4,5,6,7,8}) + array2 = s32[8] constant({10,20,30,40,50,60,70,80}) + tuple.8 = (s32[], s32[], s32[8], s32[8]) tuple(constant.3, init, array, array2) + while = (s32[], s32[], s32[8], s32[8]) while(tuple.8), condition=outer_cond, body=outer_body + get-tuple-element.40 = s32[8] get-tuple-element(while), index=2 + get-tuple-element.41 = s32[8] get-tuple-element(while), index=3 + ROOT out = (s32[8],s32[8]) tuple(get-tuple-element.40, get-tuple-element.41) + } // main.43 + + )"; + + auto module = ParseAndReturnVerifiedModule(kModule).value(); + + VLOG(3) << "before:\n" << module->ToString(); + + Literal reference = ExecuteAndTransfer(module->Clone(), {}); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileDoubleBufferRemoval().Run(module.get())); + EXPECT_TRUE(simplified_loop); + + VLOG(3) << "after:\n" << module->ToString(); + + // Index 2 and 3 of the while are replaced with the input arrays. + for (HloInstruction* instr : + module->entry_computation()->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kWhile) { + EXPECT_EQ(instr->while_init()->operand(2)->opcode(), + HloOpcode::kConstant); + EXPECT_EQ(instr->while_init()->operand(3)->opcode(), + HloOpcode::kConstant); + } + } + Literal actual = ExecuteAndTransfer(module->Clone(), {}); + + ASSERT_TRUE(LiteralTestUtil::NearOrEqual(/*expected=*/reference, + /*actual=*/actual, std::nullopt)); +} + +TEST_F(WhileDoubleBufferRemovalTest, BufferAllocateOutside) { + [[maybe_unused]] constexpr char kModule[] = R"( + HloModule jit_scan + + wide.region_0.7 { + wide.arg_tuple.8 = (s32[], s32[], s32[8], s32[8]) parameter(0) + get-tuple-element.46 = s32[] get-tuple-element(wide.arg_tuple.8), index=0 + get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 + get-tuple-element.48 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 + get-tuple-element.54 = s32[8] get-tuple-element(wide.arg_tuple.8), index=3 + + dynamic-slice.0 = s32[1] dynamic-slice(get-tuple-element.54, get-tuple-element.46), dynamic_slice_sizes={1} + reshape.2 = s32[] reshape(dynamic-slice.0) + add.1 = s32[] add(get-tuple-element.47, reshape.2) + + reshape.3 = s32[1] reshape(add.1) + dynamic-update-slice.0 = s32[8] dynamic-update-slice(get-tuple-element.48, reshape.3, get-tuple-element.46) + + const = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.46, const) + ROOT tuple.10 = (s32[], s32[], s32[8], s32[8]) tuple(add.0, add.1, dynamic-update-slice.0, get-tuple-element.54) + } // wide.region_0.7 + + wide.region_1.29 { + constant.5 = s32[] constant(8) + wide.arg_tuple.30 = (s32[], s32[], s32[8], s32[8]) parameter(0) + get-tuple-element.16 = s32[] get-tuple-element(wide.arg_tuple.30), index=0 + ROOT compare.0 = pred[] compare(get-tuple-element.16, constant.5), direction=LT + } + + outer_body { + wide.arg_tuple.8 = (s32[], s32[], s32[8], s32[8]) parameter(0) + get-tuple-element.46 = s32[] get-tuple-element(wide.arg_tuple.8), index=0 + get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 + get-tuple-element.48 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 + get-tuple-element.54 = s32[8] get-tuple-element(wide.arg_tuple.8), index=3 + + constant.3 = s32[] constant(0) + tuple.8 = (s32[], s32[], s32[8], s32[8]) tuple(constant.3, get-tuple-element.47, get-tuple-element.54, get-tuple-element.48) + while = (s32[], s32[], s32[8], s32[8]) while(tuple.8), condition=wide.region_1.29, body=wide.region_0.7 + get-tuple-element.40 = s32[8] get-tuple-element(while), index=2 + + const = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.46, const) + ROOT out = (s32[], s32[], s32[8], s32[8]) tuple(add.0, get-tuple-element.47, get-tuple-element.48, get-tuple-element.40) + } + + outer_cond { + constant.5 = s32[] constant(8) + wide.arg_tuple.30 = (s32[], s32[], s32[8], s32[8]) parameter(0) + get-tuple-element.16 = s32[] get-tuple-element(wide.arg_tuple.30), index=0 + ROOT compare.0 = pred[] compare(get-tuple-element.16, constant.5), direction=LT + } + + main.43 { + constant.3 = s32[] constant(0) + init = s32[] constant(0) + array = s32[8] constant({1,2,3,4,5,6,7,8}) + buffer = s32[8] broadcast(constant.3), dimensions={} + tuple.8 = (s32[], s32[], s32[8], s32[8]) tuple(constant.3, init, array, buffer) + while = (s32[], s32[], s32[8], s32[8]) while(tuple.8), condition=outer_cond, body=outer_body + ROOT get-tuple-element.40 = s32[8] get-tuple-element(while), index=3 + } // main.43 + + )"; + + auto module = ParseAndReturnVerifiedModule(kModule).value(); + + VLOG(3) << "before:\n" << module->ToString(); + + Literal reference = ExecuteAndTransfer(module->Clone(), {}); + std::cout << "before:\n" << reference.ToString() << std::endl; + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileDoubleBufferRemoval().Run(module.get())); + // Buffer is not replaced with input since it is allocated outside the outer + // loop. + EXPECT_FALSE(simplified_loop); +} + +TEST_F(WhileDoubleBufferRemovalTest, InputDifferentShape) { + [[maybe_unused]] constexpr char kModule[] = R"( + HloModule jit_scan + + wide.region_0.7 { + wide.arg_tuple.8 = (s32[], s32[], s32[8], s32[8,10]) parameter(0) + get-tuple-element.46 = s32[] get-tuple-element(wide.arg_tuple.8), index=0 + get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 + get-tuple-element.48 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 + get-tuple-element.54 = s32[8,10] get-tuple-element(wide.arg_tuple.8), index=3 + + zero = s32[] constant(0) + dynamic-slice.0 = s32[1,10] dynamic-slice(get-tuple-element.54, get-tuple-element.46, zero), dynamic_slice_sizes={1,10} + reshape.2 = s32[10] reshape(dynamic-slice.0) + + dynamic-slice.1 = s32[1] dynamic-slice(reshape.2, get-tuple-element.46), dynamic_slice_sizes={1} + reshape.3 = s32[] reshape(dynamic-slice.1) + + add.1 = s32[] add(get-tuple-element.47, reshape.3) + + reshape.4 = s32[1] reshape(add.1) + dynamic-update-slice.0 = s32[8] dynamic-update-slice(get-tuple-element.48, reshape.4, get-tuple-element.46) + + const = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.46, const) + ROOT tuple.10 = (s32[], s32[], s32[8], s32[8,10]) tuple(add.0, add.1, dynamic-update-slice.0, get-tuple-element.54) + } // wide.region_0.7 + + wide.region_1.29 { + constant.5 = s32[] constant(8) + wide.arg_tuple.30 = (s32[], s32[], s32[8], s32[8,10]) parameter(0) + get-tuple-element.16 = s32[] get-tuple-element(wide.arg_tuple.30), index=0 + ROOT compare.0 = pred[] compare(get-tuple-element.16, constant.5), direction=LT + } + + ENTRY main.43 { + constant.3 = s32[] constant(0) + init = s32[] constant(0) + array = s32[8,10] parameter(0) + broadcast.5 = s32[8] broadcast(constant.3), dimensions={} + + tuple.8 = (s32[], s32[], s32[8], s32[8,10]) tuple(constant.3, init, broadcast.5, array) + while = (s32[], s32[], s32[8], s32[8,10]) while(tuple.8), condition=wide.region_1.29, body=wide.region_0.7 + get-tuple-element.39 = s32[] get-tuple-element(while), index=1 + ROOT get-tuple-element.40 = s32[8] get-tuple-element(while), index=2 + } // main.43 + + )"; + + auto module = ParseAndReturnVerifiedModule(kModule).value(); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileDoubleBufferRemoval().Run(module.get())); + EXPECT_FALSE(simplified_loop); +} + +TEST_F(WhileDoubleBufferRemovalTest, MultipleUsersInput) { + [[maybe_unused]] constexpr char kModule[] = R"( + HloModule jit_scan + + wide.region_0.7 { + wide.arg_tuple.8 = (s32[], s32[], s32[8], s32[8], s32[8], s32[8]) parameter(0) + get-tuple-element.46 = s32[] get-tuple-element(wide.arg_tuple.8), index=0 + get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 + // buffer + get-tuple-element.48 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 + // input with multiple users + get-tuple-element.54 = s32[8] get-tuple-element(wide.arg_tuple.8), index=3 + // buffer + get-tuple-element.55 = s32[8] get-tuple-element(wide.arg_tuple.8), index=4 + // input + get-tuple-element.56 = s32[8] get-tuple-element(wide.arg_tuple.8), index=5 + + // this is here only to have another user for gte.54 + mult = s32[8] multiply(get-tuple-element.54, get-tuple-element.54) + + dynamic-slice.0 = s32[1] dynamic-slice(get-tuple-element.54, get-tuple-element.46), dynamic_slice_sizes={1} + reshape.2 = s32[] reshape(dynamic-slice.0) + add.1 = s32[] add(get-tuple-element.47, reshape.2) + + reshape.3 = s32[1] reshape(add.1) + dynamic-update-slice.0 = s32[8] dynamic-update-slice(get-tuple-element.48, reshape.3, get-tuple-element.46) + + dynamic-slice.1 = s32[1] dynamic-slice(get-tuple-element.56, get-tuple-element.46), dynamic_slice_sizes={1} + reshape.4 = s32[] reshape(dynamic-slice.1) + add.2 = s32[] multiply(get-tuple-element.47, reshape.4) + + reshape.5 = s32[1] reshape(add.2) + dynamic-update-slice.1 = s32[8] dynamic-update-slice(get-tuple-element.55, reshape.5, get-tuple-element.46) + + const = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.46, const) + ROOT tuple.10 = (s32[], s32[], s32[8], s32[8], s32[8], s32[8]) tuple(add.0, add.1, dynamic-update-slice.0, get-tuple-element.54, dynamic-update-slice.1, get-tuple-element.56) + } // wide.region_0.7 + + wide.region_1.29 { + constant.5 = s32[] constant(8) + wide.arg_tuple.30 = (s32[], s32[], s32[8], s32[8], s32[8], s32[8]) parameter(0) + get-tuple-element.16 = s32[] get-tuple-element(wide.arg_tuple.30), index=0 + ROOT compare.0 = pred[] compare(get-tuple-element.16, constant.5), direction=LT + } + + outer_body { + wide.arg_tuple.8 = (s32[], s32[], s32[8], s32[8]) parameter(0) + get-tuple-element.46 = s32[] get-tuple-element(wide.arg_tuple.8), index=0 + get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 + get-tuple-element.54 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 + get-tuple-element.56 = s32[8] get-tuple-element(wide.arg_tuple.8), index=3 + + constant.3 = s32[] constant(0) + broadcast = s32[8] broadcast(constant.3), dimensions={} + broadcast2 = s32[8] broadcast(constant.3), dimensions={} + + tuple.8 = (s32[], s32[], s32[8], s32[8], s32[8], s32[8]) tuple(constant.3, get-tuple-element.47, broadcast, get-tuple-element.54, broadcast2, get-tuple-element.56) + while = (s32[], s32[], s32[8], s32[8], s32[8], s32[8]) while(tuple.8), condition=wide.region_1.29, body=wide.region_0.7 + get-tuple-element.40 = s32[8] get-tuple-element(while), index=2 + get-tuple-element.41 = s32[8] get-tuple-element(while), index=4 + + const = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.46, const) + ROOT out = (s32[], s32[], s32[8], s32[8]) tuple(add.0, get-tuple-element.47, get-tuple-element.40, get-tuple-element.41) + } + + outer_cond { + constant.5 = s32[] constant(8) + wide.arg_tuple.30 = (s32[], s32[], s32[8], s32[8]) parameter(0) + get-tuple-element.16 = s32[] get-tuple-element(wide.arg_tuple.30), index=0 + ROOT compare.0 = pred[] compare(get-tuple-element.16, constant.5), direction=LT + } + + ENTRY main.43 { + constant.3 = s32[] constant(0) + init = s32[] constant(0) + array = s32[8] constant({1,2,3,4,5,6,7,8}) + array2 = s32[8] constant({10,20,30,40,50,60,70,80}) + tuple.8 = (s32[], s32[], s32[8], s32[8]) tuple(constant.3, init, array, array2) + while = (s32[], s32[], s32[8], s32[8]) while(tuple.8), condition=outer_cond, body=outer_body + get-tuple-element.40 = s32[8] get-tuple-element(while), index=2 + get-tuple-element.41 = s32[8] get-tuple-element(while), index=3 + ROOT out = (s32[8],s32[8]) tuple(get-tuple-element.40, get-tuple-element.41) + } // main.43 + + )"; + + auto module = ParseAndReturnVerifiedModule(kModule).value(); + + VLOG(3) << "before:\n" << module->ToString(); + + Literal reference = ExecuteAndTransfer(module->Clone(), {}); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileDoubleBufferRemoval().Run(module.get())); + EXPECT_TRUE(simplified_loop); + + VLOG(3) << "after:\n" << module->ToString(); + + // Only index 2 is replaced with the array. + for (HloInstruction* instr : + module->entry_computation()->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kWhile) { + EXPECT_EQ(instr->while_init()->operand(2)->opcode(), + HloOpcode::kConstant); + } + } + Literal actual = ExecuteAndTransfer(module->Clone(), {}); + + ASSERT_TRUE(LiteralTestUtil::NearOrEqual(/*expected=*/reference, + /*actual=*/actual, std::nullopt)); +} + +TEST_F(WhileDoubleBufferRemovalTest, RemoveDoubleBufferMoreInputs) { + [[maybe_unused]] constexpr char kModule[] = R"( + HloModule jit_scan + + wide.region_0.7 { + wide.arg_tuple.8 = (s32[], s32[], s32[8], s32[8], s32[10]) parameter(0) + get-tuple-element.46 = s32[] get-tuple-element(wide.arg_tuple.8), index=0 + get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 + get-tuple-element.48 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 + get-tuple-element.54 = s32[8] get-tuple-element(wide.arg_tuple.8), index=3 + get-tuple-element.55 = s32[10] get-tuple-element(wide.arg_tuple.8), index=4 + dynamic-slice.0 = s32[1] dynamic-slice(get-tuple-element.54, get-tuple-element.46), dynamic_slice_sizes={1} + reshape.2 = s32[] reshape(dynamic-slice.0) + dynamic-slice.1 = s32[1] dynamic-slice(get-tuple-element.55, get-tuple-element.46), dynamic_slice_sizes={1} + reshape.3 = s32[] reshape(dynamic-slice.1) + add.1 = s32[] add(reshape.3, reshape.2) + add.2 = s32[] add(add.1, get-tuple-element.47) + + reshape.4 = s32[1] reshape(add.2) + dynamic-update-slice.0 = s32[8] dynamic-update-slice(get-tuple-element.48, reshape.4, get-tuple-element.46) + const = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.46, const) + ROOT tuple.10 = (s32[], s32[], s32[8], s32[8], s32[10]) tuple(add.0, add.1, dynamic-update-slice.0, get-tuple-element.54, get-tuple-element.55) + } // wide.region_0.7 + + wide.region_1.29 { + constant.5 = s32[] constant(8) + wide.arg_tuple.30 = (s32[], s32[], s32[8], s32[8], s32[10]) parameter(0) + get-tuple-element.16 = s32[] get-tuple-element(wide.arg_tuple.30), index=0 + ROOT compare.0 = pred[] compare(get-tuple-element.16, constant.5), direction=LT + } + + outer_body { + wide.arg_tuple.8 = (s32[], s32[], s32[8], s32[10]) parameter(0) + get-tuple-element.46 = s32[] get-tuple-element(wide.arg_tuple.8), index=0 + get-tuple-element.47 = s32[] get-tuple-element(wide.arg_tuple.8), index=1 + get-tuple-element.48 = s32[8] get-tuple-element(wide.arg_tuple.8), index=2 + get-tuple-element.55 = s32[10] get-tuple-element(wide.arg_tuple.8), index=3 + + constant.3 = s32[] constant(0) + broadcast = s32[8] broadcast(constant.3), dimensions={} + + tuple.8 = (s32[], s32[], s32[8], s32[8], s32[10]) tuple(constant.3, get-tuple-element.47, broadcast, get-tuple-element.48, get-tuple-element.55) + while = (s32[], s32[], s32[8], s32[8], s32[10]) while(tuple.8), condition=wide.region_1.29, body=wide.region_0.7 + get-tuple-element.40 = s32[8] get-tuple-element(while), index=2 + + const = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.46, const) + ROOT out = (s32[], s32[], s32[8], s32[10]) tuple(add.0, get-tuple-element.47, get-tuple-element.40, get-tuple-element.55) + } + + outer_cond { + constant.5 = s32[] constant(8) + wide.arg_tuple.30 = (s32[], s32[], s32[8], s32[10]) parameter(0) + get-tuple-element.16 = s32[] get-tuple-element(wide.arg_tuple.30), index=0 + ROOT compare.0 = pred[] compare(get-tuple-element.16, constant.5), direction=LT + } + + ENTRY main.43 { + constant.3 = s32[] constant(0) + init = s32[] constant(0) + array = s32[8] constant({1,2,3,4,5,6,7,8}) + other_input = s32[10] constant({10,20,30,40,50,60,70,80,90,100}) + tuple.8 = (s32[], s32[], s32[8], s32[10]) tuple(constant.3, init, array, other_input) + while = (s32[], s32[], s32[8], s32[10]) while(tuple.8), condition=outer_cond, body=outer_body + get-tuple-element.39 = s32[8] get-tuple-element(while), index=2 + get-tuple-element.40 = s32[10] get-tuple-element(while), index=3 + ROOT out = (s32[8],s32[10]) tuple(get-tuple-element.39, get-tuple-element.40) + } // main.43 + + )"; + + auto module = ParseAndReturnVerifiedModule(kModule).value(); + + VLOG(3) << "before:\n" << module->ToString(); + + Literal reference = ExecuteAndTransfer(module->Clone(), {}); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileDoubleBufferRemoval().Run(module.get())); + EXPECT_TRUE(simplified_loop); + + VLOG(3) << "after:\n" << module->ToString(); + + // Index 2 of the while is replaced with the input array. + for (HloInstruction* instr : + module->entry_computation()->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kWhile) { + EXPECT_EQ(instr->while_init()->operand(2)->opcode(), + HloOpcode::kConstant); + } + } + Literal actual = ExecuteAndTransfer(module->Clone(), {}); + + ASSERT_TRUE(LiteralTestUtil::NearOrEqual(/*expected=*/reference, + /*actual=*/actual, std::nullopt)); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/while_loop_simplifier.h b/third_party/xla/xla/service/while_loop_simplifier.h index 3aacd3b0c70efc..f715919c575674 100644 --- a/third_party/xla/xla/service/while_loop_simplifier.h +++ b/third_party/xla/xla/service/while_loop_simplifier.h @@ -24,6 +24,8 @@ limitations under the License. namespace xla { +StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op); + // HLO pass that makes the following transformations on while loops: // // - A while loop with static trip count of 0 is deleted. diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index a3b75a278bee8c..1c8258bbd14ee9 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -9777,10 +9777,15 @@ absl::Status CudnnGraph::Execute(Stream& stream, std::unordered_map, void*> tensor_to_ptr_map; + absl::Span operands_without_workspace = operands; + DeviceMemoryBase workspace; + if (graph_.get_workspace_size() != 0) { + workspace = operands.back(); + CHECK_EQ(graph_.get_workspace_size(), workspace.size()); + operands_without_workspace = operands.first(operands.size() - 1); + } int operand_number = 0; - - CHECK_EQ(graph_.get_workspace_size(), 0); - for (DeviceMemoryBase operand : operands) { + for (DeviceMemoryBase operand : operands_without_workspace) { const cudnn_frontend::graph::Tensor_attributes attr = cudnn_frontend::graph::Tensor_attributes().set_uid( CuDnnTensorUID(operand_number)); @@ -9795,7 +9800,7 @@ absl::Status CudnnGraph::Execute(Stream& stream, dnn_support.cudnn_ ->GetHandle(ExtractGpuExecutor(stream.parent()), &stream) .handle(), - tensor_to_ptr_map, /*workspace=*/nullptr)); + tensor_to_ptr_map, workspace.opaque())); return absl::OkStatus(); }