Skip to content

Commit

Permalink
Merge pull request #54810 from Intel-tensorflow:gyshi/fiex_remapper_bug
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 432994733
  • Loading branch information
tensorflower-gardener committed Mar 7, 2022
2 parents 3966157 + 42191f0 commit 71151ec
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 8 deletions.
29 changes: 21 additions & 8 deletions tensorflow/core/grappler/optimizers/remapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2762,13 +2762,19 @@ Status AddFusedBatchMatMul(RemapperContext* ctx,
return Status::OK();
}

// This function supports below patterns that require inferred
// shapes:
// 1. Contraction + Add.
// 2. Contraction + Add + Activation.
// 3. Contraction + BiasAdd/BiasSemanticAdd + Add.
// 4. Contraction + BiasAdd/BiasSemanticAdd + Add + Activation.
// Contraction candidate: MatMul, Conv2D, Conv3D, DepthwiseConv2dNative.
bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) {
const auto* node_view = ctx.graph_view.GetNode(node_index);

// Candidate for Contraction + Add or Contraction + BiasAdd + Add fusion.
// Contraction candidate: MatMul, Conv2D, DepthwiseConv2dNative
auto is_supported_add_input = [](const auto* node_view) -> bool {
if (IsConvOrMatMul(*node_view->node())) return true;
// IsAdd will verify BiasSemanticAdd.
if (IsBiasAdd(*node_view->node()) || IsAdd(*node_view->node())) {
if (node_view->NumRegularFanins() < 2) return false;
const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0);
Expand All @@ -2791,14 +2797,21 @@ bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) {
return false;
};

bool ret = false;
for (int i = 0; i < node_view->NumRegularFanins(); i++) {
const auto& fanin_i = node_view->GetRegularFanin(i);
ret = is_supported_add(fanin_i.node_view());
if (ret) break;
// Dealing with the Contraction + Add or Contraction + BiasAdd/BiasSemanticAdd
// + Add patterns.
if (is_supported_add(node_view)) {
return true;
}
// Dealing with the Contraction + Add + Activation or Contraction +
// BiasAdd/BiasSemanticAdd + Add + Activation pattern.
if (IsSupportedActivation(*node_view->node())) {
for (int i = 0; i < node_view->NumRegularFanins(); i++) {
const auto& fanin_i = node_view->GetRegularFanin(i);
if (is_supported_add(fanin_i.node_view())) return true;
}
}

return ret;
return false;
}

// Check if a node is a candidate to one of the patterns that require inferred
Expand Down
98 changes: 98 additions & 0 deletions tensorflow/core/grappler/optimizers/remapper_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1769,6 +1769,104 @@ TEST_F(RemapperTest, FuseConv3DWithBiasAndAddActivation) {
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
}
}

// Conv2D + Add {6,} + Conv2D + Biasadd fusion.
TEST_F(RemapperTest, FuseConv2DWithSemanticAdd) {
if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to MKL.";
using ::tensorflow::ops::Placeholder;
tensorflow::Scope s = tensorflow::Scope::NewRootScope();

auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3});
auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 6});
auto filter_shape_1 = ops::Placeholder::Shape({1, 1, 6, 6});
auto semanticadd_shape = ops::Placeholder::Shape({6});
auto bias_shape = ops::Placeholder::Shape({6});

auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
auto filter_1 =
Placeholder(s.WithOpName("filter_1"), DT_FLOAT, filter_shape_1);
auto semanticadd =
Placeholder(s.WithOpName("semanticadd"), DT_FLOAT, semanticadd_shape);
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);

std::vector<int> strides = {1, 1, 1, 1};
auto conv =
ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "VALID");
auto add = ops::Add(s.WithOpName("add"), semanticadd, conv);
auto conv_1 =
ops::Conv2D(s.WithOpName("conv_1"), add, filter_1, strides, "VALID");
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv_1, bias);
auto fetch = ops::Identity(s.WithOpName("fetch"), bias_add);

auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(input_shape.shape_.dim_sizes()));
auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(filter_shape.shape_.dim_sizes()));
auto filter_tensor_1 = GenerateRandomTensor<DT_FLOAT>(
TensorShape(filter_shape_1.shape_.dim_sizes()));
auto semanticadd_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(semanticadd_shape.shape_.dim_sizes()));
auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(bias_shape.shape_.dim_sizes()));

GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_tensor},
{"filter", filter_tensor},
{"filter_1", filter_tensor_1},
{"semanticadd", semanticadd_tensor},
{"bias", bias_tensor}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));

// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}

Remapper optimizer(RewriterConfig::AGGRESSIVE);
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));

int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "bias_add") {
EXPECT_EQ(node.op(), "_FusedConv2D");
ASSERT_GE(node.input_size(), 3);
EXPECT_EQ(node.input(0), "add");
EXPECT_EQ(node.input(1), "filter_1");

EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "bias");

const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(fused_ops.size(), 1);
EXPECT_EQ(fused_ops[0], "BiasAdd");
found++;
}
if (node.name() == "add") {
EXPECT_EQ(node.op(), "_FusedConv2D");
ASSERT_GE(node.input_size(), 3);
EXPECT_EQ(node.input(0), "input");
EXPECT_EQ(node.input(1), "filter");

EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "semanticadd");

const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(fused_ops.size(), 1);
EXPECT_EQ(fused_ops[0], "BiasAdd");
found++;
}
}
EXPECT_EQ(found, 2);

auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
}
#endif

} // namespace grappler
Expand Down

0 comments on commit 71151ec

Please sign in to comment.