Skip to content

Commit

Permalink
Rewrite tfl.batch_matmul with batch_1 constants of ones into reduce_s…
Browse files Browse the repository at this point in the history
…um, to avoid materializing the constant.

PiperOrigin-RevId: 614834108
  • Loading branch information
tensorflower-gardener committed Mar 11, 2024
1 parent 897493e commit cf21b73
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 1 deletion.
20 changes: 20 additions & 0 deletions tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,23 @@ func.func @Batchmatmul2FullyconnectedQDQ(%arg0: tensor<4x128x2xf32>, %arg1: tens
// CHECK-SAME: {fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<4x128x2xf32>, tensor<1x2xf32>, none) -> tensor<4x128x1xf32>
// CHECK-NEXT: return %[[FC_RES]]
}

// CHECK-LABEL: BatchmatmulToReduceSumI32
// CHECK-NOT: "tfl.batch_matmul"
func.func @BatchmatmulToReduceSumI32(%arg0: tensor<1x16384x257xi32>) -> (tensor<1x1x257xi32>) {
%0 = arith.constant dense<1> : tensor<1x1x16384xi32>
%1 = "tfl.batch_matmul"(%0, %arg0) {adj_x = false, adj_y = false} : (tensor<1x1x16384xi32>, tensor<1x16384x257xi32>) -> tensor<1x1x257xi32>
func.return %1 : tensor<1x1x257xi32>
// CHECK: %[[CONST_DIM:.*]] = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: %[[RED:.*]] = "tfl.sum"(%arg0, %[[CONST_DIM]]) {keep_dims = true} : (tensor<1x16384x257xi32>, tensor<1xi32>) -> tensor<1x1x257xi32>
}

// CHECK-LABEL: BatchmatmulToReduceSumF32
// CHECK-NOT: "tfl.batch_matmul"
func.func @BatchmatmulToReduceSumF32(%arg0: tensor<1x16384x257xf32>) -> (tensor<1x1x257xf32>) {
%0 = arith.constant dense<1.0> : tensor<1x1x16384xf32>
%1 = "tfl.batch_matmul"(%0, %arg0) {adj_x = false, adj_y = false} : (tensor<1x1x16384xf32>, tensor<1x16384x257xf32>) -> tensor<1x1x257xf32>
func.return %1 : tensor<1x1x257xf32>
// CHECK: %[[CONST_DIM:.*]] = "tfl.pseudo_const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: %[[RED:.*]] = "tfl.sum"(%arg0, %[[CONST_DIM]]) {keep_dims = true} : (tensor<1x16384x257xf32>, tensor<1xi32>) -> tensor<1x1x257xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Support/TypeID.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
Expand Down Expand Up @@ -152,14 +153,83 @@ struct ConvertBatchMatMulOp2FullyConnectedOp
};
};

// Converts batch_matmul operation with a ones tensor to a reduce_sum.
struct ConvertBatchMatMulOpToReduceSum
: public OpRewritePattern<TFL::BatchMatMulOp> {
using OpRewritePattern<TFL::BatchMatMulOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TFL::BatchMatMulOp bmm_op,
PatternRewriter& rewriter) const override {
// For simplicity, check if first operand is an identity i.e. `ones_like`.
// This assumes canonicalization ordered operands this way.
SplatElementsAttr constant;
if (!matchPattern(bmm_op.getX(), m_Constant(&constant))) {
return failure();
}

if (!SplatValueEquals(constant, 1.0)) {
return failure();
}

// The input tensors x and y are 2-D or higher with shape:
// [..., r_x == 1, c_x] and [..., c_y, r_y].
// The position of r_* and c_* are determined by the polarity of
// the adj(X|Y) attribute, respectively.
// So adjX == True indicates [..., c_x, r_x == 1].
llvm::ArrayRef<int64_t> lhs_shape =
bmm_op.getX().getType().cast<RankedTensorType>().getShape();
int rX = lhs_shape.size() - 2;
int cX = lhs_shape.size() - 1;
if (bmm_op.getAdjX()) {
rX = lhs_shape.size() - 1;
cX = lhs_shape.size() - 2;
}

if (lhs_shape[rX] != 1) {
return failure();
}

llvm::ArrayRef<int64_t> rhs_shape =
bmm_op.getY().getType().cast<RankedTensorType>().getShape();
int rY = rhs_shape.size() - 1;
int cY = rhs_shape.size() - 2;
if (bmm_op.getAdjX()) {
rY = rhs_shape.size() - 2;
cY = rhs_shape.size() - 1;
}

auto reduce_dim_op = rewriter.create<TFL::ConstOp>(
bmm_op->getLoc(),
DenseIntElementsAttr::get(
RankedTensorType::get({1}, rewriter.getI32Type()), {cY}));
auto sum_op = rewriter.create<TFL::SumOp>(
bmm_op->getLoc(), bmm_op.getType(), bmm_op.getY(), reduce_dim_op,
/*keep_dims=*/rewriter.getBoolAttr(true));
rewriter.replaceOp(bmm_op, sum_op);
return success();
};

private:
bool SplatValueEquals(SplatElementsAttr float_or_int, double rhs) const {
if (float_or_int.isa<DenseFPElementsAttr>()) {
return float_or_int.cast<DenseFPElementsAttr>()
.getSplatValue<APFloat>()
.isExactlyValue(rhs);
} else if (float_or_int.cast<DenseIntElementsAttr>()) {
return float_or_int.getSplatValue<APInt>() == static_cast<int>(rhs);
}
return false;
}
};

#include "tensorflow/compiler/mlir/lite/transforms/generated_optimize_batch_matmul.inc"

void OptimizeBatchMatmulPass::runOnOperation() {
auto func = getOperation();
auto* ctx = &getContext();

RewritePatternSet patterns(ctx);
patterns.add<ConvertBatchMatMulOp2FullyConnectedOp>(ctx);
patterns.add<ConvertBatchMatMulOp2FullyConnectedOp,
ConvertBatchMatMulOpToReduceSum>(ctx);
TFL::populateWithGenerated(patterns);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
Expand Down

0 comments on commit cf21b73

Please sign in to comment.