Skip to content

Commit

Permalink
add conv + biassemanticadd + conv2d + biasadd pattern test
Browse files Browse the repository at this point in the history
  • Loading branch information
gyshi committed Mar 2, 2022
1 parent 63a1543 commit 42191f0
Showing 1 changed file with 98 additions and 0 deletions.
98 changes: 98 additions & 0 deletions tensorflow/core/grappler/optimizers/remapper_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1769,6 +1769,104 @@ TEST_F(RemapperTest, FuseConv3DWithBiasAndAddActivation) {
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
}
}

// Conv2D + Add {6,} + Conv2D + Biasadd fusion.
TEST_F(RemapperTest, FuseConv2DWithSemanticAdd) {
if (!IsMKLEnabled()) GTEST_SKIP() << "Test only applicable to MKL.";
using ::tensorflow::ops::Placeholder;
tensorflow::Scope s = tensorflow::Scope::NewRootScope();

auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3});
auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 6});
auto filter_shape_1 = ops::Placeholder::Shape({1, 1, 6, 6});
auto semanticadd_shape = ops::Placeholder::Shape({6});
auto bias_shape = ops::Placeholder::Shape({6});

auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
auto filter_1 =
Placeholder(s.WithOpName("filter_1"), DT_FLOAT, filter_shape_1);
auto semanticadd =
Placeholder(s.WithOpName("semanticadd"), DT_FLOAT, semanticadd_shape);
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);

std::vector<int> strides = {1, 1, 1, 1};
auto conv =
ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "VALID");
auto add = ops::Add(s.WithOpName("add"), semanticadd, conv);
auto conv_1 =
ops::Conv2D(s.WithOpName("conv_1"), add, filter_1, strides, "VALID");
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv_1, bias);
auto fetch = ops::Identity(s.WithOpName("fetch"), bias_add);

auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(input_shape.shape_.dim_sizes()));
auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(filter_shape.shape_.dim_sizes()));
auto filter_tensor_1 = GenerateRandomTensor<DT_FLOAT>(
TensorShape(filter_shape_1.shape_.dim_sizes()));
auto semanticadd_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(semanticadd_shape.shape_.dim_sizes()));
auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(bias_shape.shape_.dim_sizes()));

GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_tensor},
{"filter", filter_tensor},
{"filter_1", filter_tensor_1},
{"semanticadd", semanticadd_tensor},
{"bias", bias_tensor}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));

// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}

Remapper optimizer(RewriterConfig::AGGRESSIVE);
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));

int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "bias_add") {
EXPECT_EQ(node.op(), "_FusedConv2D");
ASSERT_GE(node.input_size(), 3);
EXPECT_EQ(node.input(0), "add");
EXPECT_EQ(node.input(1), "filter_1");

EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "bias");

const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(fused_ops.size(), 1);
EXPECT_EQ(fused_ops[0], "BiasAdd");
found++;
}
if (node.name() == "add") {
EXPECT_EQ(node.op(), "_FusedConv2D");
ASSERT_GE(node.input_size(), 3);
EXPECT_EQ(node.input(0), "input");
EXPECT_EQ(node.input(1), "filter");

EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "semanticadd");

const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(fused_ops.size(), 1);
EXPECT_EQ(fused_ops[0], "BiasAdd");
found++;
}
}
EXPECT_EQ(found, 2);

auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
}
#endif

} // namespace grappler
Expand Down

0 comments on commit 42191f0

Please sign in to comment.