diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 7017c9a170661e..13dba293646170 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -521,12 +521,11 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value, } // Convert an axis from TF format to TRT format while validating. TF format -// includes the batch dimension, while TRT does not. TF can also use negative -// indices. -// TODO(tmorris): Use this method in more ops. +// includes the batch dimension, while TRT does not if implicit batching is used +// (i.e. for tensors). TF can also use negative indices. Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name, - int* trt_axis) { - const int tf_nb_dims = trt_nb_dims + 1; + bool use_implicit_batch, int* trt_axis) { + const int tf_nb_dims = trt_nb_dims + (use_implicit_batch ? 1 : 0); // Check bounds. if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) { return errors::InvalidArgument( @@ -536,13 +535,13 @@ Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name, // Make negative axis positive. if (tf_axis < 0) tf_axis += tf_nb_dims; // Don't allow axis to be the batch dimension. - if (tf_axis == 0) { + if (use_implicit_batch && tf_axis == 0) { return errors::Unimplemented( "TensorRT does not allow manipulation of the batch dimension, at ", node_name); } - // Remove batch dimension. - *trt_axis = tf_axis - 1; + // Remove batch dimension if it is implicit. + *trt_axis = use_implicit_batch ? tf_axis - 1 : tf_axis; return Status::OK(); } @@ -2062,8 +2061,8 @@ Status ConvertExpandDims(OpConverterParams* params) { // Use rank = nbDims + 1 for ConvertAxis's bounds checking to account for // ExpandDim's ability to add an axis at end of the shape. int trt_axis; - TF_RETURN_IF_ERROR( - ConvertAxis(axis[0], dims.nbDims + 1, node_def.name(), &trt_axis)); + TF_RETURN_IF_ERROR(ConvertAxis(axis[0], dims.nbDims + 1, node_def.name(), + /*use_implicit_batch=*/true, &trt_axis)); if (params->validation_only) return Status::OK(); // ExpandDims: Insert new dim of size 1. @@ -2098,8 +2097,8 @@ Status ConvertSqueeze(OpConverterParams* params) { for (int tf_axis : squeeze_dims) { // Make sure axis is valid. int trt_axis; - TF_RETURN_IF_ERROR( - ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis)); + TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(), + /*use_implicit_batch=*/true, &trt_axis)); // Make sure target dimension is size 1. if (input_dims[trt_axis] != 1) { return errors::InvalidArgument( @@ -3294,9 +3293,9 @@ Status ConvertReduce(OpConverterParams* params) { } for (int i = 0; i < tf_axes_list.size(); i++) { int trt_axis; - TF_RETURN_IF_ERROR(ConvertAxis(tf_axes_list[i], - tensor->getDimensions().nbDims, - node_def.name(), &trt_axis)); + TF_RETURN_IF_ERROR( + ConvertAxis(tf_axes_list[i], tensor->getDimensions().nbDims, + node_def.name(), /*use_implicit_batch=*/true, &trt_axis)); axes |= (1 << trt_axis); } @@ -3363,8 +3362,8 @@ Status ConvertPack(OpConverterParams* params) { const nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); const int64 tf_axis = attrs.get("axis"); int trt_axis; - TF_RETURN_IF_ERROR( - ConvertAxis(tf_axis, dims.nbDims + 1, node_def.name(), &trt_axis)); + TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims + 1, node_def.name(), + /*use_implicit_batch=*/true, &trt_axis)); // Compute expanded dimensions and then reshape input tensors. std::vector tensor_dims(dims.d, dims.d + dims.nbDims); @@ -3511,8 +3510,8 @@ Status ConvertSplitHelper(OpConverterParams* params, const nvinfer1::Dims dims = input.GetTrtDims(); // Convert axis. int trt_axis; - TF_RETURN_IF_ERROR( - ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis)); + TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(), + /*use_implicit_batch=*/true, &trt_axis)); // Dimension must equal num_splits for Unstack (when squeeze_after is true) if (squeeze_after && dims.d[trt_axis] != num_splits) { return errors::InvalidArgument( @@ -3640,8 +3639,8 @@ Status ConvertConcat(OpConverterParams* params) { } int trt_axis = 0; const auto dim = inputs.at(0).GetTrtDims(); - TF_RETURN_IF_ERROR( - ConvertAxis(axis[0], dim.nbDims, node_def.name(), &trt_axis)); + TF_RETURN_IF_ERROR(ConvertAxis(axis[0], dim.nbDims, node_def.name(), + /*use_implicit_batch=*/true, &trt_axis)); // Check that dimensions match on non-concatenate axis. TF_RETURN_IF_ERROR(VerifyShapesMatch( absl::Span(inputs).first(num_inputs), trt_axis, @@ -3800,29 +3799,58 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { Status ConvertGather(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - TF_RETURN_IF_ERROR(CheckInputsWeights( - *params, {{"params", false}, {"indices", false}, {"axis", true}})); + // TODO(tmorris): Use CheckInputsWeights by changing bool to enum with an + // option for an input to be either tensor or weight. + if (inputs.size() != 3) { + return errors::InvalidArgument("GatherV2 got ", inputs.size(), + " inputs but expected 3, at ", + node_def.name()); + } + const auto& params_input = inputs.at(0); + const auto& indices_input = inputs.at(1); + const auto& axis_input = inputs.at(2); + if (!axis_input.is_weights()) { + return errors::Unimplemented( + "The input \"axis\" for GatherV2 must be a constant, at ", + node_def.name()); + } + if (!indices_input.is_tensor()) { + return errors::Unimplemented( + "The input \"indices\" for GatherV2 must be a tensor, at ", + node_def.name()); + } + TF_RETURN_IF_ERROR(AllowDataTypes( *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}, /*dtype_attr_name=*/"Tparams")); - absl::Span axis = inputs.at(2).weights().GetSpan(); + TF_RETURN_IF_ERROR(AllowDataTypes(*params, {DataType::DT_INT32}, + /*dtype_attr_name=*/"Tindices")); + + absl::Span axis = axis_input.weights().GetSpan(); if (axis.size() != 1) { return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ", node_def.name()); } int trt_axis = 0; - TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims, - node_def.name(), &trt_axis)); - const TRT_TensorOrWeights& params_tensor = inputs.at(0); - const TRT_TensorOrWeights& indices_tensor = inputs.at(1); - if (indices_tensor.batch_size() != 1) { - return errors::InvalidArgument("Only indices with batch 1 are supported."); + TF_RETURN_IF_ERROR(ConvertAxis(axis[0], params_input.GetTrtDims().nbDims, + node_def.name(), params_input.is_tensor(), + &trt_axis)); + if (params_input.is_weights() && trt_axis != 0) { + return errors::Unimplemented( + "The input axis must be zero when params is a weight."); + } + if (params_input.is_tensor() && indices_input.batch_size() != 1) { + return errors::Unimplemented( + "Indices must have a batch size of 1 when params is a tensor."); } // Both input are tensors, and the TF gather result will have rank: // (params.nbDims + 1) + (indices.nbDims + 1) - 1, - // where "+ 1" adds the batch dim. - const int tf_gather_output_rank = params_tensor.GetTrtDims().nbDims + - indices_tensor.GetTrtDims().nbDims + 1; + // where "+ 1" adds the batch dim. If params is a weight, the TRT rank matches + // the TF rank so we don't have to add + 1. + const int params_tf_rank = + params_input.GetTrtDims().nbDims + (params_input.is_tensor() ? 1 : 0); + const int indices_tf_rank = indices_input.GetTrtDims().nbDims + 1; + const int tf_gather_output_rank = params_tf_rank + indices_tf_rank - 1; if (tf_gather_output_rank > nvinfer1::Dims::MAX_DIMS + 1) { return errors::InvalidArgument( "Result of gather has dimension greater than ", @@ -3830,38 +3858,50 @@ Status ConvertGather(OpConverterParams* params) { } if (params->validation_only) return Status::OK(); + // Convert params to tensor is it is a weight. + nvinfer1::ITensor* params_tensor = nullptr; + if (params_input.is_weights()) { + params_tensor = params->converter->CreateConstantLayer( + params_input.weights(), params_input.GetTrtDims()); + } else { + params_tensor = params_input.tensor(); + } + // Note on how IGatherLayer works: if both the data and indices tensors have // a batch size dimension of size N, it performs: // for batchid in xrange(N): // output[batchid, a0, ..., an, i, ..., j, b0, ..., bn] = ( // data[batchid, a0, ..., an, indices[batchid, i, ..., j] b0, ..., bn]) nvinfer1::IGatherLayer* layer = params->converter->network()->addGather( - *params_tensor.tensor(), *indices_tensor.tensor(), trt_axis); + *params_tensor, *indices_input.tensor(), trt_axis); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - nvinfer1::ITensor* gather_output = layer->getOutput(0); - nvinfer1::Dims trt_gather_output_dims = gather_output->getDimensions(); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + nvinfer1::Dims trt_gather_output_dims = output_tensor->getDimensions(); // Note for the "- 2": one is for the output batch dim encapsulated by TF-TRT, // and the other is for the output dimension that is squeezed by IGatherLayer // because of the implicit batch dim in the indices (see the above note). - if (trt_gather_output_dims.nbDims != tf_gather_output_rank - 2) { + const int expected_trt_output_rank = + tf_gather_output_rank - (params_input.is_tensor() ? 2 : 1); + if (trt_gather_output_dims.nbDims != expected_trt_output_rank) { return errors::Internal( "Get unexpected output dimensions of IGatherLayer. Expect nbDims: ", - tf_gather_output_rank - 2, - ", actual nbDims: ", trt_gather_output_dims.nbDims); + expected_trt_output_rank, ", actual nbDims: ", + trt_gather_output_dims.nbDims); } // Reshape the output so after adding the implicit batch dim it'll match the // output shape of TF GatherV2. - for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) { - trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1]; - } - trt_gather_output_dims.d[trt_axis] = 1; - ++trt_gather_output_dims.nbDims; + if (params_input.is_tensor()) { + for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) { + trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1]; + } + trt_gather_output_dims.d[trt_axis] = 1; + ++trt_gather_output_dims.nbDims; - nvinfer1::ITensor* output_tensor = nullptr; - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - TRT_TensorOrWeights(gather_output), trt_gather_output_dims, - /*validation_only=*/false, &output_tensor)); + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + TRT_TensorOrWeights(output_tensor), trt_gather_output_dims, + /*validation_only=*/false, &output_tensor)); + } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -4121,8 +4161,8 @@ Status ConvertArgMinMax(OpConverterParams* params) { int tf_axis = inputs.at(1).weights().GetSpan()[0]; int trt_axis; nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); - TF_RETURN_IF_ERROR( - ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis)); + TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(), + /*use_implicit_batch=*/true, &trt_axis)); nvinfer1::TopKOperation topk_op; if (node_def.op() == "ArgMin") { topk_op = nvinfer1::TopKOperation::kMIN; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 09b7a60c0839c9..66c6f9b800f087 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -3839,63 +3839,176 @@ void TestConvertGather(OpConverterTest* test) { const NodeDef& node_def = gather.operation.node()->def(); struct TestParams { - std::vector params_dims; - std::vector indices_dims; + // TF shape of the input 'params' (including batch dimension). + std::vector params_shape; + // TF shape of the input 'indices' (including batch dimension). + std::vector indices_shape; std::vector indices; int axis; - std::vector expected_output_dims; + // Expected TF shape of the output (including batch dimension). + std::vector expected_output_shape; std::vector expected_output; + bool params_is_tensor; }; // Input is the same {1, 2, 3, 4, 5, 6} for all cases. - const int kGatherOKCases = 7; + const int kGatherOKCases = 11; const std::vector params_input = {CType(1), CType(2), CType(3), CType(4), CType(5), CType(6)}; TestParams ok_params[kGatherOKCases] = { // Vector indices, and output rank is rank(params). - TestParams{{1, 2, 3}, {}, {0}, 3, {1, 2, 1}, {1, 4}}, - TestParams{{1, 2, 3}, {}, {1}, 2, {1, 1, 3}, {4, 5, 6}}, + TestParams{ + /*params_shape=*/{1, 1, 2, 3}, + /*indices_shape=*/{1}, + /*indices=*/{0}, + /*axis=*/3, + /*expected_output_shape=*/{1, 1, 2, 1}, + /*expected_output=*/{1, 4}, + /*params_is_tensor=*/true, + }, + TestParams{ + /*params_shape=*/{1, 1, 2, 3}, + /*indices_shape=*/{1}, + /*indices=*/{1}, + /*axis=*/2, + /*expected_output_shape=*/{1, 1, 1, 3}, + /*expected_output=*/{4, 5, 6}, + /*params_is_tensor=*/true, + }, // Indices with rank>1, and output rank is rank(params)+rank(indices)-1. - TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1, 1}, {1, 4}}, - TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1, 1}, {2, 5}}, - TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1, 1}, {3, 6}}, TestParams{ - {1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 1, 3}, {3, 1, 2, 6, 4, 5}}, - TestParams{{3, 2}, - {2, 2}, - {0, 0, 1, 0}, - 2, - {3, 1, 2, 2}, - {1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5}}, + /*params_shape=*/{1, 1, 2, 3}, + /*indices_shape=*/{1, 1}, + /*indices=*/{0}, + /*axis=*/3, + /*expected_output_shape=*/{1, 1, 2, 1, 1}, + /*expected_output=*/{1, 4}, + /*params_is_tensor=*/true, + }, + TestParams{ + /*params_shape=*/{1, 1, 2, 3}, + /*indices_shape=*/{1, 1}, + /*indices=*/{1}, + /*axis=*/3, + /*expected_output_shape=*/{1, 1, 2, 1, 1}, + /*expected_output=*/{2, 5}, + /*params_is_tensor=*/true, + }, + TestParams{ + /*params_shape=*/{1, 1, 2, 3}, + /*indices_shape=*/{1, 1}, + /*indices=*/{2}, + /*axis=*/-1, + /*expected_output_shape=*/{1, 1, 2, 1, 1}, + /*expected_output=*/{3, 6}, + /*params_is_tensor=*/true, + }, + TestParams{ + /*params_shape=*/{1, 1, 2, 3}, + /*indices_shape=*/{1, 3}, + /*indices=*/{2, 0, 1}, + /*axis=*/3, + /*expected_output_shape=*/{1, 1, 2, 1, 3}, + /*expected_output=*/{3, 1, 2, 6, 4, 5}, + /*params_is_tensor=*/true, + }, + TestParams{ + /*params_shape=*/{1, 3, 2}, + /*indices_shape=*/{1, 2, 2}, + /*indices=*/{0, 0, 1, 0}, + /*axis=*/2, + /*expected_output_shape=*/{1, 3, 1, 2, 2}, + /*expected_output=*/{1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5}, + /*params_is_tensor=*/true, + }, + TestParams{ + /*params_shape=*/{1, 2, 3}, + /*indices_shape=*/{1}, + /*indices=*/{0}, + /*axis=*/0, + /*expected_output_shape=*/{1, 2, 3}, + /*expected_output=*/{1, 2, 3, 4, 5, 6}, + /*params_is_tensor=*/false, + }, + TestParams{ + /*params_shape=*/{3, 2}, + /*indices_shape=*/{1, 2}, + /*indices=*/{0, 1}, + /*axis=*/0, + /*expected_output_shape=*/{1, 2, 2}, + /*expected_output=*/{1, 2, 3, 4}, + /*params_is_tensor=*/false, + }, + TestParams{ + /*params_shape=*/{2, 3}, + /*indices_shape=*/{1, 1, 2}, + /*indices=*/{0, 1}, + /*axis=*/0, + /*expected_output_shape=*/{1, 1, 2, 3}, + /*expected_output=*/{1, 2, 3, 4, 5, 6}, + /*params_is_tensor=*/false, + }, + TestParams{ + /*params_shape=*/{3, 2}, + /*indices_shape=*/{2, 2}, + /*indices=*/{0, 2, 1, 0}, + /*axis=*/0, + /*expected_output_shape=*/{2, 2, 2}, + /*expected_output=*/{1, 2, 5, 6, 3, 4, 1, 2}, + /*params_is_tensor=*/false, + }, }; // Ok. for (int i = 0; i < kGatherOKCases; i++) { test->Reset(); - test->AddTestTensor("params", ok_params[i].params_dims, 1, - TfDataTypeToTrt(dtype)); - test->AddTestTensor("indices", ok_params[i].indices_dims, 1, - nvinfer1::DataType::kINT32); + const auto& params_shape = ok_params[i].params_shape; + if (ok_params[i].params_is_tensor) { + std::vector params_dims(params_shape.begin() + 1, + params_shape.end()); + test->AddTestTensor("params", params_dims, params_shape[0], + TfDataTypeToTrt(dtype)); + } else { + test->AddTestWeights("params", params_shape, params_input); + } + + const auto& indices_shape = ok_params[i].indices_shape; + test->AddTestTensor( + "indices", + std::vector(indices_shape.begin() + 1, indices_shape.end()), + indices_shape[0], nvinfer1::DataType::kINT32); test->AddTestWeights("axis", {1}, {ok_params[i].axis}); test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_gather", &output)); ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + + const auto& expected_output_shape = ok_params[i].expected_output_shape; + const auto& expected_output = ok_params[i].expected_output; + ASSERT_EQ(expected_output.size(), + TrtWeightDimsNumElements(GetTestDims(expected_output_shape))); + const std::vector expected_output_dims( + expected_output_shape.begin() + 1, expected_output_shape.end()); + ExpectTrtDimsEqualsArray(expected_output_dims, output.tensor()->getDimensions()); // Create input in CType and convert expected output to CType. - std::vector converted_expected_output( - ok_params[i].expected_output.begin(), - ok_params[i].expected_output.end()); + std::vector converted_expected_output(expected_output.begin(), + expected_output.end()); - const DataVec input_data{ - {"params", test::AsTensor(params_input)}, - {"indices", test::AsTensor(ok_params[i].indices)}}; + DataVec input_data; + if (ok_params[i].params_is_tensor) { + input_data = {{"params", test::AsTensor(params_input)}, + {"indices", test::AsTensor(ok_params[i].indices)}}; + } else { + input_data = {{"indices", test::AsTensor(ok_params[i].indices)}}; + } DataVec output_data{ - {"my_gather", - ConstructTensor(ok_params[i].expected_output.size())}}; - test->BuildAndRun(input_data, &output_data); + {"my_gather", ConstructTensor(expected_output.size())}}; + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32, + /*batch_size=*/expected_output_shape[0]); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(converted_expected_output)); } @@ -3947,6 +4060,26 @@ TEST_F(OpConverterTest, ConvertGather) { "TensorRT does not allow manipulation of the " "batch dimension, at my_gather"); } + { + // Axis is not zero when params is a weight, should fail. + Reset(); + AddTestWeights("params", {1, 3}, {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestWeights("axis", {1}, {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input axis must be zero when params is a weight."); + } + { + // Batch size of indices is not 1 when params is a tensor. + Reset(); + AddTestTensor("params", {1, 2, 3}, /*batch_size=*/2); + AddTestTensor("indices", {2}, /*batch_size=*/2); + AddTestWeights("axis", {1}, {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Indices must have a batch size of 1 when params is a tensor."); + } Reset(); TestConvertGather(this);