Skip to content

Commit

Permalink
Address review comments on the previous change.
Browse files Browse the repository at this point in the history
Mostly refactoring and no functional change.

PiperOrigin-RevId: 428604960
Change-Id: I21d40e52d34752d1bcb80a083ea87d413a2546e9
  • Loading branch information
jingpu authored and tensorflower-gardener committed Feb 14, 2022
1 parent dcdf412 commit 071a34e
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ using mhlo::DotDimensionNumbersAttr;

// Replaces `region`'s terminator to TF::Yield.
void ReplaceReturnOp(Region &region, PatternRewriter &rewriter) {
OpBuilder::InsertionGuard guard(rewriter);

for (auto &block : region.getBlocks()) {
Operation *terminator = block.getTerminator();
auto return_op = llvm::dyn_cast_or_null<mhlo::ReturnOp>(terminator);
Expand Down Expand Up @@ -2642,13 +2644,12 @@ class ConvertWhileOp : public OpConversionPattern<mhlo::WhileOp> {
// Creates a TF::WhileRegionOp to replace the mhlo::WhileOp. HLO WhileOp
// currently doesn't support stateless and shape invariant, so these
// parameters are set to the default values.
rewriter.setInsertionPoint(while_op);
auto new_while = rewriter.create<TF::WhileRegionOp>(
while_op.getLoc(), while_op->getResultTypes(), while_op->getOperands(),
/*parallel_iterations=*/10,
/*is_stateless=*/false, /*shape_invariant=*/false);
new_while.cond().takeBody(while_op.getRegion(0));
new_while.body().takeBody(while_op.getRegion(1));
new_while.cond().takeBody(while_op.cond());
new_while.body().takeBody(while_op.body());
ReplaceReturnOp(new_while.cond(), rewriter);
ReplaceReturnOp(new_while.body(), rewriter);
rewriter.replaceOp(while_op, new_while.getResults());
Expand All @@ -2664,13 +2665,12 @@ class ConvertIfOp : public OpConversionPattern<mhlo::IfOp> {
mhlo::IfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// HLO IfOp currently doesn't support stateless
rewriter.setInsertionPoint(op);
auto new_op = rewriter.create<TF::IfRegionOp>(
op.getLoc(), op->getResultTypes(), op.pred(),
/*is_stateless=*/false, /*_then_func_name=*/nullptr,
/*_else_func_name=*/nullptr);
new_op.then_branch().takeBody(op.getRegion(0));
new_op.else_branch().takeBody(op.getRegion(1));
new_op.then_branch().takeBody(op.true_branch());
new_op.else_branch().takeBody(op.false_branch());
ReplaceReturnOp(new_op.then_branch(), rewriter);
ReplaceReturnOp(new_op.else_branch(), rewriter);
rewriter.replaceOp(op, new_op.getResults());
Expand Down Expand Up @@ -2878,15 +2878,14 @@ void LegalizeHloToTf::runOnOperation() {
void PopulateLegalizeHloToTfPatterns(RewritePatternSet *patterns,
MLIRContext *context) {
patterns->insert<
ConvertIfOp, ConvertWhileOp, ConvertSortToTfTopk, ConvertAvgPoolOp,
ConvertConvOp, ConvertNonTrivialConvOp, ConvertDynamicSliceOp,
ConvertDynamicUpdateSliceOp, ConvertGatherOp, ConvertMaxPoolOp,
ConvertScatterAddOp, ConvertScatterMaxOp, ConvertScatterMinOp,
ConvertScatterSubOp, ConvertScatterUpdateOp, ConvertSliceOp,
ConvertReduceOpToTfArgmax, ConvertReduceOpToTfArgmin,
ConvertAvgPoolOp, ConvertConvOp, ConvertNonTrivialConvOp,
ConvertDynamicSliceOp, ConvertDynamicUpdateSliceOp, ConvertGatherOp,
ConvertIfOp, ConvertMaxPoolOp, ConvertScatterAddOp, ConvertScatterMaxOp,
ConvertScatterMinOp, ConvertScatterSubOp, ConvertScatterUpdateOp,
ConvertSliceOp, ConvertReduceOpToTfArgmax, ConvertReduceOpToTfArgmin,
ConvertReduceOpToTfMax, ConvertReduceOpToTfMin, ConvertReduceOpToTfAll,
ConvertReduceOpToTfAny, ConvertReduceOpToTfSum, ConvertIotaOpToTfRange>(
context);
ConvertReduceOpToTfAny, ConvertReduceOpToTfSum, ConvertSortToTfTopk,
ConvertIotaOpToTfRange, ConvertWhileOp>(context);
populateWithGenerated(*patterns);
}

Expand Down

0 comments on commit 071a34e

Please sign in to comment.