Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Add constant folder for xla_hlo.broadcast_in_dim op #40745

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 17 additions & 9 deletions tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -534,17 +534,25 @@ static LogicalResult Verify(BroadcastInDimOp op) {
return success();
}

OpFoldResult BroadcastInDimOp::fold(ArrayRef<Attribute>) {
OpFoldResult BroadcastInDimOp::fold(ArrayRef<Attribute> attrs) {
auto type = getType().cast<RankedTensorType>();
if (type != getOperand().getType()) {
return nullptr;
}
auto broadcast_values = broadcast_dimensions().getValues<int64_t>();
if (!std::equal(broadcast_values.begin(), broadcast_values.end(),
llvm::seq<int64_t>(0, type.getRank()).begin())) {
return nullptr;
if (type == getOperand().getType()) {
auto broadcast_values = broadcast_dimensions().getValues<int64_t>();
if (!std::equal(broadcast_values.begin(), broadcast_values.end(),
llvm::seq<int64_t>(0, type.getRank()).begin())) {
return {};
}
return getOperand();
}
return getOperand();

// Constant fold when an operand is a splat tensor attribute.
if (!attrs[0] || !type.hasStaticShape()) return {};
auto splatOperandAttr = attrs[0].dyn_cast<SplatElementsAttr>();
if (!splatOperandAttr) return {};
// MLIR core bug (https://bugs.llvm.org/show_bug.cgi?id=46588): dense element
// attribute iterator not implemented for complex element types.
if (type.getElementType().isa<ComplexType>()) return {};
return SplatElementsAttr::get(type, splatOperandAttr.getSplatValue());
}

//===----------------------------------------------------------------------===//
Expand Down
19 changes: 18 additions & 1 deletion tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ func @broadcast_in_dim_not_identity_permutation(%arg0: tensor<2x2xf32>) -> tenso
return %0 : tensor<2x2xf32>
}


// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic
func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> {
// CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32>
Expand All @@ -365,6 +364,24 @@ func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %ar
return %0 : tensor<5x4xf32>
}

// CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d
func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> {
%cst = mhlo.constant dense<0.000000e+00> : tensor<f32>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one particularly is more than a splat: it is a splat into a 0d scalar tensor type (the splat is somewhat incidental?). By having this as a folding rule, I lose the ability to represent a scalar-constant broadcast in HLO (since it will always be folded to fully materialized). Having that structure (not unconditionally materializing scalar broadcasts) is fairly important from an analysis on tensors perspective as having this knowledge enables various size-sensitive optimizations.

I agree that this is a valid folding from a correctness perspective, but it limits what we can express in mhlo (since this is the current idiom for representing a scalar broadcast and is used quite extensively).

I'm struggling with an "HLO principles" reason why this folding should be disallowed and the following allowed, though. This may actually be pointing at a missing op (i.e. mhlo.broadcast_scalar) if this case is important to represent at this level.

Copy link
Contributor Author

@bondhugula bondhugula Jul 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one particularly is more than a splat: it is a splat into a 0d scalar tensor type (the splat is somewhat incidental?). By having this as a folding rule, I lose the ability to represent a scalar-constant broadcast in

@stellaraccident: this one is a splat! :-) (a trivial splat). Please see lines 550-551 in hlo_ops.cc where I'm explicitly checking for splat attributes. For the rest that you mention below, please see my post below on why having one form is important.

HLO (since it will always be folded to fully materialized). Having that structure (not unconditionally materializing scalar broadcasts) is fairly important from an analysis on tensors perspective as having this knowledge enables various size-sensitive optimizations.

I agree that this is a valid folding from a correctness perspective, but it limits what we can express in mhlo (since this is the current idiom for representing a scalar broadcast and is used quite extensively).

I'm struggling with an "HLO principles" reason why this folding should be disallowed and the following allowed, though. This may actually be pointing at a missing op (i.e. mhlo.broadcast_scalar) if this case is important to represent at this level.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for bearing with me (and side note: I'm sorry to have triggered another round of review on such a simple PR -- I know how frustrating that can be).

Now re-reading this, I realize I was missing the syntactic distinction of the splat being preserved in this transformation (while more verbose, I did appreciate that we used to actually spell out "splat"). I had gotten used to keying off of the 0d-ness of the tensor as the distinguishing characteristic while treating the originating SplatElementsAttr as incidental. So you are right: this preserves the same information when considering the attribute type of the constant and as long as the splat is handled correctly all the way down, then this can be fine (modulo bugs).

As one of the things that we've found challenging in the past, the IREE team is a bit wired to see unnecessary constant expansions as problematic, and the bias there certainly kept me from seeing this alternative simplification.

%b = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x64x224x224xf32>
return %b : tensor<1x64x224x224xf32>
}
// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x224x224xf32>
// CHECK-NEXT: return %[[CST]] : tensor<1x64x224x224xf32>

// CHECK-LABEL: func @broadcast_in_dim_constant_fold
func @broadcast_in_dim_constant_fold() -> tensor<1x64x4x4xf32> {
%cst = mhlo.constant dense<0.000000e+00> : tensor<4x4xf32>
%b = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<1x64x4x4xf32>
return %b : tensor<1x64x4x4xf32>
}
// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x4x4xf32>
// CHECK-NEXT: return %[[CST]] : tensor<1x64x4x4xf32>

// CHECK-LABEL: @complex_expand_fold
func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)
Expand Down
7 changes: 3 additions & 4 deletions tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ func @batchNormInference_2D_inner_features(
%x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
-> (tensor<4x256xf32>) {
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.001000e-05> : tensor<f32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = mhlo.constant dense<1.001000e-05> : tensor<256xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
Expand Down Expand Up @@ -51,7 +50,7 @@ func @batchNormInference_4D_middle_features(
// -----
// CHECK-LABEL: @batchNormInference_f64
// Validate that epsilon is properly promoted to f64
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f64>
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<256xf64>
func @batchNormInference_f64(
%x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>,
%mean: tensor<256xf64>, %variance: tensor<256xf64>)
Expand All @@ -66,7 +65,7 @@ func @batchNormInference_f64(
// -----
// CHECK-LABEL: @batchNormInference_f16
// Validate that epsilon is properly promoted to f64
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f16>
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<256xf16>
func @batchNormInference_f16(
%x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>,
%mean: tensor<256xf16>, %variance: tensor<256xf16>)
Expand Down