Skip to content

Commit

Permalink
[mhlo] PadOp::fold: use RTTI to dispatch DenseElementAttr creation
Browse files Browse the repository at this point in the history
Fixes wrong output caused by incorrect `mhlo.pad` canonicalization for complex numbers.

PiperOrigin-RevId: 449066943
  • Loading branch information
atondwal authored and tensorflower-gardener committed May 16, 2022
1 parent 0ea23da commit 39ad008
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 27 deletions.
79 changes: 52 additions & 27 deletions tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5191,31 +5191,15 @@ LogicalResult PadOp::inferReturnTypeComponents(
return success();
}

OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
// If all padding is zero then it is an identity pad.
auto is_zero = [](const APInt& i) { return i == 0; };
if (llvm::all_of(edge_padding_low().getValues<APInt>(), is_zero) &&
llvm::all_of(edge_padding_high().getValues<APInt>(), is_zero) &&
llvm::all_of(interior_padding().getValues<APInt>(), is_zero))
return operand();

// If any padding is negative then it isn't supported by the folder (yet).
auto is_negative = [](const APInt& i) { return i.slt(0); };
if (llvm::any_of(edge_padding_low().getValues<APInt>(), is_negative) ||
llvm::any_of(edge_padding_high().getValues<APInt>(), is_negative) ||
llvm::any_of(interior_padding().getValues<APInt>(), is_negative))
return {};

DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
DenseElementsAttr padding = operands[1].dyn_cast_or_null<DenseElementsAttr>();
RankedTensorType return_type = getType().dyn_cast_or_null<RankedTensorType>();
if (!input || !input.getType().hasRank() || !padding || !return_type ||
!return_type.hasStaticShape())
return {};

template <typename T>
OpFoldResult PadOpFoldHelper(DenseElementsAttr input, DenseElementsAttr padding,
RankedTensorType return_type,
DenseIntElementsAttr edge_padding_low,
DenseIntElementsAttr edge_padding_high,
DenseIntElementsAttr interior_padding) {
// Fill the full result tensor with the padding value.
llvm::SmallVector<Attribute, 4> result(return_type.getNumElements(),
padding.getValues<Attribute>()[0]);
llvm::SmallVector<T, 4> result(return_type.getNumElements(),
padding.getValues<T>()[0]);

auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
llvm::ArrayRef<int64_t> shape) {
Expand All @@ -5235,17 +5219,58 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
uint64_t idx_multiplyer = 1;
for (int64_t i = index.size() - 1; i >= 0; --i) {
result_idx +=
(edge_padding_low().getValues<int64_t>()[i] +
index[i] * (interior_padding().getValues<int64_t>()[i] + 1)) *
(edge_padding_low.getValues<int64_t>()[i] +
index[i] * (interior_padding.getValues<int64_t>()[i] + 1)) *
idx_multiplyer;
idx_multiplyer *= return_type.getDimSize(i);
}
result[result_idx] = input.getValues<Attribute>()[index];
result[result_idx] = input.getValues<T>()[index];
next_index(index, input.getType().getShape());
}
return DenseElementsAttr::get(return_type, result);
}

OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
// If all padding is zero then it is an identity pad.
auto is_zero = [](const APInt& i) { return i == 0; };
if (llvm::all_of(edge_padding_low().getValues<APInt>(), is_zero) &&
llvm::all_of(edge_padding_high().getValues<APInt>(), is_zero) &&
llvm::all_of(interior_padding().getValues<APInt>(), is_zero))
return operand();

// If any padding is negative then it isn't supported by the folder (yet).
auto is_negative = [](const APInt& i) { return i.slt(0); };
if (llvm::any_of(edge_padding_low().getValues<APInt>(), is_negative) ||
llvm::any_of(edge_padding_high().getValues<APInt>(), is_negative) ||
llvm::any_of(interior_padding().getValues<APInt>(), is_negative))
return {};

DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
DenseElementsAttr padding = operands[1].dyn_cast_or_null<DenseElementsAttr>();
RankedTensorType return_type = getType().dyn_cast_or_null<RankedTensorType>();
if (!input || !input.getType().hasRank() || !padding || !return_type ||
!return_type.hasStaticShape())
return {};

if (return_type.getElementType().isa<IntegerType>())
return PadOpFoldHelper<APInt>(input, padding, return_type,
edge_padding_low(), edge_padding_high(),
interior_padding());
if (return_type.getElementType().isa<FloatType>())
return PadOpFoldHelper<APFloat>(input, padding, return_type,
edge_padding_low(), edge_padding_high(),
interior_padding());
if (ComplexType complex =
return_type.getElementType().dyn_cast_or_null<ComplexType>()) {
// TODO(atondwal): Allow int types in HLO_complex
if (complex.getElementType().isa<FloatType>())
return PadOpFoldHelper<std::complex<APFloat>>(
input, padding, return_type, edge_padding_low(), edge_padding_high(),
interior_padding());
}
return {};
}

//===----------------------------------------------------------------------===//
// DynamicPadOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2026,6 +2026,7 @@ func.func @pad_negative_fold() -> tensor<4x4xi32> {
// CHECK: "mhlo.pad"
}

// CHECK-LABEL: @pad_fold_zero_elements
func.func @pad_fold_zero_elements() -> tensor<3xi32> {
%0 = mhlo.constant dense<> : tensor<0xi32>
%1 = mhlo.constant dense<7> : tensor<i32>
Expand All @@ -2034,6 +2035,24 @@ func.func @pad_fold_zero_elements() -> tensor<3xi32> {
// CHECK: mhlo.constant dense<7> : tensor<3xi32>
}

// CHECK-LABEL: @pad_float_fold
func.func @pad_float_fold() -> tensor<2xf32> {
%0 = mhlo.constant dense<2.000000e+00> : tensor<1xf32>
%1 = mhlo.constant dense<1.000000e+00> : tensor<f32>
%2 = "mhlo.pad"(%0, %1) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<1xf32>, tensor<f32>) -> tensor<2xf32>
return %2 : tensor<2xf32>
// CHECK: mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32>
}

// CHECK-LABEL: @pad_complex_fold
func.func @pad_complex_fold() -> tensor<2xcomplex<f32>> {
%0 = mhlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor<1xcomplex<f32>>
%1 = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>
%2 = "mhlo.pad"(%0, %1) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<1xcomplex<f32>>, tensor<complex<f32>>) -> tensor<2xcomplex<f32>>
return %2 : tensor<2xcomplex<f32>>
// CHECK: mhlo.constant dense<[(2.000000e+00,0.000000e+00), (1.000000e+00,0.000000e+00)]> : tensor<2xcomplex<f32>>
}

// CHECK-LABEL: @identity_broadcast_reshape
func.func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
%0 = "mhlo.broadcast"(%arg0) {
Expand Down

0 comments on commit 39ad008

Please sign in to comment.