Skip to content

Commit

Permalink
Move TF-TRT C++ implementations to compiler/tf2tensorrt.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 230407626
  • Loading branch information
aaroey authored and tensorflower-gardener committed Jan 22, 2019
1 parent a4d4392 commit ea1d491
Show file tree
Hide file tree
Showing 55 changed files with 630 additions and 579 deletions.
440 changes: 440 additions & 0 deletions tensorflow/compiler/tf2tensorrt/BUILD

Large diffs are not rendered by default.

Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"

#include <fstream>
#include <list>
Expand All @@ -24,13 +24,14 @@ limitations under the License.
#include <utility>
#include <vector>

#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/contrib/tensorrt/segment/segment.h"
#include "tensorflow/contrib/tensorrt/test/utils.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
Expand Down Expand Up @@ -63,8 +64,8 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
namespace convert {
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
using absl::StrAppend;
using absl::StrCat;

// Returns compiled TRT version information {Maj, Min, Patch}
std::vector<int> GetLinkedTensorRTVersion() {
Expand Down Expand Up @@ -151,7 +152,7 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) {
if (precision_mode_ == INT8MODE && quantize_ops.count(node->type_string())) {
is_supported_op_type = true;
}
// LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc)
// LINT.ThenChange(//tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc)
if (!is_supported_op_type) {
return errors::Unimplemented("Op type ", node->type_string(),
" is not supported");
Expand Down
Expand Up @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_
#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_

#include <vector>

#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
Expand Down Expand Up @@ -123,4 +123,4 @@ std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_
Expand Up @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/framework/tensor_shape.h"
Expand Down
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"

#include <algorithm>
#include <cstring>
Expand All @@ -24,11 +24,12 @@ limitations under the License.
#include <utility>
#include <vector>

#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
Expand All @@ -43,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tensor_coding.h"
#include "tensorflow/core/platform/types.h"

Expand Down Expand Up @@ -81,9 +83,9 @@ const char* const kInputPHName = "TensorRTInputPH_";
const char* const kOutputPHName = "TensorRTOutputPH_";

namespace convert {
using absl::StrAppend;
using absl::StrCat;
using ::tensorflow::str_util::Split;
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;

inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
nvinfer1::DataType* trt_dtype) {
Expand Down Expand Up @@ -3692,7 +3694,7 @@ tensorflow::Status ConvertGraphDefToEngine(
if (tensorflow::str_util::StartsWith(node_name, kInputPHName) &&
(node_def.op() == "Placeholder")) {
int32 slot_number = -1;
if (!tensorflow::strings::safe_strto32(
if (!tensorflow::strings::safe_strto32( // non-absl ok
node_name.c_str() + strlen(kInputPHName), &slot_number)) {
return tensorflow::errors::InvalidArgument(
"Failed to parse slot number from ", node_name);
Expand Down Expand Up @@ -3721,7 +3723,7 @@ tensorflow::Status ConvertGraphDefToEngine(
} else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) &&
(node_def.op() == "Identity")) {
int32 slot_number = -1;
if (!tensorflow::strings::safe_strto32(
if (!tensorflow::strings::safe_strto32( // non-absl ok
node_name.c_str() + strlen(kOutputPHName), &slot_number)) {
return tensorflow::errors::InvalidArgument(
"Failed to parse slot number from ", node_name);
Expand Down
Expand Up @@ -13,20 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_

#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
Expand Down Expand Up @@ -555,4 +555,4 @@ class Converter {
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
Expand Up @@ -13,19 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"

#include <memory>
#include <unordered_map>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/str_cat.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
Expand All @@ -51,7 +52,7 @@ namespace tensorflow {
namespace tensorrt {
namespace convert {

using ::tensorflow::strings::StrCat;
using absl::StrCat;
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;

Expand Down Expand Up @@ -2316,7 +2317,7 @@ TEST_F(OpConverterTest, ConvertSqueeze) {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
ops::Squeeze::Attrs squeeze_attrs;
squeeze_attrs.axis_ = gtl::ArraySlice<int>(axis);
squeeze_attrs.axis_ = gtl::ArraySlice<int>(axis); // non-absl ok
auto squeeze =
ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs);
return squeeze.operation.node()->def();
Expand Down
Expand Up @@ -12,9 +12,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h"
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h"

#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
Expand All @@ -30,9 +32,9 @@ namespace tensorflow {
namespace tensorrt {
namespace convert {
// TODO(sami): Remove VLOG messages once the code matures
using absl::StrAppend;
using absl::StrCat;
using tensorflow::str_util::Uppercase;
using tensorflow::strings::StrAppend;
using tensorflow::strings::StrCat;

tensorflow::Status TRTOptimizationPass::Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
Expand Down Expand Up @@ -243,7 +245,7 @@ tensorflow::Status TRTOptimizationPass::Optimize(
// If the last token is not an integer, it must be part of the name.
// Otherwise it is port number.
if (tokens.size() > 1 &&
!strings::safe_strto32(tokens.back(), &dumm_port)) {
!strings::safe_strto32(tokens.back(), &dumm_port)) { // non-absl ok
StrAppend(&s, ":", tokens.back());
}
nodes_to_preserve.push_back(s);
Expand Down
Expand Up @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_
#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_

#include <string>

Expand Down Expand Up @@ -77,4 +77,4 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer {

#endif // GOOGLE_CUDA
#endif // GOOGLE_TENSORRT
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_
#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"

#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
Expand Down
Expand Up @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_

#include <memory>

Expand Down Expand Up @@ -47,4 +47,4 @@ Status GetPrecisionMode(const string& name, int* precision_mode);
} // namespace tensorrt
} // namespace tensorflow

#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_
#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
Expand Up @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_
#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_
#define TENSORFLOW_COMPILER_TF2TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_

#include <memory>
#include <vector>

#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
Expand Down Expand Up @@ -70,4 +70,4 @@ REGISTER_KERNEL_BUILDER(Name("GetSerializedResourceOp").Device(DEVICE_GPU),

#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA
#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_
#endif // TENSORFLOW_COMPILER_TF2TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_
Expand Up @@ -18,7 +18,7 @@ limitations under the License.
#include <fstream>
#include <vector>

#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
Expand Down
Expand Up @@ -12,18 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
#include "tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.h"

#include <algorithm>

#include "absl/memory/memory.h"
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/contrib/tensorrt/test/utils.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/strings/str_util.h"
Expand All @@ -39,9 +40,9 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
static Logger logger;
using absl::StrAppend;
using absl::StrCat;
using ::nvinfer1::IRuntime;
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;

// A helper class to call done() when destructed for asynchronous execution.
// Helps simultaneous execution of native and TRT engines.
Expand Down

0 comments on commit ea1d491

Please sign in to comment.