Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add algorithmic optimizer to convert Log(Softmax(x)) to LogSoftmax(x) #25455

Merged
merged 1 commit into from Mar 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorflow/core/grappler/op_types.cc
Expand Up @@ -447,6 +447,8 @@ bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }

bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot"; }

bool IsSoftmax(const NodeDef& node) { return node.op() == "Softmax"; }

bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad"; }

bool IsSoftsignGrad(const NodeDef& node) { return node.op() == "SoftsignGrad"; }
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/grappler/op_types.h
Expand Up @@ -148,6 +148,7 @@ bool IsShapeN(const NodeDef& node);
bool IsShuffle(const NodeDef& node);
bool IsSigmoidGrad(const NodeDef& node);
bool IsSnapshot(const NodeDef& node);
bool IsSoftmax(const NodeDef& node);
bool IsSoftplusGrad(const NodeDef& node);
bool IsSoftsignGrad(const NodeDef& node);
bool IsSplit(const NodeDef& node);
Expand Down
32 changes: 32 additions & 0 deletions tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
Expand Up @@ -1804,6 +1804,36 @@ class FuseSquaredDiffStage : public ArithmeticOptimizerStage {
}
};

// Performs the conversion:
// Log(Softmax(x)) => LogSoftmax(x)
class LogSoftmaxStage : public ArithmeticOptimizerStage {
public:
explicit LogSoftmaxStage(const GraphOptimizerContext& ctx,
const ArithmeticOptimizerContext& ctx_ext)
: ArithmeticOptimizerStage("LogSoftmaxStage", ctx, ctx_ext) {}
~LogSoftmaxStage() override = default;

bool IsSupported(const NodeDef* node) const override {
return IsLog(*node);
}

Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
NodeDef* x;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
// Optimize only if arg is a Softmax whose output is not being consumed
// elsewhere.
if (IsSoftmax(*x) && !IsInPreserveSet(*x) &&
(NumNonControlOutputs(*x, *ctx().node_map) == 1)) {
// Log(Softmax(x)) => LogSoftmax(Identity(x))
node->set_op("LogSoftmax");
x->set_op("Identity");
AddToOptimizationQueue(node);
AddToOptimizationQueue(x);
}
return Status::OK();
}
};

// Bypass redundant reshape nodes:
//
// Reshape Reshape <-+
Expand Down Expand Up @@ -3552,6 +3582,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
if (options_.convert_log1p)
pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
if (options_.convert_log_softmax)
pipeline.AddStage<LogSoftmaxStage>(ctx, ctx_ext);
if (options_.optimize_max_or_min_of_monotonic)
pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
if (options_.convert_expm1)
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
Expand Up @@ -79,6 +79,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool simplify_aggregation = true;
bool convert_pow = true;
bool convert_log1p = true;
bool convert_log_softmax = true;
bool convert_expm1 = true;
bool unary_ops_composition = true;
bool remove_stack_strided_slice_same_axis = true;
Expand Down
61 changes: 61 additions & 0 deletions tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
Expand Up @@ -2556,6 +2556,67 @@ TEST_F(ArithmeticOptimizerTest, DoNotFuseSquaredDiffFetchNode) {
}
}

TEST_F(ArithmeticOptimizerTest, ConvertLogSoftmax) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
Output softmax = ops::Softmax(s.WithOpName("softmax"), x);
Output logsoftmax = ops::Log(s.WithOpName("output"), softmax);

GrapplerItem item;
item.fetch = {"output"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());

GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyLogSoftmax(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
const auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());

test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
if (node.name() == "output") {
EXPECT_EQ("LogSoftmax", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("x", node.input(0));
}
}
}

TEST_F(ArithmeticOptimizerTest, DoNotConvertLogSoftmaxArgFetchNode) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output floats = ops::Const(s.WithOpName("floats"),
{0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
Output softmax = ops::Softmax(s.WithOpName("softmax"), floats);
Output final_output = ops::Log(s.WithOpName("final_output"), softmax);

GrapplerItem item;
item.fetch = {"softmax", "final_output"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
ASSERT_EQ(2, tensors_expected.size());

GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyLogSoftmax(&optimizer);
OptimizeTwice(&optimizer, &item, &output);
const auto tensors = EvaluateNodes(output, item.fetch);
ASSERT_EQ(2, tensors.size());

// Should be a NoOp since we are not allowed to change the output of fetch
// nodes.
VerifyGraphsMatch(item.graph, output, __LINE__);

for (int i = 0; i < tensors.size(); i++) {
EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
}
}

TEST_F(ArithmeticOptimizerTest, ConvertPow) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
Expand Down
Expand Up @@ -183,6 +183,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true;
}

void EnableOnlyLogSoftmax(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.convert_log_softmax = true;
}

void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.convert_pow = true;
Expand Down