Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF-TRT Improve test coverage of pool op converters #40184

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 7 additions & 2 deletions tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3508,8 +3508,13 @@ Status ConvertPool(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
#if IS_TRT_VERSION_GE(5, 1, 0, 0)
std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF,
DataType::DT_INT8};
#else
std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF};
#endif
TF_RETURN_IF_ERROR(AllowDataTypes(*params, allowed_types));
nvinfer1::PoolingType type;
if (node_def.op() == "MaxPool") {
type = nvinfer1::PoolingType::kMAX;
Expand Down
290 changes: 140 additions & 150 deletions tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4602,41 +4602,72 @@ TEST_F(OpConverterTest, ConvertConv3D) {
ElementsAreArray(ok_params[i].expected_output));
}
}
#endif

TEST_F(OpConverterTest, ConvertPool3D) {
// Get nodedef for MaxPool3D and AvgPool3D layers.
auto get_pool3d_nodedef = [](std::vector<int> ksize = {1, 1, 1, 1, 1},
std::vector<int> strides = {1, 1, 1, 1, 1},
string padding = "SAME",
string data_format = "NCDHW",
const bool is_max_pooling = true) -> NodeDef {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);

template <typename T>
NodeDef CreatePoolOp(DataType tf_type, std::vector<int> ksize,
std::vector<int> strides, string padding,
string data_format) {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
typename T::Attrs attrs;
attrs.data_format_ = data_format;
return T(s.WithOpName("my_pool"), input, ksize, strides, padding, attrs)
.operation.node()
->def();
}
TEST_P(OpConverterTest1, ConvertPool) {
// Get nodedef for MaxPool and AvgPool layers (2D or 3D).
auto get_pool_nodedef =
[](DataType tf_type, int nDim, std::vector<int> ksize = {},
std::vector<int> strides = {}, string padding = "SAME",
string data_format = "", const bool is_max_pooling = true) -> NodeDef {
if (ksize.empty()) {
ksize = nDim == 2 ? std::vector<int>{1, 1, 1, 1}
: std::vector<int>{1, 1, 1, 1, 1};
}
if (strides.empty()) {
strides = nDim == 2 ? std::vector<int>{1, 1, 1, 1}
: std::vector<int>{1, 1, 1, 1, 1};
}
if (data_format == "") {
data_format = nDim == 2 ? "NCHW" : "NCDHW";
}
if (is_max_pooling) {
ops::MaxPool3D::Attrs attrs =
ops::MaxPool3D::Attrs().DataFormat(data_format);
auto pool3d = ops::MaxPool3D(s.WithOpName("my_maxpool3d"), input, ksize,
strides, padding, attrs);
return pool3d.operation.node()->def();
if (nDim == 3) {
return CreatePoolOp<ops::MaxPool3D>(tf_type, ksize, strides, padding,
data_format);
} else {
return CreatePoolOp<ops::MaxPool>(tf_type, ksize, strides, padding,
data_format);
}
} else {
ops::AvgPool3D::Attrs attrs =
ops::AvgPool3D::Attrs().DataFormat(data_format);
auto pool3d = ops::AvgPool3D(s.WithOpName("my_avgpool3d"), input, ksize,
strides, padding, attrs);
return pool3d.operation.node()->def();
if (nDim == 3) {
return CreatePoolOp<ops::AvgPool3D>(tf_type, ksize, strides, padding,
data_format);
} else {
return CreatePoolOp<ops::AvgPool>(tf_type, ksize, strides, padding,
data_format);
}
}
};

{
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
std::vector<int> test_nDims{2, 3};
#else
std::vector<int> test_nDims{2};
#endif

for (int nDim : test_nDims) {
// Input is weights, should fail.
Reset();
NodeDef node_def = get_pool3d_nodedef();
NodeDef node_def = get_pool_nodedef(tf_type, nDim);

AddTestWeights<float>("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"The input \"input\" for MaxPool3D must be a tensor, at my_maxpool3d");
AddTestWeights<float>("input", {1, 1, 1, 2, 3}, {1, 2, 3, 4, 5, 6});
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
StrCat("The input \"input\" for ", node_def.op(),
" must be a tensor, at my_pool")
.c_str());
}

struct TestParams {
Expand All @@ -4646,150 +4677,109 @@ TEST_F(OpConverterTest, ConvertPool3D) {
std::vector<int> strides;
string padding;
string data_format;
bool is_max_pooling;
std::vector<int> expected_output_dims;
std::vector<float> expected_output;
// The expected outputs for the following operations: MaxPool2D, AvgPool2D,
// MaxPool3D, AvgPool3D
std::vector<std::vector<float>> expected_outputs;
};

// Start here
const std::vector<float> common_array{-4, 2, 15, 3, 6, -3, 22, 1, 88,
// We use common_input as the input to test both 2D and 3D pooling operations,
// to simplify TestParams. For 2D operations, only the first 1/3 of the values
// are used from in common_input.
const std::vector<float> common_input{-4, 2, 15, 3, 6, -3, 22, 1, 88,
56, 36, 1, 1, 105, 1, 16, -28, 1,
42, 9, 3, 1, 7, 1, 11, 61, 5};
// Output of 2D ops for the case when the op equivalent with the identity op.
const std::vector<float> common_2d_output{-4, 2, 15, 3, 6, -3, 22, 1, 88};
std::vector<TestParams> ok_params = {
// Basic - just 1x1 max pooling - input = output
TestParams{/*input_dims=*/{1, 3, 3, 3},
/*input=*/common_array,
/*ksize=*/{1, 1, 1, 1, 1},
/*strides=*/{1, 1, 1, 1, 1},
/*padding=*/"VALID",
/*data_format=*/"NCDHW",
/*is_max_pooling=*/true,
/*expected_output_dims=*/{1, 3, 3, 3},
/*expected_output=*/common_array},
// Basic - just 1x1 avg pooling - input = output
TestParams{/*input_dims=*/{1, 3, 3, 3},
/*input=*/common_array,
/*ksize=*/{1, 1, 1, 1, 1},
/*strides=*/{1, 1, 1, 1, 1},
/*padding=*/"VALID",
/*data_format=*/"NCDHW",
/*is_max_pooling=*/false,
/*expected_output_dims=*/{1, 3, 3, 3},
/*expected_output=*/common_array},
TestParams{
/*input_dims=*/{1, 1, 3, 3, 3},
/*input=*/common_input,
/*ksize=*/{1, 1, 1, 1, 1},
/*strides=*/{1, 1, 1, 1, 1},
/*padding=*/"VALID",
/*data_format=*/"NCDHW",
/*expected_output_dims=*/{1, 1, 3, 3, 3},
/*expected_outputs=*/
{common_2d_output, common_2d_output, common_input, common_input}},
// Basic - just 1x1 max pooling - input = output, SAME padding
TestParams{/*input_dims=*/{1, 3, 3, 3},
/*input=*/common_array,
/*ksize=*/{1, 1, 1, 1, 1},
/*strides=*/{1, 1, 1, 1, 1},
/*padding=*/"SAME",
/*data_format=*/"NCDHW",
/*is_max_pooling=*/true,
/*expected_output_dims=*/{1, 3, 3, 3},
/*expected_output=*/common_array},
// Basic - just 1x1 avg pooling - input = output, SAME padding
TestParams{/*input_dims=*/{1, 3, 3, 3},
/*input=*/common_array,
/*ksize=*/{1, 1, 1, 1, 1},
/*strides=*/{1, 1, 1, 1, 1},
/*padding=*/"VALID",
/*data_format=*/"NCDHW",
/*is_max_pooling=*/false,
/*expected_output_dims=*/{1, 3, 3, 3},
/*expected_output=*/common_array},
// 3x3 max pooling
TestParams{/*input_dims=*/{1, 3, 3, 3},
/*input=*/common_array,
/*ksize=*/{1, 1, 3, 3, 3},
/*strides=*/{1, 1, 1, 1, 1},
/*padding=*/"VALID",
/*data_format=*/"NCDHW",
/*is_max_pooling=*/true,
/*expected_output_dims=*/{1, 1, 1, 1},
/*expected_output=*/{105}},
// 3x3 avg pooling
TestParams{/*input_dims=*/{1, 3, 3, 3},
/*input=*/common_array,
TestParams{
/*input_dims=*/{1, 1, 3, 3, 3},
/*input=*/common_input,
/*ksize=*/{1, 1, 1, 1, 1},
/*strides=*/{1, 1, 1, 1, 1},
/*padding=*/"SAME",
/*data_format=*/"NCDHW",
/*expected_output_dims=*/{1, 1, 3, 3, 3},
/*expected_outputs=*/
{common_2d_output, common_2d_output, common_input, common_input}},
// 3x3 pooling NCDHW
TestParams{/*input_dims=*/{1, 1, 3, 3, 3},
/*input=*/common_input,
/*ksize=*/{1, 1, 3, 3, 3},
/*strides=*/{1, 1, 1, 1, 1},
/*padding=*/"VALID",
/*data_format=*/"NCDHW",
/*is_max_pooling=*/false,
/*expected_output_dims=*/{1, 1, 1, 1},
/*expected_output=*/{17}},
// 3x3 max pooling, NDHWC
TestParams{/*input_dims=*/{3, 3, 3, 1},
/*input=*/common_array,
/*expected_output_dims=*/{1, 1, 1, 1, 1},
/*expected_outputs=*/{{88}, {14.444445}, {105}, {17}}},
// 3x3 pooling, NDHWC
TestParams{/*input_dims=*/{1, 3, 3, 3, 1},
tfeher marked this conversation as resolved.
Show resolved Hide resolved
/*input=*/common_input,
/*ksize=*/{1, 3, 3, 3, 1},
/*strides=*/{1, 1, 1, 1, 1},
/*padding=*/"VALID",
/*data_format=*/"NDHWC",
/*is_max_pooling=*/true,
/*expected_output_dims=*/{1, 1, 1, 1},
/*expected_output=*/{105}},
// 3x3 avg pooling, NDHWC
TestParams{/*input_dims=*/{3, 3, 3, 1},
/*input=*/common_array,
/*ksize=*/{1, 3, 3, 3, 1},
/*strides=*/{1, 1, 1, 1, 1},
/*expected_output_dims=*/{1, 1, 1, 1, 1},
/*expected_outputs=*/{{88}, {14.444445}, {105}, {17}}},
// Strided
TestParams{/*input_dims=*/{1, 1, 3, 3, 3},
/*input=*/{1, 0, 2, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0,
0, 0, 0, 0, 5, 0, 6, 0, 0, 0, 7, 0, 8},
/*ksize=*/{1, 1, 1, 1, 1},
/*strides=*/{1, 1, 2, 2, 2},
/*padding=*/"VALID",
/*data_format=*/"NDHWC",
/*is_max_pooling=*/false,
/*expected_output_dims=*/{1, 1, 1, 1},
/*expected_output=*/{17}},
// Strided max
TestParams{
/*input_dims=*/{1, 3, 3, 3},
/*input=*/{1, 0, 2, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0,
0, 0, 0, 0, 5, 0, 6, 0, 0, 0, 7, 0, 8},
/*ksize=*/{1, 1, 1, 1, 1},
/*strides=*/{1, 1, 2, 2, 2},
/*padding=*/"VALID",
/*data_format=*/"NCDHW",
/*is_max_pooling=*/true,
/*expected_output_dims=*/{1, 2, 2, 2},
/*expected_output=*/{1, 2, 3, 4, 5, 6, 7, 8} // Should only pick up
// the corners
},
// Strided avg
TestParams{
/*input_dims=*/{1, 3, 3, 3},
/*input=*/{1, 0, 2, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0,
0, 0, 0, 0, 5, 0, 6, 0, 0, 0, 7, 0, 8},
/*ksize=*/{1, 1, 1, 1, 1},
/*strides=*/{1, 1, 2, 2, 2},
/*padding=*/"VALID",
/*data_format=*/"NCDHW",
/*is_max_pooling=*/false,
/*expected_output_dims=*/{1, 2, 2, 2},
/*expected_output=*/{1, 2, 3, 4, 5, 6, 7, 8} // Should only pick up
// the corners
}};

for (int i = 0; i < ok_params.size(); i++) {
Reset();
NodeDef node_def = get_pool3d_nodedef(
ok_params[i].ksize, ok_params[i].strides, ok_params[i].padding,
ok_params[i].data_format, ok_params[i].is_max_pooling);
AddTestTensor("input", ok_params[i].input_dims);
RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
string expected_node_name =
ok_params[i].is_max_pooling ? "my_maxpool3d" : "my_avgpool3d";
TF_EXPECT_OK(GetTensorOrWeights(expected_node_name, &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
output.tensor()->getDimensions());
/*data_format=*/"NCDHW",
/*expected_output_dims=*/{1, 1, 2, 2, 2},
/*expected_outputs=*/
{{1, 2, 3, 4}, // Should only pick up the corners
{1, 2, 3, 4},
{1, 2, 3, 4, 5, 6, 7, 8},
{1, 2, 3, 4, 5, 6, 7, 8}}},
};

const DataVec input_data{{"input", AsTensor<float>(ok_params[i].input)}};
DataVec output_data{
{expected_node_name,
ConstructTensor<float>(ok_params[i].expected_output.size())}};
TF_EXPECT_OK(BuildAndRun(input_data, &output_data));
EXPECT_THAT(GetSpanForData<float>(output_data[0]),
ElementsAreArray(ok_params[i].expected_output));
for (auto p : ok_params) {
int test_counter = 0;
for (int nDim : test_nDims) {
auto input = p.input;
auto input_dims = p.input_dims;
auto ksize = p.ksize;
auto strides = p.strides;
auto expected_output_dims = p.expected_output_dims;
std::string data_format = p.data_format;
if (nDim == 2) {
input.resize(9);
data_format = p.data_format == "NDHWC" ? "NHWC" : "NCHW";
// Remove one of the spatial dimensions
input_dims.erase(input_dims.begin() + 2);
ksize.erase(ksize.begin() + 2);
strides.erase(strides.begin() + 2);
expected_output_dims.erase(expected_output_dims.begin() + 2);
}
for (bool is_max_pooling : {true, false}) {
Reset();
NodeDef node_def =
get_pool_nodedef(tf_type, nDim, ksize, strides, p.padding,
data_format, is_max_pooling);
AddTestTensor("input", input_dims, input);
TestOpConverter("my_pool", node_def, expected_output_dims, Status::OK(),
Status::OK(),
ElementsAreArray(p.expected_outputs.at(test_counter)));
test_counter++;
}
}
}
}
#endif // IS_TRT_VERSION_GE(6, 0, 0, 0)

TEST_F(OpConverterTest, ConvertTopK) {
// TODO(tmorris): This test isn't setting the input dtype properly. TopK with
Expand Down