Skip to content

Commit

Permalink
Change RedundantCopiesRemoval to a pass
Browse files Browse the repository at this point in the history
  • Loading branch information
dfki-ehna committed Feb 11, 2020
1 parent a402ef1 commit aa2d06d
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 60 deletions.
15 changes: 14 additions & 1 deletion tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
}

// CHECK-LABEL: func @remove_lhlo_copy_op_created_from_tensor_store
func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: memref<f32>) {
%0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
tensor_store %0, %arg2 : memref<f32>
return
}
// CHECK: (%[[NEW_ARG0:.*]]: memref<f32>, %[[NEW_ARG1:.*]]: memref<f32>, %[[RESULT:.*]]: memref<f32>)
// CHECK-NOT: %[[ALLOC_OPERAND:.*]] = alloc() {temp = true} : memref<f32>
// CHECK: "xla_lhlo.max"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[RESULT]]) : (memref<f32>, memref<f32>, memref<f32>) -> ()
// CHECK-NOT: "xla_lhlo.copy"(%[[ALLOC_OPERAND]], %[[RESULT]]) : (memref<f32>, memref<f32>) -> ()
// CHECK-NOT: dealloc %[[ALLOC_OPERAND]] : memref<f32>
// CHECK: "xla_lhlo.terminator"() : () -> ()

// CHECK-LABEL: func @fusion
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
Expand Down Expand Up @@ -208,4 +221,4 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
// CHECK-NEXT: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
}
108 changes: 49 additions & 59 deletions tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
Expand All @@ -49,17 +50,6 @@ Operation* FindInsertionPointForCopy(Value value) {
return nullptr;
}

Value GetTensorStore(Value value) {
for (const auto& user : value.getUsers()) {
if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) {
if (tensor_store.getOperand(0) == value) {
return tensor_store.getOperand(1);
}
}
}
return nullptr;
}

Value InsertAllocAndDealloc(Location loc, Value result,
ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>();
Expand All @@ -85,17 +75,6 @@ Value InsertAllocAndDealloc(Location loc, Value result,
return alloc;
}

/// For every tensor-type value that is produced in the original function,
/// this function returns the buffer that can be used in the converted
/// function to store that values held in the tensor.
Value GetBufferForResultValue(Location loc, Value result,
ConversionPatternRewriter* rewriter) {
if (auto existing_memref = GetTensorStore(result)) {
return existing_memref;
}
return InsertAllocAndDealloc(loc, result, rewriter);
}

template <typename HloOpTy, typename LhloOpTy>
class HloToLhloOpConverter : public ConversionPattern {
public:
Expand Down Expand Up @@ -137,7 +116,7 @@ struct HloToLHloReduceOpConverter
const auto& original_results = op.getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) {
buffer_args.push_back(GetBufferForResultValue(loc, result, &rewriter));
buffer_args.push_back(InsertAllocAndDealloc(loc, result, &rewriter));
}
auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
loc, llvm::None, buffer_args, op.getAttrs());
Expand Down Expand Up @@ -200,38 +179,6 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern {
}
};

/// Removes Lhlo.CopyOp that copies from an allocated buffer to the block
/// argument. All uses of the buffer are replaced with the block argument.
void RemoveRedundantCopies(ModuleOp module) {
llvm::SmallVector<Operation*, 2> eraseList;
module.walk([&](xla_lhlo::CopyOp copyOp) {
auto arguments = copyOp.getOperation()->getBlock()->getArguments();
if (std::any_of(
arguments.begin(), arguments.end(),
[&](BlockArgument arg) { return copyOp.output() == arg; }) &&
std::none_of(
arguments.begin(), arguments.end(),
[&](BlockArgument arg) { return copyOp.operand() == arg; })) {
Value operand = copyOp.operand();
Value output = copyOp.output();
copyOp.erase();
for (auto op : operand.getUsers()) {
if (!isa<DeallocOp>(op)) {
op->replaceUsesOfWith(operand, output);
}
}
auto allocOp = operand.getDefiningOp();
if (auto deallocOp = dyn_cast<DeallocOp>(*allocOp->getUsers().begin())) {
eraseList.push_back(deallocOp);
eraseList.push_back(allocOp);
}
}
});
for (auto op : eraseList) {
op->erase();
}
}

// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
// buffers if necessary.
//
Expand Down Expand Up @@ -321,8 +268,6 @@ struct HloLegalizeToLhlo : public ModulePass<HloLegalizeToLhlo> {
if (failed(applyFullConversion(module, target, patterns, nullptr))) {
signalPassFailure();
}

RemoveRedundantCopies(module);
}
};

Expand Down Expand Up @@ -442,12 +387,57 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
// clang-format on
}

/// Removes Lhlo.CopyOp that copies from an allocated buffer to the block
/// argument. All uses of the buffer are replaced with the block argument.
struct RedundantCopiesRemoval : mlir::FunctionPass<RedundantCopiesRemoval> {
void runOnFunction() override {
llvm::SmallVector<mlir::Operation*, 2> eraseList;
getFunction().walk([&](mlir::xla_lhlo::CopyOp copyOp) {
auto arguments = copyOp.getOperation()->getBlock()->getArguments();
if (std::any_of(arguments.begin(), arguments.end(),
[&](mlir::BlockArgument arg) {
return copyOp.output() == arg;
}) &&
std::none_of(arguments.begin(), arguments.end(),
[&](mlir::BlockArgument arg) {
return copyOp.operand() == arg;
})) {
mlir::Value operand = copyOp.operand();
mlir::Value output = copyOp.output();
copyOp.erase();
for (auto op : operand.getUsers()) {
if (!mlir::isa<mlir::DeallocOp>(op)) {
op->replaceUsesOfWith(operand, output);
}
}
auto allocOp = operand.getDefiningOp();
if (auto deallocOp =
mlir::dyn_cast<mlir::DeallocOp>(*allocOp->getUsers().begin())) {
eraseList.push_back(deallocOp);
eraseList.push_back(allocOp);
}
}
});
for (auto op : eraseList) {
op->erase();
}
};
};

std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass() {
return absl::make_unique<HloLegalizeToLhlo>();
}

static PassRegistration<HloLegalizeToLhlo> legalize_pass(
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
std::unique_ptr<OpPassBase<FuncOp>> createLhloCopyRemovalPass() {
return absl::make_unique<RedundantCopiesRemoval>();
}

static PassPipelineRegistration<> legalize_pass(
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect",
[](mlir::OpPassManager& pm) {
pm.addPass(createLegalizeToLhloPass());
pm.addPass(createLhloCopyRemovalPass());
});

} // namespace xla_hlo
} // namespace mlir
5 changes: 5 additions & 0 deletions tensorflow/compiler/mlir/xla/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeToLhloPass();
// Lowers from HLO dialect to Linalg dialect.
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeHloToLinalgPass();

// Removes unnecessary LHLO copies which copy from the allocated buffers to the
// block arguments. These copies have been created by replacing TensorStoreOp
// with LHLO.CopyOp in HLO to LHLO lowering.
std::unique_ptr<OpPassBase<FuncOp>> createLhloCopyRemovalPass();

} // namespace xla_hlo

namespace xla_lhlo {
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ Status LowerLHLOToGPU(mlir::ModuleOp module) {
pm.addPass(absl::make_unique<FusionToLhloConverter>());
// Next, we can strip the outer fusion operation.
pm.addPass(absl::make_unique<FusionOpRemover>());
// Remove unnecessary Lhlo copies.
pm.addPass(::mlir::xla_hlo::createLhloCopyRemovalPass());
// Transform lhlo operations to LinAlg.
pm.addPass(::mlir::xla_lhlo::createLegalizeLhloToLinalgPass());
// Fuse linalg operations. This will yield a single tiled loop nest where
Expand Down

0 comments on commit aa2d06d

Please sign in to comment.