Skip to content

Commit

Permalink
[tfrt:jit] Insert a copy when returning a dynamic broadcast
Browse files Browse the repository at this point in the history
We bufferize dynamic broadcasts into a memref reinterpret cast that yields a
memref with affine map. This clashes with the return type of the function that
doesn't support affine maps. Insert a copy for this special case.

This is still a bit of a hack, but I prefer not to invest too much as a
different representation for dynamic broaddcasts is on the horizon.

PiperOrigin-RevId: 405473962
  • Loading branch information
d0k authored and TensorFlow MLIR Team committed Oct 25, 2021
1 parent 3fd107f commit cb34357
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
8 changes: 6 additions & 2 deletions include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_

#include <functional>
#include <memory>

#include "mlir/IR/MLIRContext.h"
Expand Down Expand Up @@ -64,12 +65,15 @@ void populateHLOToLHLOConversionPattern(MLIRContext *context,

// Collection of rewrite patterns for lowering of HLO to memref dialect.
// These patterns generally assume that the HLO operation are aliasing their
// input memrefs. If enforce_identity_map is set to true, copies will be
// input memrefs. If enforce_identity_map returns true for an op, copies will be
// inserted when the lowering would otherwise lead to a memref with a
// non-identity map.
void populateHLOToMemrefConversionPattern(
BufferizeTypeConverter *converter, RemoveSignTypeConverter *sign_converter,
OwningRewritePatternList *patterns, bool enforce_identity_map = true);
OwningRewritePatternList *patterns,
std::function<bool(Operation *)> enforce_identity_map = [](Operation *) {
return true;
});

// Collection of rewrite patterns for lowering of HLO to Linalg dialect.
void populateHLOToLinalgConversionPattern(MLIRContext *context,
Expand Down
14 changes: 8 additions & 6 deletions lib/Dialect/mhlo/transforms/hlo_legalize_to_memref.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

// This file implements logic for lowering HLO dialect to LHLO dialect.

#include <functional>
#include <utility>

#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
Expand Down Expand Up @@ -159,10 +160,10 @@ class HloToMemrefDynamicBroadcastInDimOpConverter
public:
HloToMemrefDynamicBroadcastInDimOpConverter(
TypeConverter& converter, RemoveSignTypeConverter* sign_converter,
MLIRContext* ctx, bool enforce_identity_maps)
MLIRContext* ctx, std::function<bool(Operation*)> enforce_identity_maps)
: BaseOpConversion<mhlo::DynamicBroadcastInDimOp>(converter,
sign_converter, ctx),
enforce_identity_maps_(enforce_identity_maps) {}
enforce_identity_maps_(std::move(enforce_identity_maps)) {}

Value signlessRewrite(mhlo::DynamicBroadcastInDimOp op,
ArrayRef<Value> operands, Type op_result_type,
Expand All @@ -171,7 +172,7 @@ class HloToMemrefDynamicBroadcastInDimOpConverter
if (!result_type) return {};
Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);

if (enforce_identity_maps_) {
if (enforce_identity_maps_(op)) {
result = CreateCopy(op, result, &rewriter);
}

Expand Down Expand Up @@ -295,7 +296,7 @@ class HloToMemrefDynamicBroadcastInDimOpConverter
return copy;
}

bool enforce_identity_maps_;
std::function<bool(Operation*)> enforce_identity_maps_;
};

struct HloLegalizeToMemrefPass
Expand Down Expand Up @@ -331,10 +332,11 @@ struct HloLegalizeToMemrefPass

void populateHLOToMemrefConversionPattern(
BufferizeTypeConverter* converter, RemoveSignTypeConverter* sign_converter,
OwningRewritePatternList* patterns, bool enforce_identity_maps) {
OwningRewritePatternList* patterns,
std::function<bool(Operation*)> enforce_identity_maps) {
MLIRContext* context = patterns->getContext();
patterns->insert<HloToMemrefDynamicBroadcastInDimOpConverter>(
*converter, sign_converter, context, enforce_identity_maps);
*converter, sign_converter, context, std::move(enforce_identity_maps));
patterns->insert<HloToMemrefDynamicReshapeConverter,
HloToMemrefReshapeUnrankedConverter>(
*converter, sign_converter, context);
Expand Down

0 comments on commit cb34357

Please sign in to comment.