From 597f02cac4ff857aafde6ff8cc73436d110571e5 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Sat, 2 Feb 2019 10:49:56 +0100 Subject: [PATCH] Add algorithmic optimizer to convert Log(Softmax(x)) to LogSoftmax(x) This PR adds an algorithmic optimizer which converts `Log(Softmax(x))` to `LogSoftmax(x)`. [`LogSoftmax`](https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/log-softmax) is numerically more stable and may be a bit faster in some cases. This could be expanded in the future to also optimize `log(softmax(x) * y) = logsoftmax(x) + log(y)` and `log(softmax(x) / y) = logsoftmax(x) - log(y)`. --- tensorflow/core/grappler/op_types.cc | 2 + tensorflow/core/grappler/op_types.h | 1 + .../optimizers/arithmetic_optimizer.cc | 32 ++++++++++ .../optimizers/arithmetic_optimizer.h | 1 + .../optimizers/arithmetic_optimizer_test.cc | 61 +++++++++++++++++++ .../arithmetic_optimizer_test_utils.h | 5 ++ 6 files changed, 102 insertions(+) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 22c81f350ff76c..7dd15de279089d 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -447,6 +447,8 @@ bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; } bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot"; } +bool IsSoftmax(const NodeDef& node) { return node.op() == "Softmax"; } + bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad"; } bool IsSoftsignGrad(const NodeDef& node) { return node.op() == "SoftsignGrad"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 6dccdba16189fa..e9df1cc0114361 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -148,6 +148,7 @@ bool IsShapeN(const NodeDef& node); bool IsShuffle(const NodeDef& node); bool IsSigmoidGrad(const NodeDef& node); bool IsSnapshot(const NodeDef& node); +bool IsSoftmax(const NodeDef& node); bool IsSoftplusGrad(const NodeDef& node); bool IsSoftsignGrad(const NodeDef& node); bool IsSplit(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 442b713bd09fe4..f754cfd67604c9 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1804,6 +1804,36 @@ class FuseSquaredDiffStage : public ArithmeticOptimizerStage { } }; +// Performs the conversion: +// Log(Softmax(x)) => LogSoftmax(x) +class LogSoftmaxStage : public ArithmeticOptimizerStage { + public: + explicit LogSoftmaxStage(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("LogSoftmaxStage", ctx, ctx_ext) {} + ~LogSoftmaxStage() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsLog(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + NodeDef* x; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x)); + // Optimize only if arg is a Softmax whose output is not being consumed + // elsewhere. + if (IsSoftmax(*x) && !IsInPreserveSet(*x) && + (NumNonControlOutputs(*x, *ctx().node_map) == 1)) { + // Log(Softmax(x)) => LogSoftmax(Identity(x)) + node->set_op("LogSoftmax"); + x->set_op("Identity"); + AddToOptimizationQueue(node); + AddToOptimizationQueue(x); + } + return Status::OK(); + } +}; + // Bypass redundant reshape nodes: // // Reshape Reshape <-+ @@ -3552,6 +3582,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { if (options_.convert_pow) pipeline.AddStage(ctx, ctx_ext); if (options_.convert_log1p) pipeline.AddStage(ctx, ctx_ext); + if (options_.convert_log_softmax) + pipeline.AddStage(ctx, ctx_ext); if (options_.optimize_max_or_min_of_monotonic) pipeline.AddStage(ctx, ctx_ext); if (options_.convert_expm1) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 4535b27ae6b2c2..0330480db3ca3d 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -79,6 +79,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool simplify_aggregation = true; bool convert_pow = true; bool convert_log1p = true; + bool convert_log_softmax = true; bool convert_expm1 = true; bool unary_ops_composition = true; bool remove_stack_strided_slice_same_axis = true; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 56e37aade91b32..11fd91e7588d43 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -2556,6 +2556,67 @@ TEST_F(ArithmeticOptimizerTest, DoNotFuseSquaredDiffFetchNode) { } } +TEST_F(ArithmeticOptimizerTest, ConvertLogSoftmax) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output softmax = ops::Softmax(s.WithOpName("softmax"), x); + Output logsoftmax = ops::Log(s.WithOpName("output"), softmax); + + GrapplerItem item; + item.fetch = {"output"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + const auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyLogSoftmax(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + const auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + + test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); + EXPECT_EQ(item.graph.node_size() - 1, output.node_size()); + for (int i = 0; i < output.node_size(); ++i) { + const NodeDef& node = output.node(i); + if (node.name() == "output") { + EXPECT_EQ("LogSoftmax", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("x", node.input(0)); + } + } +} + +TEST_F(ArithmeticOptimizerTest, DoNotConvertLogSoftmaxArgFetchNode) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output floats = ops::Const(s.WithOpName("floats"), + {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3}); + Output softmax = ops::Softmax(s.WithOpName("softmax"), floats); + Output final_output = ops::Log(s.WithOpName("final_output"), softmax); + + GrapplerItem item; + item.fetch = {"softmax", "final_output"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + const auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + ASSERT_EQ(2, tensors_expected.size()); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyLogSoftmax(&optimizer); + OptimizeTwice(&optimizer, &item, &output); + const auto tensors = EvaluateNodes(output, item.fetch); + ASSERT_EQ(2, tensors.size()); + + // Should be a NoOp since we are not allowed to change the output of fetch + // nodes. + VerifyGraphsMatch(item.graph, output, __LINE__); + + for (int i = 0; i < tensors.size(); i++) { + EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements()); + test::ExpectTensorNear(tensors_expected[i], tensors[i], 1e-6); + } +} + TEST_F(ArithmeticOptimizerTest, ConvertPow) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h index cd9950e17a2d97..0358d7f5409865 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h @@ -183,6 +183,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true; } + void EnableOnlyLogSoftmax(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_log_softmax = true; + } + void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.convert_pow = true;