diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index f568b947959e7b..441a0f5eda2256 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -4641,7 +4641,6 @@ static void RegisterValidatableOpConverters( (*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; (*registration)["ExpandDims"] = ConvertExpandDims; (*registration)["GatherV2"] = ConvertGather; - (*registration)["Identity"] = ConvertIdentity; // Identity should be removed (*registration)["LeakyRelu"] = ConvertLeakyRelu; (*registration)["MatMul"] = ConvertMatMul; (*registration)["Pack"] = ConvertPack; @@ -4650,7 +4649,6 @@ static void RegisterValidatableOpConverters( (*registration)["Reshape"] = ConvertReshape; (*registration)["Rsqrt"] = ConvertRsqrt; (*registration)["Slice"] = ConvertSlice; - (*registration)["Snapshot"] = ConvertIdentity; // Snapshot should be removed (*registration)["Softmax"] = ConvertSoftmax; (*registration)["SpaceToDepth"] = ConvertDepthSpaceShuffle; (*registration)["Split"] = ConvertSplit; @@ -4688,6 +4686,11 @@ static void RegisterValidatableOpConverters( for (auto arg_minmax_type : {"ArgMin", "ArgMax"}) { (*registration)[arg_minmax_type] = ConvertArgMinMax; } + // The following are no-ops during inference and will not be mapped to any TRT + // layer. + for (auto identity_op_type : {"Identity", "Snapshot", "StopGradient"}) { + (*registration)[identity_op_type] = ConvertIdentity; + } } void TrtNodeValidator::RegisterOpValidators() {