Skip to content

Commit

Permalink
[MLIR][XLA] Remove redundant LHLO CopyOp
Browse files Browse the repository at this point in the history
  • Loading branch information
dfki-ehna committed Jan 30, 2020
1 parent 534acb2 commit f62aff3
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 15 deletions.
10 changes: 2 additions & 8 deletions tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,14 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {

// CHECK-LABEL: func @func_op
func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
%0 = xla_hlo.max %arg0, %arg1 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.max"(%arg0, %arg1, %[[MAX_RESULT]])
// CHECK-NEXT: "xla_lhlo.copy"(%[[MAX_RESULT]], %arg2)
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// CHECK-NEXT: "xla_lhlo.max"(%arg0, %arg1, %arg2)
return %0 : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
}

// CHECK-LABEL: func @func_op_long
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
Expand All @@ -38,13 +34,11 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%4 = xla_hlo.sub %arg1, %3 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.sub"(%arg1, %[[MIN_RESULT]], %[[SUB_RESULT]])
%5 = xla_hlo.mul %2, %4 {name = "maximum.47"} : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// CHECK-NEXT: "xla_lhlo.mul"(%[[ADD_RESULT]], %[[SUB_RESULT]], %arg2)
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %arg2)
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
return %5 : tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
}
Expand Down
52 changes: 45 additions & 7 deletions tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class HloToLhloOpConverter : public ConversionPattern {
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) {
buffer_args.push_back(
GetBufferForResultValue(op->getLoc(), result, &rewriter));
InsertAllocAndDealloc(op->getLoc(), result, &rewriter));
}
rewriter.create<LhloOpTy>(op->getLoc(), llvm::None, buffer_args,
op->getAttrs());
Expand Down Expand Up @@ -194,11 +194,50 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
PatternMatchResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.eraseOp(op);
rewriter.replaceOpWithNewOp<xla_lhlo::CopyOp>(
op, llvm::None, operands.front(), operands.back());
return matchSuccess();
}
};

/// Removes an allocated buffer without any users.
void RemoveRedundantAllocDealloc(ModuleOp module) {
module.walk([&](AllocOp allocOp) {
auto allocOpUsers = allocOp.getResult().getUsers();
if (auto deallocOp = dyn_cast<DeallocOp>(*allocOpUsers.begin())) {
deallocOp.erase();
allocOp.erase();
}
});
}

/// Removes Lhlo.CopyOp that copies from an allocated buffer to the block
/// argument. All uses of the buffer are replaced with the block argument.
void RemoveRedundantCopyOp(ModuleOp module) {
module.walk([&](xla_lhlo::CopyOp copyOp) {
Value operand = copyOp.operand();
Value output = copyOp.output();
auto arguments = copyOp.getOperation()->getBlock()->getArguments();
if (std::any_of(arguments.begin(), arguments.end(),
[&](BlockArgument arg) { return output == arg; }) &&
std::none_of(arguments.begin(), arguments.end(),
[&](BlockArgument arg) { return operand == arg; })) {
llvm::SmallVector<Operation*, 4> updateList;
for (auto op : copyOp.operand().getUsers()) {
if (!isa<DeallocOp>(op)) {
updateList.push_back(op);
}
}
for (auto op : updateList) {
op->replaceUsesOfWith(operand, output);
}
copyOp.erase();
}
});
RemoveRedundantAllocDealloc(module);
}


// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
// buffers if necessary.
//
Expand Down Expand Up @@ -255,14 +294,11 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
// %arg1: memref<4xf32>,
// %arg2: memref<4xf32>) {
// %0 = alloc() {temp = true} : memref<4xf32>
// %1 = alloc() {temp = true} : memref<4xf32>
// "xla_lhlo.max"(%arg0, %arg1, %1) {name = "maximum.47"} :
// "xla_lhlo.max"(%arg0, %arg1, %0) {name = "maximum.47"} :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// "xla_lhlo.add"(%arg0, %1, %0) {name = "maximum.47"} :
// "xla_lhlo.add"(%arg0, %0, %arg2) {name = "maximum.47"} :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// dealloc %1 : memref<4xf32>
// "xla_lhlo.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
// dealloc %0 : memref<4xf32>
// "xla_lhlo.terminator"() : () -> ()
// }

Expand Down Expand Up @@ -291,6 +327,8 @@ struct HloLegalizeToLhlo : public ModulePass<HloLegalizeToLhlo> {
if (failed(applyFullConversion(module, target, patterns, nullptr))) {
signalPassFailure();
}

RemoveRedundantCopyOp(module);
}
};

Expand Down

0 comments on commit f62aff3

Please sign in to comment.