Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reopen: Generalize MinMax monotonic optimizer #25584

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow/core/grappler/op_types.cc
Expand Up @@ -47,16 +47,30 @@ bool IsAnyDiv(const NodeDef& node) {
node.op() == "FloorDiv" || node.op() == "TruncateDiv";
}

bool IsAnyMax(const NodeDef& node) {
const auto& op = node.op();
return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax";
}

bool IsAnyMaxPool(const NodeDef& node) {
const auto& op = node.op();
return op == "MaxPool" || op == "MaxPoolV2" || op == "MaxPool3D" ||
op == "MaxPoolWithArgmax" || op == "FractionalMaxPool";
}

bool IsAnyMin(const NodeDef& node) {
const auto& op = node.op();
return op == "Min" || op == "SegmentMin" || op == "UnsortedSegmentMin";
}

bool IsApproximateEqual(const NodeDef& node) {
return node.op() == "ApproximateEqual";
}

bool IsArgMax(const NodeDef& node) { return node.op() == "ArgMax"; }

bool IsArgMin(const NodeDef& node) { return node.op() == "ArgMin"; }

bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }

bool IsAssign(const NodeDef& node) {
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/grappler/op_types.h
Expand Up @@ -28,8 +28,12 @@ bool IsAll(const NodeDef& node);
bool IsAngle(const NodeDef& node);
bool IsAny(const NodeDef& node);
bool IsAnyDiv(const NodeDef& node);
bool IsAnyMax(const NodeDef& node);
bool IsAnyMaxPool(const NodeDef& node);
bool IsAnyMin(const NodeDef& node);
bool IsApproximateEqual(const NodeDef& node);
bool IsArgMax(const NodeDef& node);
bool IsArgMin(const NodeDef& node);
bool IsAvgPoolGrad(const NodeDef& node);
bool IsAssert(const NodeDef& node);
bool IsAssign(const NodeDef& node);
Expand Down
22 changes: 20 additions & 2 deletions tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
Expand Up @@ -53,6 +53,7 @@ limitations under the License.
#include "tensorflow/core/util/strided_slice_op.h"

using tensorflow::strings::StrCat;
using tensorflow::str_util::StringReplace;

namespace tensorflow {
namespace grappler {
Expand Down Expand Up @@ -2721,7 +2722,8 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
~OptimizeMaxOrMinOfMonotonicStage() override = default;

bool IsSupported(const NodeDef* node) const override {
return IsMax(*node) || IsMin(*node) || IsAnyMaxPool(*node);
return IsAnyMax(*node) || IsAnyMin(*node) || IsAnyMaxPool(*node) ||
IsArgMax(*node) || IsArgMin(*node);
}

Status TrySimplify(NodeDef* reduction_node,
Expand Down Expand Up @@ -2755,9 +2757,15 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
if (!is_non_decreasing) {
// Flip Min<->Max if the function is non-increasing, e.g.
// Max(Neg(x)) = Neg(Min(x)).
const string opposite = IsMax(*reduction_node) ? "Min" : "Max";
const string opposite = FlipMinMax(*reduction_node);
reduction_node->set_op(opposite);
}

if (IsArgMax(*reduction_node) || IsArgMin(*reduction_node)) {
// ArgMax(Sqrt(x)) = ArgMax(x)
inner_function->set_op("Identity");
}

AddToOptimizationQueue(reduction_node);
AddToOptimizationQueue(inner_function);
AddToOptimizationQueue(inner_input);
Expand All @@ -2778,6 +2786,16 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
AddToOptimizationQueue(consumer);
}
}

private:
string FlipMinMax(const NodeDef& node) {
const string& op = node.op();
if (IsAnyMax(node) || IsArgMax(node)) {
return str_util::StringReplace(op, "Max", "Min", false);
} else {
return str_util::StringReplace(op, "Min", "Max", false);
}
}
};

// Replace a chain of type&shape preserving unary ops with a
Expand Down
41 changes: 41 additions & 0 deletions tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
Expand Up @@ -3466,6 +3466,47 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
EXPECT_EQ(2, required_node_count);
}

TEST_F(ArithmeticOptimizerTest, OptimizeArgMaxOrArgMinOfMonotonicElementWise) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
const auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
Output arg_max = ops::ArgMax(s.WithOpName("arg_max"), sqrt, 1);
Output final_out = ops::Identity(s.WithOpName("final_out"), arg_max);

GrapplerItem item;
item.fetch = {"final_out"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());

GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
const auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());

test::ExpectTensorEqual<int64>(tensors_expected[0], tensors[0]);
EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
// Check if the inputs are switched
int required_node_count = 0;
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
if (node.name() == "final_out") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("arg_max", node.input(0));
++required_node_count;
} else if (node.name() == "arg_max") {
EXPECT_EQ("ArgMax", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
++required_node_count;
}
}
EXPECT_EQ(2, required_node_count);
}

TEST_F(ArithmeticOptimizerTest,
OptimizeMaxOrMinOfMonotonicElementWise_DoNotChangeFetchNode) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Expand Down