Skip to content

Commit

Permalink
new implementation of xla_lhlo.abs/negate without modify mlir standar…
Browse files Browse the repository at this point in the history
…d dialect
  • Loading branch information
qqsun8819 committed Apr 10, 2020
1 parent 035e216 commit 846e2f6
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 4 deletions.
33 changes: 33 additions & 0 deletions tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,23 @@ func @abs(%input: memref<2x2xf32>,

// -----

func @abs(%input: memref<2x2xi32>,
%result: memref<2x2xi32>) {
"xla_lhlo.abs"(%input, %result)
: (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}

// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[L0:.*]] = constant 0 : i32
// CHECK-NEXT: %[[L1:.*]] = cmpi "sge", %[[OPERAND_IN]], %[[L0]] : i32
// CHECK-NEXT: %[[L2:.*]] = subi %[[L0]], %[[OPERAND_IN]] : i32
// CHECK-NEXT: %[[RESULT:.*]] = select %[[L1]], %[[OPERAND_IN]], %[[L2]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32

// -----

// CHECK-LABEL: func @ceil
func @ceil(%input: memref<2x2xf32>,
%result: memref<2x2xf32>) {
Expand Down Expand Up @@ -401,6 +418,22 @@ func @neg(%input: memref<2x2xf32>,

// -----

// CHECK-LABEL: func @neg
func @neg(%input: memref<2x2xi32>,
%result: memref<2x2xi32>) {
"xla_lhlo.negate"(%input, %result)
: (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}

// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[L0:.*]] = constant 0 : i32
// CHECK-NEXT: %[[RESULT:.*]] = subi %[[L0]], %[[OPERAND_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32

// -----

// CHECK-LABEL: func @rem
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
Expand Down
37 changes: 33 additions & 4 deletions tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,23 @@ template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
loc, result_types, args, b);
Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
loc, result_types, args, b);
}
if (element_type.isa<IntegerType>()) {
const auto& lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>();

auto zero_intval = b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(
loc, CmpIPredicate::sge, lhs, zero_intval);
auto neg_val = b->create<ScalarIOp<xla_lhlo::SubOp>>(
loc, zero_intval, lhs);
return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val);
}
return nullptr;
}

template <>
Expand Down Expand Up @@ -326,8 +341,22 @@ template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
loc, result_types, args, b);
Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
loc, result_types, args, b);

}
if (element_type.isa<IntegerType>()) {
const auto& lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>();

auto zero_intval = b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
return b->create<ScalarIOp<xla_lhlo::SubOp>>(
loc, zero_intval, lhs);

}
return nullptr;
}

template <>
Expand Down

0 comments on commit 846e2f6

Please sign in to comment.