Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF-TRT] Support constant params input for gather #29766

Merged
140 changes: 90 additions & 50 deletions tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
Expand Up @@ -521,12 +521,11 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
}

// Convert an axis from TF format to TRT format while validating. TF format
// includes the batch dimension, while TRT does not. TF can also use negative
// indices.
// TODO(tmorris): Use this method in more ops.
// includes the batch dimension, while TRT does not if implicit batching is used
// (i.e. for tensors). TF can also use negative indices.
Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name,
int* trt_axis) {
const int tf_nb_dims = trt_nb_dims + 1;
bool use_implicit_batch, int* trt_axis) {
const int tf_nb_dims = trt_nb_dims + (use_implicit_batch ? 1 : 0);
// Check bounds.
if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) {
return errors::InvalidArgument(
Expand All @@ -536,13 +535,13 @@ Status ConvertAxis(int tf_axis, int trt_nb_dims, absl::string_view node_name,
// Make negative axis positive.
if (tf_axis < 0) tf_axis += tf_nb_dims;
// Don't allow axis to be the batch dimension.
if (tf_axis == 0) {
if (use_implicit_batch && tf_axis == 0) {
return errors::Unimplemented(
"TensorRT does not allow manipulation of the batch dimension, at ",
node_name);
}
// Remove batch dimension.
*trt_axis = tf_axis - 1;
// Remove batch dimension if it is implicit.
*trt_axis = use_implicit_batch ? tf_axis - 1 : tf_axis;
return Status::OK();
}

Expand Down Expand Up @@ -2062,8 +2061,8 @@ Status ConvertExpandDims(OpConverterParams* params) {
// Use rank = nbDims + 1 for ConvertAxis's bounds checking to account for
// ExpandDim's ability to add an axis at end of the shape.
int trt_axis;
TF_RETURN_IF_ERROR(
ConvertAxis(axis[0], dims.nbDims + 1, node_def.name(), &trt_axis));
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], dims.nbDims + 1, node_def.name(),
/*use_implicit_batch=*/true, &trt_axis));
if (params->validation_only) return Status::OK();

// ExpandDims: Insert new dim of size 1.
Expand Down Expand Up @@ -2098,8 +2097,8 @@ Status ConvertSqueeze(OpConverterParams* params) {
for (int tf_axis : squeeze_dims) {
// Make sure axis is valid.
int trt_axis;
TF_RETURN_IF_ERROR(
ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis));
TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
/*use_implicit_batch=*/true, &trt_axis));
// Make sure target dimension is size 1.
if (input_dims[trt_axis] != 1) {
return errors::InvalidArgument(
Expand Down Expand Up @@ -3294,9 +3293,9 @@ Status ConvertReduce(OpConverterParams* params) {
}
for (int i = 0; i < tf_axes_list.size(); i++) {
int trt_axis;
TF_RETURN_IF_ERROR(ConvertAxis(tf_axes_list[i],
tensor->getDimensions().nbDims,
node_def.name(), &trt_axis));
TF_RETURN_IF_ERROR(
ConvertAxis(tf_axes_list[i], tensor->getDimensions().nbDims,
node_def.name(), /*use_implicit_batch=*/true, &trt_axis));
axes |= (1 << trt_axis);
}

Expand Down Expand Up @@ -3363,8 +3362,8 @@ Status ConvertPack(OpConverterParams* params) {
const nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
const int64 tf_axis = attrs.get<int64>("axis");
int trt_axis;
TF_RETURN_IF_ERROR(
ConvertAxis(tf_axis, dims.nbDims + 1, node_def.name(), &trt_axis));
TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims + 1, node_def.name(),
/*use_implicit_batch=*/true, &trt_axis));

// Compute expanded dimensions and then reshape input tensors.
std::vector<int> tensor_dims(dims.d, dims.d + dims.nbDims);
Expand Down Expand Up @@ -3511,8 +3510,8 @@ Status ConvertSplitHelper(OpConverterParams* params,
const nvinfer1::Dims dims = input.GetTrtDims();
// Convert axis.
int trt_axis;
TF_RETURN_IF_ERROR(
ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis));
TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
/*use_implicit_batch=*/true, &trt_axis));
// Dimension must equal num_splits for Unstack (when squeeze_after is true)
if (squeeze_after && dims.d[trt_axis] != num_splits) {
return errors::InvalidArgument(
Expand Down Expand Up @@ -3640,8 +3639,8 @@ Status ConvertConcat(OpConverterParams* params) {
}
int trt_axis = 0;
const auto dim = inputs.at(0).GetTrtDims();
TF_RETURN_IF_ERROR(
ConvertAxis(axis[0], dim.nbDims, node_def.name(), &trt_axis));
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], dim.nbDims, node_def.name(),
/*use_implicit_batch=*/true, &trt_axis));
// Check that dimensions match on non-concatenate axis.
TF_RETURN_IF_ERROR(VerifyShapesMatch(
absl::Span<const TRT_TensorOrWeights>(inputs).first(num_inputs), trt_axis,
Expand Down Expand Up @@ -3800,68 +3799,109 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
Status ConvertGather(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(
*params, {{"params", false}, {"indices", false}, {"axis", true}}));
// TODO(tmorris): Use CheckInputsWeights by changing bool to enum with an
// option for an input to be either tensor or weight.
if (inputs.size() != 3) {
return errors::InvalidArgument("GatherV2 got ", inputs.size(),
" inputs but expected 3, at ",
node_def.name());
}
const auto& params_input = inputs.at(0);
const auto& indices_input = inputs.at(1);
const auto& axis_input = inputs.at(2);
if (!axis_input.is_weights()) {
return errors::Unimplemented(
"The input \"axis\" for GatherV2 must be a constant, at ",
node_def.name());
}
if (!indices_input.is_tensor()) {
return errors::Unimplemented(
"The input \"indices\" for GatherV2 must be a tensor, at ",
node_def.name());
}

TF_RETURN_IF_ERROR(AllowDataTypes(
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32},
/*dtype_attr_name=*/"Tparams"));
absl::Span<const int> axis = inputs.at(2).weights().GetSpan<int>();
TF_RETURN_IF_ERROR(AllowDataTypes(*params, {DataType::DT_INT32},
/*dtype_attr_name=*/"Tindices"));

absl::Span<const int> axis = axis_input.weights().GetSpan<int>();
if (axis.size() != 1) {
return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ",
node_def.name());
}
int trt_axis = 0;
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims,
node_def.name(), &trt_axis));
const TRT_TensorOrWeights& params_tensor = inputs.at(0);
const TRT_TensorOrWeights& indices_tensor = inputs.at(1);
if (indices_tensor.batch_size() != 1) {
return errors::InvalidArgument("Only indices with batch 1 are supported.");
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], params_input.GetTrtDims().nbDims,
node_def.name(), params_input.is_tensor(),
&trt_axis));
if (params_input.is_weights() && trt_axis != 0) {
return errors::Unimplemented(
"The input axis must be zero when params is a weight.");
}
if (params_input.is_tensor() && indices_input.batch_size() != 1) {
return errors::Unimplemented(
"Indices must have a batch size of 1 when params is a tensor.");
}
// Both input are tensors, and the TF gather result will have rank:
// (params.nbDims + 1) + (indices.nbDims + 1) - 1,
// where "+ 1" adds the batch dim.
const int tf_gather_output_rank = params_tensor.GetTrtDims().nbDims +
indices_tensor.GetTrtDims().nbDims + 1;
// where "+ 1" adds the batch dim. If params is a weight, the TRT rank matches
// the TF rank so we don't have to add + 1.
const int params_tf_rank =
params_input.GetTrtDims().nbDims + (params_input.is_tensor() ? 1 : 0);
const int indices_tf_rank = indices_input.GetTrtDims().nbDims + 1;
const int tf_gather_output_rank = params_tf_rank + indices_tf_rank - 1;
if (tf_gather_output_rank > nvinfer1::Dims::MAX_DIMS + 1) {
return errors::InvalidArgument(
"Result of gather has dimension greater than ",
nvinfer1::Dims::MAX_DIMS + 1);
}
if (params->validation_only) return Status::OK();

// Convert params to tensor is it is a weight.
nvinfer1::ITensor* params_tensor = nullptr;
if (params_input.is_weights()) {
params_tensor = params->converter->CreateConstantLayer(
params_input.weights(), params_input.GetTrtDims());
} else {
params_tensor = params_input.tensor();
}

// Note on how IGatherLayer works: if both the data and indices tensors have
// a batch size dimension of size N, it performs:
// for batchid in xrange(N):
// output[batchid, a0, ..., an, i, ..., j, b0, ..., bn] = (
// data[batchid, a0, ..., an, indices[batchid, i, ..., j] b0, ..., bn])
nvinfer1::IGatherLayer* layer = params->converter->network()->addGather(
*params_tensor.tensor(), *indices_tensor.tensor(), trt_axis);
*params_tensor, *indices_input.tensor(), trt_axis);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());

nvinfer1::ITensor* gather_output = layer->getOutput(0);
nvinfer1::Dims trt_gather_output_dims = gather_output->getDimensions();
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
nvinfer1::Dims trt_gather_output_dims = output_tensor->getDimensions();
// Note for the "- 2": one is for the output batch dim encapsulated by TF-TRT,
// and the other is for the output dimension that is squeezed by IGatherLayer
// because of the implicit batch dim in the indices (see the above note).
if (trt_gather_output_dims.nbDims != tf_gather_output_rank - 2) {
const int expected_trt_output_rank =
tf_gather_output_rank - (params_input.is_tensor() ? 2 : 1);
if (trt_gather_output_dims.nbDims != expected_trt_output_rank) {
return errors::Internal(
"Get unexpected output dimensions of IGatherLayer. Expect nbDims: ",
tf_gather_output_rank - 2,
", actual nbDims: ", trt_gather_output_dims.nbDims);
expected_trt_output_rank, ", actual nbDims: ",
trt_gather_output_dims.nbDims);
}
// Reshape the output so after adding the implicit batch dim it'll match the
// output shape of TF GatherV2.
for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) {
trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1];
}
trt_gather_output_dims.d[trt_axis] = 1;
++trt_gather_output_dims.nbDims;
if (params_input.is_tensor()) {
for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) {
trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1];
}
trt_gather_output_dims.d[trt_axis] = 1;
++trt_gather_output_dims.nbDims;

nvinfer1::ITensor* output_tensor = nullptr;
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
TRT_TensorOrWeights(gather_output), trt_gather_output_dims,
/*validation_only=*/false, &output_tensor));
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
TRT_TensorOrWeights(output_tensor), trt_gather_output_dims,
/*validation_only=*/false, &output_tensor));
}

params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
return Status::OK();
Expand Down Expand Up @@ -4121,8 +4161,8 @@ Status ConvertArgMinMax(OpConverterParams* params) {
int tf_axis = inputs.at(1).weights().GetSpan<int>()[0];
int trt_axis;
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
TF_RETURN_IF_ERROR(
ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis));
TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(),
/*use_implicit_batch=*/true, &trt_axis));
nvinfer1::TopKOperation topk_op;
if (node_def.op() == "ArgMin") {
topk_op = nvinfer1::TopKOperation::kMIN;
Expand Down