From fc39b1b3ca4c310b44ac388b31a1b36f952ae52b Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 10 Apr 2019 13:39:07 -0700 Subject: [PATCH 1/7] support constant node for gather support constant params input for gather fix tests --- .../tf2tensorrt/convert/convert_nodes.cc | 58 ++++++++++++++----- .../tf2tensorrt/convert/convert_nodes_test.cc | 54 ++++++++++++----- 2 files changed, 86 insertions(+), 26 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 7017c9a170661e..9ce86ab6d44a9b 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -525,8 +525,12 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value, // indices. // TODO(tmorris): Use this method in more ops. Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name, - int* trt_axis) { - const int tf_nb_dims = trt_nb_dims + 1; + int* trt_axis, bool is_weights=false) { + int tf_nb_dims = trt_nb_dims; + if (!is_weights) { + tf_nb_dims = trt_nb_dims + 1; + } + // Check bounds. if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) { return errors::InvalidArgument( @@ -536,13 +540,16 @@ Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name, // Make negative axis positive. if (tf_axis < 0) tf_axis += tf_nb_dims; // Don't allow axis to be the batch dimension. - if (tf_axis == 0) { + if (tf_axis == 0 && !is_weights) { return errors::Unimplemented( "TensorRT does not allow manipulation of the batch dimension, at ", node_name); } // Remove batch dimension. - *trt_axis = tf_axis - 1; + *trt_axis = tf_axis; + if (!is_weights){ + *trt_axis = tf_axis - 1; + } return Status::OK(); } @@ -3800,20 +3807,37 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { Status ConvertGather(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - TF_RETURN_IF_ERROR(CheckInputsWeights( - *params, {{"params", false}, {"indices", false}, {"axis", true}})); + + if (inputs.size() != 3) { + return errors::InvalidArgument("GatherV2 got ", inputs.size(), + " inputs but expected 3, at ", node_def.name()); + } + if (!inputs.at(2).is_weights()) { + return errors::Unimplemented("The input \"axis\" for GatherV2", + " must be a constant, at ", node_def.name()); + } + if (!inputs.at(1).is_tensor()) { + return errors::Unimplemented( + "The input \"indecies\" for GatherV2," + " must be a tensor, at ", node_def.name()); + } + TF_RETURN_IF_ERROR(AllowDataTypes( *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}, /*dtype_attr_name=*/"Tparams")); + absl::Span axis = inputs.at(2).weights().GetSpan(); if (axis.size() != 1) { return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ", node_def.name()); } + if (inputs.at(0).is_weights() && axis[0] != 0){ + return errors::Unimplemented("The input axis must be a zero," + " in case of params is a weights"); + } int trt_axis = 0; TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims, - node_def.name(), &trt_axis)); - const TRT_TensorOrWeights& params_tensor = inputs.at(0); + node_def.name(), &trt_axis, inputs.at(0).is_weights())); const TRT_TensorOrWeights& indices_tensor = inputs.at(1); if (indices_tensor.batch_size() != 1) { return errors::InvalidArgument("Only indices with batch 1 are supported."); @@ -3821,15 +3845,24 @@ Status ConvertGather(OpConverterParams* params) { // Both input are tensors, and the TF gather result will have rank: // (params.nbDims + 1) + (indices.nbDims + 1) - 1, // where "+ 1" adds the batch dim. - const int tf_gather_output_rank = params_tensor.GetTrtDims().nbDims + - indices_tensor.GetTrtDims().nbDims + 1; + const int tf_gather_output_rank = + inputs.at(0).GetTrtDims().nbDims + indices_tensor.GetTrtDims().nbDims + 1; if (tf_gather_output_rank > nvinfer1::Dims::MAX_DIMS + 1) { return errors::InvalidArgument( "Result of gather has dimension greater than ", nvinfer1::Dims::MAX_DIMS + 1); } if (params->validation_only) return Status::OK(); + nvinfer1::ITensor* params_input; + // convert weights to tensor + if (inputs.at(0).is_tensor()) { + params_input = inputs.at(0).tensor(); + } else { + params_input = params->converter->CreateConstantLayer( + inputs.at(0).weights(), inputs.at(0).GetTrtDims()); + } + const TRT_TensorOrWeights params_tensor(params_input); // Note on how IGatherLayer works: if both the data and indices tensors have // a batch size dimension of size N, it performs: // for batchid in xrange(N): @@ -3847,8 +3880,8 @@ Status ConvertGather(OpConverterParams* params) { if (trt_gather_output_dims.nbDims != tf_gather_output_rank - 2) { return errors::Internal( "Get unexpected output dimensions of IGatherLayer. Expect nbDims: ", - tf_gather_output_rank - 2, - ", actual nbDims: ", trt_gather_output_dims.nbDims); + tf_gather_output_rank - 2, ", actual nbDims: ", + trt_gather_output_dims.nbDims); } // Reshape the output so after adding the implicit batch dim it'll match the // output shape of TF GatherV2. @@ -3862,7 +3895,6 @@ Status ConvertGather(OpConverterParams* params) { TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(gather_output), trt_gather_output_dims, /*validation_only=*/false, &output_tensor)); - params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 09b7a60c0839c9..51a913d6220c7a 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -3845,35 +3845,47 @@ void TestConvertGather(OpConverterTest* test) { int axis; std::vector expected_output_dims; std::vector expected_output; + bool params_is_tensor; }; // Input is the same {1, 2, 3, 4, 5, 6} for all cases. - const int kGatherOKCases = 7; + const int kGatherOKCases = 10; const std::vector params_input = {CType(1), CType(2), CType(3), CType(4), CType(5), CType(6)}; TestParams ok_params[kGatherOKCases] = { // Vector indices, and output rank is rank(params). - TestParams{{1, 2, 3}, {}, {0}, 3, {1, 2, 1}, {1, 4}}, - TestParams{{1, 2, 3}, {}, {1}, 2, {1, 1, 3}, {4, 5, 6}}, + TestParams{{1, 2, 3}, {}, {0}, 3, {1, 2, 1}, {1, 4}, true}, + TestParams{{1, 2, 3}, {}, {1}, 2, {1, 1, 3}, {4, 5, 6}, true}, // Indices with rank>1, and output rank is rank(params)+rank(indices)-1. - TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1, 1}, {1, 4}}, - TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1, 1}, {2, 5}}, - TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1, 1}, {3, 6}}, + TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1, 1}, {1, 4}, true}, + TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1, 1}, {2, 5}, true}, + TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1, 1}, {3, 6}, true}, TestParams{ - {1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 1, 3}, {3, 1, 2, 6, 4, 5}}, + {1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 1, 3}, {3, 1, 2, 6, 4, 5}, true}, TestParams{{3, 2}, {2, 2}, {0, 0, 1, 0}, 2, {3, 1, 2, 2}, - {1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5}}, + {1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5}, + true}, + TestParams{{1, 2, 3}, {}, {0}, 0, {1, 2, 3}, {1, 2, 3, 4, 5, 6}, false}, + TestParams{{3, 2}, {2}, {0, 1}, 0, {1, 2, 2}, {1, 2, 3, 4}, false}, + TestParams{{2, 3}, {1, 2}, {0, 1}, 0, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, false} }; // Ok. for (int i = 0; i < kGatherOKCases; i++) { + //for (const bool params_is_tensor : {true, false}) { test->Reset(); - test->AddTestTensor("params", ok_params[i].params_dims, 1, - TfDataTypeToTrt(dtype)); + if (ok_params[i].params_is_tensor) { + test->AddTestTensor("params", ok_params[i].params_dims, 1, + TfDataTypeToTrt(dtype)); + } else { + test->AddTestWeights("params", ok_params[i].params_dims, + params_input); + } + test->AddTestTensor("indices", ok_params[i].indices_dims, 1, nvinfer1::DataType::kINT32); test->AddTestWeights("axis", {1}, {ok_params[i].axis}); @@ -3889,9 +3901,13 @@ void TestConvertGather(OpConverterTest* test) { ok_params[i].expected_output.begin(), ok_params[i].expected_output.end()); - const DataVec input_data{ - {"params", test::AsTensor(params_input)}, - {"indices", test::AsTensor(ok_params[i].indices)}}; + DataVec input_data; + if (ok_params[i].params_is_tensor) { + input_data = {{"params", test::AsTensor(params_input)}, + {"indices", test::AsTensor(ok_params[i].indices)}}; + } else { + input_data = {{"indices", test::AsTensor(ok_params[i].indices)}}; + } DataVec output_data{ {"my_gather", ConstructTensor(ok_params[i].expected_output.size())}}; @@ -3947,6 +3963,18 @@ TEST_F(OpConverterTest, ConvertGather) { "TensorRT does not allow manipulation of the " "batch dimension, at my_gather"); } + { + // Axis is not equal zero and params is a weights + Reset(); + AddTestWeights("params", {3}, {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestWeights("axis", {1}, {1}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input axis must be a zero," + " in case of params is a weights"); + } + + Reset(); TestConvertGather(this); From 06868ab9ff745bb4b568796a2b75d5c096700995 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 7 Jun 2019 10:08:01 -0700 Subject: [PATCH 2/7] Change ConvertAxis is_weights to use_implicit_batch and move output to end of args list --- .../tf2tensorrt/convert/convert_nodes.cc | 55 ++++++++----------- 1 file changed, 24 insertions(+), 31 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 9ce86ab6d44a9b..37fb2bde180481 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -521,16 +521,11 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value, } // Convert an axis from TF format to TRT format while validating. TF format -// includes the batch dimension, while TRT does not. TF can also use negative -// indices. -// TODO(tmorris): Use this method in more ops. +// includes the batch dimension, while TRT does not if implicit batching is used +// (i.e. for tensors). TF can also use negative indices. Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name, - int* trt_axis, bool is_weights=false) { - int tf_nb_dims = trt_nb_dims; - if (!is_weights) { - tf_nb_dims = trt_nb_dims + 1; - } - + bool use_implicit_batch, int* trt_axis) { + const int tf_nb_dims = trt_nb_dims + (use_implicit_batch ? 1 : 0); // Check bounds. if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) { return errors::InvalidArgument( @@ -540,16 +535,13 @@ Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name, // Make negative axis positive. if (tf_axis < 0) tf_axis += tf_nb_dims; // Don't allow axis to be the batch dimension. - if (tf_axis == 0 && !is_weights) { + if (use_implicit_batch && tf_axis == 0) { return errors::Unimplemented( "TensorRT does not allow manipulation of the batch dimension, at ", node_name); } - // Remove batch dimension. - *trt_axis = tf_axis; - if (!is_weights){ - *trt_axis = tf_axis - 1; - } + // Remove batch dimension if it is implicit. + *trt_axis = use_implicit_batch ? tf_axis - 1 : tf_axis; return Status::OK(); } @@ -2069,8 +2061,8 @@ Status ConvertExpandDims(OpConverterParams* params) { // Use rank = nbDims + 1 for ConvertAxis's bounds checking to account for // ExpandDim's ability to add an axis at end of the shape. int trt_axis; - TF_RETURN_IF_ERROR( - ConvertAxis(axis[0], dims.nbDims + 1, node_def.name(), &trt_axis)); + TF_RETURN_IF_ERROR(ConvertAxis(axis[0], dims.nbDims + 1, node_def.name(), + /*use_implicit_batch=*/true, &trt_axis)); if (params->validation_only) return Status::OK(); // ExpandDims: Insert new dim of size 1. @@ -2105,8 +2097,8 @@ Status ConvertSqueeze(OpConverterParams* params) { for (int tf_axis : squeeze_dims) { // Make sure axis is valid. int trt_axis; - TF_RETURN_IF_ERROR( - ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis)); + TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(), + /*use_implicit_batch=*/true, &trt_axis)); // Make sure target dimension is size 1. if (input_dims[trt_axis] != 1) { return errors::InvalidArgument( @@ -3301,9 +3293,9 @@ Status ConvertReduce(OpConverterParams* params) { } for (int i = 0; i < tf_axes_list.size(); i++) { int trt_axis; - TF_RETURN_IF_ERROR(ConvertAxis(tf_axes_list[i], - tensor->getDimensions().nbDims, - node_def.name(), &trt_axis)); + TF_RETURN_IF_ERROR( + ConvertAxis(tf_axes_list[i], tensor->getDimensions().nbDims, + node_def.name(), /*use_implicit_batch=*/true, &trt_axis)); axes |= (1 << trt_axis); } @@ -3370,8 +3362,8 @@ Status ConvertPack(OpConverterParams* params) { const nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); const int64 tf_axis = attrs.get("axis"); int trt_axis; - TF_RETURN_IF_ERROR( - ConvertAxis(tf_axis, dims.nbDims + 1, node_def.name(), &trt_axis)); + TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims + 1, node_def.name(), + /*use_implicit_batch=*/true, &trt_axis)); // Compute expanded dimensions and then reshape input tensors. std::vector tensor_dims(dims.d, dims.d + dims.nbDims); @@ -3518,8 +3510,8 @@ Status ConvertSplitHelper(OpConverterParams* params, const nvinfer1::Dims dims = input.GetTrtDims(); // Convert axis. int trt_axis; - TF_RETURN_IF_ERROR( - ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis)); + TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(), + /*use_implicit_batch=*/true, &trt_axis)); // Dimension must equal num_splits for Unstack (when squeeze_after is true) if (squeeze_after && dims.d[trt_axis] != num_splits) { return errors::InvalidArgument( @@ -3647,8 +3639,8 @@ Status ConvertConcat(OpConverterParams* params) { } int trt_axis = 0; const auto dim = inputs.at(0).GetTrtDims(); - TF_RETURN_IF_ERROR( - ConvertAxis(axis[0], dim.nbDims, node_def.name(), &trt_axis)); + TF_RETURN_IF_ERROR(ConvertAxis(axis[0], dim.nbDims, node_def.name(), + /*use_implicit_batch=*/true, &trt_axis)); // Check that dimensions match on non-concatenate axis. TF_RETURN_IF_ERROR(VerifyShapesMatch( absl::Span(inputs).first(num_inputs), trt_axis, @@ -3837,7 +3829,8 @@ Status ConvertGather(OpConverterParams* params) { } int trt_axis = 0; TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims, - node_def.name(), &trt_axis, inputs.at(0).is_weights())); + node_def.name(), inputs.at(0).is_tensor(), + &trt_axis)); const TRT_TensorOrWeights& indices_tensor = inputs.at(1); if (indices_tensor.batch_size() != 1) { return errors::InvalidArgument("Only indices with batch 1 are supported."); @@ -4153,8 +4146,8 @@ Status ConvertArgMinMax(OpConverterParams* params) { int tf_axis = inputs.at(1).weights().GetSpan()[0]; int trt_axis; nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); - TF_RETURN_IF_ERROR( - ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis)); + TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(), + /*use_implicit_batch=*/true, &trt_axis)); nvinfer1::TopKOperation topk_op; if (node_def.op() == "ArgMin") { topk_op = nvinfer1::TopKOperation::kMIN; From 1b054b7fef2357bece9c10400aa724df7a7a7119 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 7 Jun 2019 14:47:14 -0700 Subject: [PATCH 3/7] Fixes for weights param input --- .../tf2tensorrt/convert/convert_nodes.cc | 83 ++++++++++--------- .../tf2tensorrt/convert/convert_nodes_test.cc | 6 +- 2 files changed, 49 insertions(+), 40 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 37fb2bde180481..40d68f8f7074b2 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -3799,77 +3799,84 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { Status ConvertGather(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - + const auto& params_input = inputs.at(0); + const auto& indices_input = input.at(1); + const auto& axis_input = input.at(2); + // TODO(tmorris): Use CheckInputsWeights by changing bool to enum with an + // option for an input to be either tensor or weight. if (inputs.size() != 3) { return errors::InvalidArgument("GatherV2 got ", inputs.size(), " inputs but expected 3, at ", node_def.name()); } - if (!inputs.at(2).is_weights()) { - return errors::Unimplemented("The input \"axis\" for GatherV2", - " must be a constant, at ", node_def.name()); + if (!axis_input.is_weights()) { + return errors::Unimplemented("The input \"axis\" for GatherV2 must be a constant, at ", node_def.name()); } - if (!inputs.at(1).is_tensor()) { + if (!indices_tensor.is_tensor()) { return errors::Unimplemented( - "The input \"indecies\" for GatherV2," - " must be a tensor, at ", node_def.name()); + "The input \"indices\" for GatherV2 must be a tensor, at ", node_def.name()); } TF_RETURN_IF_ERROR(AllowDataTypes( *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}, /*dtype_attr_name=*/"Tparams")); + TF_RETURN_IF_ERROR(AllowDataTypes( + *params, {DataType::DT_INT32}, /*dtype_attr_name=*/"Tindices")); - absl::Span axis = inputs.at(2).weights().GetSpan(); + absl::Span axis = axis_input.weights().GetSpan(); if (axis.size() != 1) { return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ", node_def.name()); } - if (inputs.at(0).is_weights() && axis[0] != 0){ - return errors::Unimplemented("The input axis must be a zero," - " in case of params is a weights"); - } int trt_axis = 0; - TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims, - node_def.name(), inputs.at(0).is_tensor(), + TF_RETURN_IF_ERROR(ConvertAxis(axis[0], params_input.GetTrtDims().nbDims, + node_def.name(), params_input.is_tensor(), &trt_axis)); - const TRT_TensorOrWeights& indices_tensor = inputs.at(1); - if (indices_tensor.batch_size() != 1) { - return errors::InvalidArgument("Only indices with batch 1 are supported."); + + if (params_input.is_weights() && trt_axis != 0){ + return errors::Unimplemented( + "The input axis must be zero when params is a weight."); + } + if (params_input.is_tensor() && indices_tensor.batch_size() != 1) { + return errors::InvalidArgument( + "Only indices with batch 1 are supported when params is a tensor."); } // Both input are tensors, and the TF gather result will have rank: // (params.nbDims + 1) + (indices.nbDims + 1) - 1, // where "+ 1" adds the batch dim. - const int tf_gather_output_rank = - inputs.at(0).GetTrtDims().nbDims + indices_tensor.GetTrtDims().nbDims + 1; + const int params_tf_rank = params_input.GetTrtDims().nbDims + (params_input.is_tensor() ? 1 : 0); + const int indices_tf_rank = indices_tensor.GetTrtDims().nbDims + 1; + const int tf_gather_output_rank = params_tf_rank + indices_tf_rank - 1; if (tf_gather_output_rank > nvinfer1::Dims::MAX_DIMS + 1) { return errors::InvalidArgument( "Result of gather has dimension greater than ", nvinfer1::Dims::MAX_DIMS + 1); } if (params->validation_only) return Status::OK(); - nvinfer1::ITensor* params_input; - // convert weights to tensor - if (inputs.at(0).is_tensor()) { - params_input = inputs.at(0).tensor(); + + // Convert params to tensor is it is a weight. + nvinfer1::ITensor* params_tensor = nullptr; + if (params_input.is_weights()) { + params_tensor = TRT_TensorOrWeights(params->converter->CreateConstantLayer( + params_input.weights(), params_input.GetTrtDims())); } else { - params_input = params->converter->CreateConstantLayer( - inputs.at(0).weights(), inputs.at(0).GetTrtDims()); + params_tensor = params_input.tensor(); } - const TRT_TensorOrWeights params_tensor(params_input); // Note on how IGatherLayer works: if both the data and indices tensors have // a batch size dimension of size N, it performs: // for batchid in xrange(N): // output[batchid, a0, ..., an, i, ..., j, b0, ..., bn] = ( // data[batchid, a0, ..., an, indices[batchid, i, ..., j] b0, ..., bn]) nvinfer1::IGatherLayer* layer = params->converter->network()->addGather( - *params_tensor.tensor(), *indices_tensor.tensor(), trt_axis); + *params_tensor, *indices_tensor.tensor(), trt_axis); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - nvinfer1::ITensor* gather_output = layer->getOutput(0); - nvinfer1::Dims trt_gather_output_dims = gather_output->getDimensions(); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + nvinfer1::Dims trt_gather_output_dims = output_tensor->getDimensions(); // Note for the "- 2": one is for the output batch dim encapsulated by TF-TRT, // and the other is for the output dimension that is squeezed by IGatherLayer // because of the implicit batch dim in the indices (see the above note). + // const int expected_trt_output_rank = ; if (trt_gather_output_dims.nbDims != tf_gather_output_rank - 2) { return errors::Internal( "Get unexpected output dimensions of IGatherLayer. Expect nbDims: ", @@ -3878,16 +3885,18 @@ Status ConvertGather(OpConverterParams* params) { } // Reshape the output so after adding the implicit batch dim it'll match the // output shape of TF GatherV2. - for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) { - trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1]; + if (params_input.is_tensor()) { + for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) { + trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1]; + } + trt_gather_output_dims.d[trt_axis] = 1; + ++trt_gather_output_dims.nbDims; + + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + TRT_TensorOrWeights(output_tensor), trt_gather_output_dims, + /*validation_only=*/false, &output_tensor)); } - trt_gather_output_dims.d[trt_axis] = 1; - ++trt_gather_output_dims.nbDims; - nvinfer1::ITensor* output_tensor = nullptr; - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - TRT_TensorOrWeights(gather_output), trt_gather_output_dims, - /*validation_only=*/false, &output_tensor)); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 51a913d6220c7a..fab9da87e604ee 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -3869,9 +3869,9 @@ void TestConvertGather(OpConverterTest* test) { {3, 1, 2, 2}, {1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5}, true}, - TestParams{{1, 2, 3}, {}, {0}, 0, {1, 2, 3}, {1, 2, 3, 4, 5, 6}, false}, - TestParams{{3, 2}, {2}, {0, 1}, 0, {1, 2, 2}, {1, 2, 3, 4}, false}, - TestParams{{2, 3}, {1, 2}, {0, 1}, 0, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, false} + TestParams{{1, 2, 3}, {}, {0}, 0, {2, 3}, {1, 2, 3, 4, 5, 6}, false}, + TestParams{{3, 2}, {2}, {0, 1}, 0, {2, 2}, {1, 2, 3, 4}, false}, + TestParams{{2, 3}, {1, 2}, {0, 1}, 0, {1, 2, 3}, {1, 2, 3, 4, 5, 6}, false} }; // Ok. From f1bafd331d854c54d4c0eea86051d7872bad3833 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 10 Jun 2019 11:03:51 -0700 Subject: [PATCH 4/7] Fix formatting and compilation errors --- .../tf2tensorrt/convert/convert_nodes.cc | 43 +++++++++++-------- .../tf2tensorrt/convert/convert_nodes_test.cc | 16 +++---- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 40d68f8f7074b2..5f069a6ee44817 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -3799,28 +3799,32 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { Status ConvertGather(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - const auto& params_input = inputs.at(0); - const auto& indices_input = input.at(1); - const auto& axis_input = input.at(2); // TODO(tmorris): Use CheckInputsWeights by changing bool to enum with an // option for an input to be either tensor or weight. if (inputs.size() != 3) { return errors::InvalidArgument("GatherV2 got ", inputs.size(), - " inputs but expected 3, at ", node_def.name()); + " inputs but expected 3, at ", + node_def.name()); } + const auto& params_input = inputs.at(0); + const auto& indices_input = inputs.at(1); + const auto& axis_input = inputs.at(2); if (!axis_input.is_weights()) { - return errors::Unimplemented("The input \"axis\" for GatherV2 must be a constant, at ", node_def.name()); + return errors::Unimplemented( + "The input \"axis\" for GatherV2 must be a constant, at ", + node_def.name()); } - if (!indices_tensor.is_tensor()) { + if (!indices_input.is_tensor()) { return errors::Unimplemented( - "The input \"indices\" for GatherV2 must be a tensor, at ", node_def.name()); + "The input \"indices\" for GatherV2 must be a tensor, at ", + node_def.name()); } TF_RETURN_IF_ERROR(AllowDataTypes( *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}, /*dtype_attr_name=*/"Tparams")); - TF_RETURN_IF_ERROR(AllowDataTypes( - *params, {DataType::DT_INT32}, /*dtype_attr_name=*/"Tindices")); + TF_RETURN_IF_ERROR(AllowDataTypes(*params, {DataType::DT_INT32}, + /*dtype_attr_name=*/"Tindices")); absl::Span axis = axis_input.weights().GetSpan(); if (axis.size() != 1) { @@ -3831,20 +3835,20 @@ Status ConvertGather(OpConverterParams* params) { TF_RETURN_IF_ERROR(ConvertAxis(axis[0], params_input.GetTrtDims().nbDims, node_def.name(), params_input.is_tensor(), &trt_axis)); - - if (params_input.is_weights() && trt_axis != 0){ + if (params_input.is_weights() && trt_axis != 0) { return errors::Unimplemented( "The input axis must be zero when params is a weight."); } - if (params_input.is_tensor() && indices_tensor.batch_size() != 1) { + if (params_input.is_tensor() && indices_input.batch_size() != 1) { return errors::InvalidArgument( "Only indices with batch 1 are supported when params is a tensor."); } // Both input are tensors, and the TF gather result will have rank: // (params.nbDims + 1) + (indices.nbDims + 1) - 1, // where "+ 1" adds the batch dim. - const int params_tf_rank = params_input.GetTrtDims().nbDims + (params_input.is_tensor() ? 1 : 0); - const int indices_tf_rank = indices_tensor.GetTrtDims().nbDims + 1; + const int params_tf_rank = + params_input.GetTrtDims().nbDims + (params_input.is_tensor() ? 1 : 0); + const int indices_tf_rank = indices_input.GetTrtDims().nbDims + 1; const int tf_gather_output_rank = params_tf_rank + indices_tf_rank - 1; if (tf_gather_output_rank > nvinfer1::Dims::MAX_DIMS + 1) { return errors::InvalidArgument( @@ -3856,8 +3860,8 @@ Status ConvertGather(OpConverterParams* params) { // Convert params to tensor is it is a weight. nvinfer1::ITensor* params_tensor = nullptr; if (params_input.is_weights()) { - params_tensor = TRT_TensorOrWeights(params->converter->CreateConstantLayer( - params_input.weights(), params_input.GetTrtDims())); + params_tensor = params->converter->CreateConstantLayer( + params_input.weights(), params_input.GetTrtDims()); } else { params_tensor = params_input.tensor(); } @@ -3868,7 +3872,7 @@ Status ConvertGather(OpConverterParams* params) { // output[batchid, a0, ..., an, i, ..., j, b0, ..., bn] = ( // data[batchid, a0, ..., an, indices[batchid, i, ..., j] b0, ..., bn]) nvinfer1::IGatherLayer* layer = params->converter->network()->addGather( - *params_tensor, *indices_tensor.tensor(), trt_axis); + *params_tensor, *indices_input.tensor(), trt_axis); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); @@ -3876,8 +3880,9 @@ Status ConvertGather(OpConverterParams* params) { // Note for the "- 2": one is for the output batch dim encapsulated by TF-TRT, // and the other is for the output dimension that is squeezed by IGatherLayer // because of the implicit batch dim in the indices (see the above note). - // const int expected_trt_output_rank = ; - if (trt_gather_output_dims.nbDims != tf_gather_output_rank - 2) { + const int expected_trt_output_rank = + tf_gather_output_rank - (params_input.is_tensor() ? 2 : 1); + if (trt_gather_output_dims.nbDims != expected_trt_output_rank) { return errors::Internal( "Get unexpected output dimensions of IGatherLayer. Expect nbDims: ", tf_gather_output_rank - 2, ", actual nbDims: ", diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index fab9da87e604ee..4837030515cf41 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -3871,12 +3871,12 @@ void TestConvertGather(OpConverterTest* test) { true}, TestParams{{1, 2, 3}, {}, {0}, 0, {2, 3}, {1, 2, 3, 4, 5, 6}, false}, TestParams{{3, 2}, {2}, {0, 1}, 0, {2, 2}, {1, 2, 3, 4}, false}, - TestParams{{2, 3}, {1, 2}, {0, 1}, 0, {1, 2, 3}, {1, 2, 3, 4, 5, 6}, false} + TestParams{ + {2, 3}, {1, 2}, {0, 1}, 0, {1, 2, 3}, {1, 2, 3, 4, 5, 6}, false}, }; // Ok. for (int i = 0; i < kGatherOKCases; i++) { - //for (const bool params_is_tensor : {true, false}) { test->Reset(); if (ok_params[i].params_is_tensor) { test->AddTestTensor("params", ok_params[i].params_dims, 1, @@ -3966,15 +3966,13 @@ TEST_F(OpConverterTest, ConvertGather) { { // Axis is not equal zero and params is a weights Reset(); - AddTestWeights("params", {3}, {1, 2, 3}); + AddTestWeights("params", {1, 3}, {1, 2, 3}); AddTestTensor("indices", {2}); AddTestWeights("axis", {1}, {1}); - RunValidationAndConversion(node_def, error::UNIMPLEMENTED, - "The input axis must be a zero," - " in case of params is a weights"); - } - - + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input axis must be zero when params is a weight."); + } Reset(); TestConvertGather(this); From dc1cb8f1ef0c20062e77df5244102fe32bdfdf09 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 10 Jun 2019 11:12:11 -0700 Subject: [PATCH 5/7] Add test case for when indices has a batch size != 1 and params is a tensor --- .../compiler/tf2tensorrt/convert/convert_nodes.cc | 4 ++-- .../tf2tensorrt/convert/convert_nodes_test.cc | 12 +++++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 5f069a6ee44817..06d3df01a536fd 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -3840,8 +3840,8 @@ Status ConvertGather(OpConverterParams* params) { "The input axis must be zero when params is a weight."); } if (params_input.is_tensor() && indices_input.batch_size() != 1) { - return errors::InvalidArgument( - "Only indices with batch 1 are supported when params is a tensor."); + return errors::Unimplemented( + "Indices must have a batch size of 1 when params is a tensor."); } // Both input are tensors, and the TF gather result will have rank: // (params.nbDims + 1) + (indices.nbDims + 1) - 1, diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 4837030515cf41..f197096a423ed5 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -3964,7 +3964,7 @@ TEST_F(OpConverterTest, ConvertGather) { "batch dimension, at my_gather"); } { - // Axis is not equal zero and params is a weights + // Axis is not zero when params is a weight, should fail. Reset(); AddTestWeights("params", {1, 3}, {1, 2, 3}); AddTestTensor("indices", {2}); @@ -3973,6 +3973,16 @@ TEST_F(OpConverterTest, ConvertGather) { node_def, error::UNIMPLEMENTED, "The input axis must be zero when params is a weight."); } + { + // Batch size of indices is not 1 when params is a tensor. + Reset(); + AddTestTensor("params", {1, 2, 3}, /*batch_size=*/2); + AddTestTensor("indices", {2}, /*batch_size=*/2); + AddTestWeights("axis", {1}, {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Indices must have a batch size of 1 when params is a tensor."); + } Reset(); TestConvertGather(this); From daef23bbc8a3ca5780c4e188940460ba1dbf6c63 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 12 Jun 2019 11:23:03 -0700 Subject: [PATCH 6/7] Update comment and use expected_trt_output_rank in error message --- tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 06d3df01a536fd..13dba293646170 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -3845,7 +3845,8 @@ Status ConvertGather(OpConverterParams* params) { } // Both input are tensors, and the TF gather result will have rank: // (params.nbDims + 1) + (indices.nbDims + 1) - 1, - // where "+ 1" adds the batch dim. + // where "+ 1" adds the batch dim. If params is a weight, the TRT rank matches + // the TF rank so we don't have to add + 1. const int params_tf_rank = params_input.GetTrtDims().nbDims + (params_input.is_tensor() ? 1 : 0); const int indices_tf_rank = indices_input.GetTrtDims().nbDims + 1; @@ -3885,7 +3886,7 @@ Status ConvertGather(OpConverterParams* params) { if (trt_gather_output_dims.nbDims != expected_trt_output_rank) { return errors::Internal( "Get unexpected output dimensions of IGatherLayer. Expect nbDims: ", - tf_gather_output_rank - 2, ", actual nbDims: ", + expected_trt_output_rank, ", actual nbDims: ", trt_gather_output_dims.nbDims); } // Reshape the output so after adding the implicit batch dim it'll match the From 185225b535e0527cba1b7b6748ae96e67fac8e28 Mon Sep 17 00:00:00 2001 From: Guangda Lai <31743510+aaroey@users.noreply.github.com> Date: Thu, 13 Jun 2019 14:40:12 -0700 Subject: [PATCH 7/7] Refactor GatherV2 test so it's easier to read/understand, and add test for the case where indices batch dim > 1 --- .../tf2tensorrt/convert/convert_nodes_test.cc | 161 ++++++++++++++---- 1 file changed, 129 insertions(+), 32 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index f197096a423ed5..66c6f9b800f087 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -3839,67 +3839,162 @@ void TestConvertGather(OpConverterTest* test) { const NodeDef& node_def = gather.operation.node()->def(); struct TestParams { - std::vector params_dims; - std::vector indices_dims; + // TF shape of the input 'params' (including batch dimension). + std::vector params_shape; + // TF shape of the input 'indices' (including batch dimension). + std::vector indices_shape; std::vector indices; int axis; - std::vector expected_output_dims; + // Expected TF shape of the output (including batch dimension). + std::vector expected_output_shape; std::vector expected_output; bool params_is_tensor; }; // Input is the same {1, 2, 3, 4, 5, 6} for all cases. - const int kGatherOKCases = 10; + const int kGatherOKCases = 11; const std::vector params_input = {CType(1), CType(2), CType(3), CType(4), CType(5), CType(6)}; TestParams ok_params[kGatherOKCases] = { // Vector indices, and output rank is rank(params). - TestParams{{1, 2, 3}, {}, {0}, 3, {1, 2, 1}, {1, 4}, true}, - TestParams{{1, 2, 3}, {}, {1}, 2, {1, 1, 3}, {4, 5, 6}, true}, + TestParams{ + /*params_shape=*/{1, 1, 2, 3}, + /*indices_shape=*/{1}, + /*indices=*/{0}, + /*axis=*/3, + /*expected_output_shape=*/{1, 1, 2, 1}, + /*expected_output=*/{1, 4}, + /*params_is_tensor=*/true, + }, + TestParams{ + /*params_shape=*/{1, 1, 2, 3}, + /*indices_shape=*/{1}, + /*indices=*/{1}, + /*axis=*/2, + /*expected_output_shape=*/{1, 1, 1, 3}, + /*expected_output=*/{4, 5, 6}, + /*params_is_tensor=*/true, + }, // Indices with rank>1, and output rank is rank(params)+rank(indices)-1. - TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1, 1}, {1, 4}, true}, - TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1, 1}, {2, 5}, true}, - TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1, 1}, {3, 6}, true}, TestParams{ - {1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 1, 3}, {3, 1, 2, 6, 4, 5}, true}, - TestParams{{3, 2}, - {2, 2}, - {0, 0, 1, 0}, - 2, - {3, 1, 2, 2}, - {1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5}, - true}, - TestParams{{1, 2, 3}, {}, {0}, 0, {2, 3}, {1, 2, 3, 4, 5, 6}, false}, - TestParams{{3, 2}, {2}, {0, 1}, 0, {2, 2}, {1, 2, 3, 4}, false}, + /*params_shape=*/{1, 1, 2, 3}, + /*indices_shape=*/{1, 1}, + /*indices=*/{0}, + /*axis=*/3, + /*expected_output_shape=*/{1, 1, 2, 1, 1}, + /*expected_output=*/{1, 4}, + /*params_is_tensor=*/true, + }, + TestParams{ + /*params_shape=*/{1, 1, 2, 3}, + /*indices_shape=*/{1, 1}, + /*indices=*/{1}, + /*axis=*/3, + /*expected_output_shape=*/{1, 1, 2, 1, 1}, + /*expected_output=*/{2, 5}, + /*params_is_tensor=*/true, + }, + TestParams{ + /*params_shape=*/{1, 1, 2, 3}, + /*indices_shape=*/{1, 1}, + /*indices=*/{2}, + /*axis=*/-1, + /*expected_output_shape=*/{1, 1, 2, 1, 1}, + /*expected_output=*/{3, 6}, + /*params_is_tensor=*/true, + }, + TestParams{ + /*params_shape=*/{1, 1, 2, 3}, + /*indices_shape=*/{1, 3}, + /*indices=*/{2, 0, 1}, + /*axis=*/3, + /*expected_output_shape=*/{1, 1, 2, 1, 3}, + /*expected_output=*/{3, 1, 2, 6, 4, 5}, + /*params_is_tensor=*/true, + }, TestParams{ - {2, 3}, {1, 2}, {0, 1}, 0, {1, 2, 3}, {1, 2, 3, 4, 5, 6}, false}, + /*params_shape=*/{1, 3, 2}, + /*indices_shape=*/{1, 2, 2}, + /*indices=*/{0, 0, 1, 0}, + /*axis=*/2, + /*expected_output_shape=*/{1, 3, 1, 2, 2}, + /*expected_output=*/{1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5}, + /*params_is_tensor=*/true, + }, + TestParams{ + /*params_shape=*/{1, 2, 3}, + /*indices_shape=*/{1}, + /*indices=*/{0}, + /*axis=*/0, + /*expected_output_shape=*/{1, 2, 3}, + /*expected_output=*/{1, 2, 3, 4, 5, 6}, + /*params_is_tensor=*/false, + }, + TestParams{ + /*params_shape=*/{3, 2}, + /*indices_shape=*/{1, 2}, + /*indices=*/{0, 1}, + /*axis=*/0, + /*expected_output_shape=*/{1, 2, 2}, + /*expected_output=*/{1, 2, 3, 4}, + /*params_is_tensor=*/false, + }, + TestParams{ + /*params_shape=*/{2, 3}, + /*indices_shape=*/{1, 1, 2}, + /*indices=*/{0, 1}, + /*axis=*/0, + /*expected_output_shape=*/{1, 1, 2, 3}, + /*expected_output=*/{1, 2, 3, 4, 5, 6}, + /*params_is_tensor=*/false, + }, + TestParams{ + /*params_shape=*/{3, 2}, + /*indices_shape=*/{2, 2}, + /*indices=*/{0, 2, 1, 0}, + /*axis=*/0, + /*expected_output_shape=*/{2, 2, 2}, + /*expected_output=*/{1, 2, 5, 6, 3, 4, 1, 2}, + /*params_is_tensor=*/false, + }, }; // Ok. for (int i = 0; i < kGatherOKCases; i++) { test->Reset(); + const auto& params_shape = ok_params[i].params_shape; if (ok_params[i].params_is_tensor) { - test->AddTestTensor("params", ok_params[i].params_dims, 1, + std::vector params_dims(params_shape.begin() + 1, + params_shape.end()); + test->AddTestTensor("params", params_dims, params_shape[0], TfDataTypeToTrt(dtype)); } else { - test->AddTestWeights("params", ok_params[i].params_dims, - params_input); + test->AddTestWeights("params", params_shape, params_input); } - test->AddTestTensor("indices", ok_params[i].indices_dims, 1, - nvinfer1::DataType::kINT32); + const auto& indices_shape = ok_params[i].indices_shape; + test->AddTestTensor( + "indices", + std::vector(indices_shape.begin() + 1, indices_shape.end()), + indices_shape[0], nvinfer1::DataType::kINT32); test->AddTestWeights("axis", {1}, {ok_params[i].axis}); test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_gather", &output)); ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + + const auto& expected_output_shape = ok_params[i].expected_output_shape; + const auto& expected_output = ok_params[i].expected_output; + ASSERT_EQ(expected_output.size(), + TrtWeightDimsNumElements(GetTestDims(expected_output_shape))); + const std::vector expected_output_dims( + expected_output_shape.begin() + 1, expected_output_shape.end()); + ExpectTrtDimsEqualsArray(expected_output_dims, output.tensor()->getDimensions()); // Create input in CType and convert expected output to CType. - std::vector converted_expected_output( - ok_params[i].expected_output.begin(), - ok_params[i].expected_output.end()); + std::vector converted_expected_output(expected_output.begin(), + expected_output.end()); DataVec input_data; if (ok_params[i].params_is_tensor) { @@ -3909,9 +4004,11 @@ void TestConvertGather(OpConverterTest* test) { input_data = {{"indices", test::AsTensor(ok_params[i].indices)}}; } DataVec output_data{ - {"my_gather", - ConstructTensor(ok_params[i].expected_output.size())}}; - test->BuildAndRun(input_data, &output_data); + {"my_gather", ConstructTensor(expected_output.size())}}; + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32, + /*batch_size=*/expected_output_shape[0]); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(converted_expected_output)); }