Skip to content

Commit

Permalink
Lower ReluGrad to HLO.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 273772952
  • Loading branch information
tensorflower-gardener committed Oct 9, 2019
1 parent a68cb21 commit 720a3a1
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 0 deletions.
18 changes: 18 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -3311,6 +3311,24 @@ def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> {
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}

def TF_ReluGradOp : TF_Op<"ReluGrad", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes rectified linear gradients for a Relu operation.";

let description = [{
}];

let arguments = (ins
TF_IntOrFpTensor:$gradients,
TF_IntOrFpTensor:$features
);

let results = (outs
TF_IntOrFpTensor:$backprops
);

TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}

def TF_ReshapeOp : TF_Op<"Reshape", [NoSideEffect]> {
let summary = "Reshapes a tensor.";

Expand Down
26 changes: 26 additions & 0 deletions tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -702,5 +702,31 @@ static LogicalResult Verify(TransposeOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// CompareOp
//===----------------------------------------------------------------------===//

void CompareOp::build(Builder* builder, OperationState& result, Value* lhs,
Value* rhs, DenseIntElementsAttr broadcast_dimensions,
StringAttr comparison_direction) {
build(builder, result,
InferOutputTypes(builder, lhs, rhs, broadcast_dimensions,
comparison_direction),
lhs, rhs, broadcast_dimensions, comparison_direction);
}

Type CompareOp::InferOutputTypes(Builder* builder, Value* lhs, Value* rhs,
DenseIntElementsAttr broadcast_dimensions,
StringAttr comparison_direction) {
if (!lhs->getType().isa<ShapedType>() || !rhs->getType().isa<ShapedType>())
return builder->getTensorType(builder->getI1Type());
// TODO(parkers): When binary ops support broadcasting shape inference, reuse
// that logic.
auto lhs_type = lhs->getType().cast<ShapedType>();
auto rhs_type = rhs->getType().cast<ShapedType>();
if (lhs_type != rhs_type) return builder->getTensorType(builder->getI1Type());
return builder->getTensorType(lhs_type.getShape(), builder->getI1Type());
}

#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.cc.inc"
12 changes: 12 additions & 0 deletions tensorflow/compiler/mlir/xla/ir/hlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,18 @@ def HLO_CompareOp: HLO_Op<"compare",
HLO_ComparisonDirectionAttr:$comparison_direction
);
let results = (outs HLO_PredTensor);

let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value *lhs, Value *rhs, "
"DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction"
>];

let extraClassDeclaration = [{
// Infers output type for given operand and attributes.
static Type InferOutputTypes(Builder *builder, Value *lhs, Value *rhs,
DenseIntElementsAttr broadcast_dimensions,
StringAttr comparison_direction);
}];
}

//===----------------------------------------------------------------------===//
Expand Down
11 changes: 11 additions & 0 deletions tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,17 @@ func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
return %0: tensor<1xi32>
}

// CHECK-LABEL: func @relu_grad
// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<4x8xf32>)
func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor<4x8xf32>) -> tensor<4x8xf32> {
// CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32>
// CHECK: %[[PRED:.*]] = "xla_hlo.compare"(%[[FEATURES]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xi1>
// CHECK: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<4x8xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
// CHECK: return %[[RESULT]] : tensor<4x8xf32>
%2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
return %2 : tensor<4x8xf32>
}

//===----------------------------------------------------------------------===//
// Select op legalizations.
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,13 @@ def : Pat<(TF_Relu6Op AnyStaticShapeTensor:$input),
(HLO_ClampOp (HLO_ConstOp (ConstantSplat<"0"> $input)), $input,
(HLO_ConstOp (ConstantSplat<"6"> $input)))>;

// ReluGrad(gradients, features) = gradients * (features > 0)
def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyStaticShapeTensor:$features),
(HLO_SelectOp
(HLO_CompareOp $features, (HLO_ConstOp:$zero (ConstantSplat<"0"> $features)),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT),
$gradients, $zero)>;

//===----------------------------------------------------------------------===//
// Relu op patterns.
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 720a3a1

Please sign in to comment.