Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][XLA] Remove redundant LHLO CopyOp #36335

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 22 additions & 13 deletions tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,42 +13,51 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {

// CHECK-LABEL: func @func_op
dfki-ehna marked this conversation as resolved.
Show resolved Hide resolved
func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
%0 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.max"(%arg0, %arg1, %[[MAX_RESULT]])
// CHECK-NEXT: "xla_lhlo.copy"(%[[MAX_RESULT]], %arg2)
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// CHECK-NEXT: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[RESULT]])
return %0 : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
}

// CHECK-LABEL: func @func_op_long
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
dfki-ehna marked this conversation as resolved.
Show resolved Hide resolved
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
%1 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.max"(%arg0, %arg1, %[[MAX_RESULT]])
// CHECK-NEXT: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
%2 = xla_hlo.add %arg0, %1 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.add"(%arg0, %[[MAX_RESULT]], %[[ADD_RESULT]])
// CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
%3 = xla_hlo.min %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.min"(%arg0, %arg1, %[[MIN_RESULT]])
// CHECK-NEXT: "xla_lhlo.min"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
%4 = xla_hlo.sub %arg1, %3 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.sub"(%arg1, %[[MIN_RESULT]], %[[SUB_RESULT]])
// CHECK-NEXT: "xla_lhlo.sub"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
%5 = xla_hlo.mul %2, %4 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[RESULT]])
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
dfki-ehna marked this conversation as resolved.
Show resolved Hide resolved
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %arg2)
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
return %5 : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
}

// CHECK-LABEL: func @remove_lhlo_copy_op_created_from_tensor_store
func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: memref<f32>) {
%0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
tensor_store %0, %arg2 : memref<f32>
return
}
// CHECK: (%[[NEW_ARG0:.*]]: memref<f32>, %[[NEW_ARG1:.*]]: memref<f32>, %[[RESULT:.*]]: memref<f32>)
// CHECK-NOT: %[[ALLOC_OPERAND:.*]] = alloc() {temp = true} : memref<f32>
// CHECK: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[RESULT]]) : (memref<f32>, memref<f32>, memref<f32>) -> ()
// CHECK-NOT: "xla_lhlo.copy"(%[[ALLOC_OPERAND]], %[[RESULT]]) : (memref<f32>, memref<f32>) -> ()
// CHECK-NOT: dealloc %[[ALLOC_OPERAND]] : memref<f32>
// CHECK: "xla_lhlo.terminator"() : () -> ()

// CHECK-LABEL: func @fusion
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
Expand Down Expand Up @@ -212,4 +221,4 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
// CHECK-NEXT: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
}
86 changes: 54 additions & 32 deletions tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
Expand All @@ -49,17 +50,6 @@ Operation* FindInsertionPointForCopy(Value value) {
return nullptr;
}

Value GetTensorStore(Value value) {
for (const auto& user : value.getUsers()) {
if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) {
if (tensor_store.getOperand(0) == value) {
return tensor_store.getOperand(1);
}
}
}
return nullptr;
}

Value InsertAllocAndDealloc(Location loc, Value result,
ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>();
Expand All @@ -85,17 +75,6 @@ Value InsertAllocAndDealloc(Location loc, Value result,
return alloc;
}

/// For every tensor-type value that is produced in the original function,
/// this function returns the buffer that can be used in the converted
/// function to store that values held in the tensor.
Value GetBufferForResultValue(Location loc, Value result,
ConversionPatternRewriter* rewriter) {
if (auto existing_memref = GetTensorStore(result)) {
return existing_memref;
}
return InsertAllocAndDealloc(loc, result, rewriter);
}

template <typename HloOpTy, typename LhloOpTy>
class HloToLhloOpConverter : public ConversionPattern {
public:
Expand All @@ -109,7 +88,7 @@ class HloToLhloOpConverter : public ConversionPattern {
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) {
buffer_args.push_back(
GetBufferForResultValue(op->getLoc(), result, &rewriter));
InsertAllocAndDealloc(op->getLoc(), result, &rewriter));
}
rewriter.create<LhloOpTy>(op->getLoc(), llvm::None, buffer_args,
op->getAttrs());
Expand Down Expand Up @@ -137,7 +116,7 @@ struct HloToLHloReduceOpConverter
const auto& original_results = op.getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) {
buffer_args.push_back(GetBufferForResultValue(loc, result, &rewriter));
buffer_args.push_back(InsertAllocAndDealloc(loc, result, &rewriter));
}
auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
loc, llvm::None, buffer_args, op.getAttrs());
Expand Down Expand Up @@ -194,7 +173,8 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
PatternMatchResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.eraseOp(op);
rewriter.replaceOpWithNewOp<xla_lhlo::CopyOp>(
op, llvm::None, operands.front(), operands.back());
return matchSuccess();
}
};
Expand Down Expand Up @@ -255,14 +235,11 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
// %arg1: memref<4xf32>,
// %arg2: memref<4xf32>) {
// %0 = alloc() {temp = true} : memref<4xf32>
// %1 = alloc() {temp = true} : memref<4xf32>
// "xla_lhlo.max"(%arg0, %arg1, %1) {name = "maximum.47"} :
// "xla_lhlo.max"(%arg0, %arg1, %0) {name = "maximum.47"} :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// "xla_lhlo.add"(%arg0, %1, %0) {name = "maximum.47"} :
// "xla_lhlo.add"(%arg0, %0, %arg2) {name = "maximum.47"} :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// dealloc %1 : memref<4xf32>
// "xla_lhlo.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
// dealloc %0 : memref<4xf32>
// "xla_lhlo.terminator"() : () -> ()
// }

Expand Down Expand Up @@ -410,12 +387,57 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
// clang-format on
}

/// Removes Lhlo.CopyOp that copies from an allocated buffer to the block
/// argument. All uses of the buffer are replaced with the block argument.
struct RedundantCopiesRemoval : mlir::FunctionPass<RedundantCopiesRemoval> {
dfki-ehna marked this conversation as resolved.
Show resolved Hide resolved
void runOnFunction() override {
llvm::SmallVector<mlir::Operation*, 2> eraseList;
getFunction().walk([&](mlir::xla_lhlo::CopyOp copyOp) {
auto arguments = copyOp.getOperation()->getBlock()->getArguments();
dfki-ehna marked this conversation as resolved.
Show resolved Hide resolved
if (std::any_of(arguments.begin(), arguments.end(),
dfki-ehna marked this conversation as resolved.
Show resolved Hide resolved
[&](mlir::BlockArgument arg) {
return copyOp.output() == arg;
dfki-ehna marked this conversation as resolved.
Show resolved Hide resolved
}) &&
std::none_of(arguments.begin(), arguments.end(),
[&](mlir::BlockArgument arg) {
return copyOp.operand() == arg;
})) {
dfki-ehna marked this conversation as resolved.
Show resolved Hide resolved
dfki-ehna marked this conversation as resolved.
Show resolved Hide resolved
mlir::Value operand = copyOp.operand();
mlir::Value output = copyOp.output();
copyOp.erase();
for (auto op : operand.getUsers()) {
if (!mlir::isa<mlir::DeallocOp>(op)) {
op->replaceUsesOfWith(operand, output);
}
}
auto allocOp = operand.getDefiningOp();
if (auto deallocOp =
mlir::dyn_cast<mlir::DeallocOp>(*allocOp->getUsers().begin())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(transmitting @joker-eph): You are dereferencing the user of the allocOp while you already deleted the copyOp, which mean there might be no user left.

Here is a minimal reproducer that will show it:

func @Fusion(%arg0: memref<2x2xf32>) -> memref<2x2xf32> {
%0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return %0 : memref<2x2xf32>
}

This comes back to my other point about testing: with dedicated testing for this pass, I would have expected this IR to be the first basic test to validate the pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the hint. You are right that this code snippet would break our current version. Although this code could exist from a theoretical point of view as an input program, the current implementation of the HLO-to-LHLO-Legalization pass ensures that there will always be a dealloc. Furthermore, it is even worse: currently, the Std.Return-to-LHLO converter expects a final dealloc operation to place a proper CopyOp. The following code snippet is taken from the hlo_legalize_to_lhlo.cc file (starting in line 425):

if (dealloc == nullptr) {
    returnOp.emitOpError()
        << "Missing dealloc for operand " << operand.index();
    return matchFailure();
}

This code snippet also fails in the following tiny test case:

func @TestFunc(%arg0: tensor<2xf32>) -> tensor<2xf32> {
   return %arg0 : tensor<2xf32>
}

eraseList.push_back(deallocOp);
eraseList.push_back(allocOp);
}
}
});
for (auto op : eraseList) {
op->erase();
}
};
};
dfki-ehna marked this conversation as resolved.
Show resolved Hide resolved

std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass() {
return absl::make_unique<HloLegalizeToLhlo>();
}

static PassRegistration<HloLegalizeToLhlo> legalize_pass(
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
std::unique_ptr<OpPassBase<FuncOp>> createLhloCopyRemovalPass() {
return absl::make_unique<RedundantCopiesRemoval>();
}

static PassPipelineRegistration<> legalize_pass(
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect",
[](mlir::OpPassManager& pm) {
pm.addPass(createLegalizeToLhloPass());
pm.addPass(createLhloCopyRemovalPass());
});

} // namespace xla_hlo
} // namespace mlir
5 changes: 5 additions & 0 deletions tensorflow/compiler/mlir/xla/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass();
// Lowers from HLO dialect to Linalg dialect.
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeHloToLinalgPass();

// Removes unnecessary LHLO copies which copy from the allocated buffers to the
// block arguments. These copies have been created by replacing TensorStoreOp
// with LHLO.CopyOp in HLO to LHLO lowering.
std::unique_ptr<OpPassBase<FuncOp>> createLhloCopyRemovalPass();

} // namespace xla_hlo

namespace xla_lhlo {
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ Status LowerLHLOToGPU(mlir::ModuleOp module) {
pm.addPass(absl::make_unique<FusionToLhloConverter>());
// Next, we can strip the outer fusion operation.
pm.addPass(absl::make_unique<FusionOpRemover>());
// Remove unnecessary Lhlo copies.
pm.addPass(::mlir::xla_hlo::createLhloCopyRemovalPass());
// Transform lhlo operations to LinAlg.
pm.addPass(::mlir::xla_lhlo::createLegalizeLhloToLinalgPass());
// Fuse linalg operations. This will yield a single tiled loop nest where
Expand Down