Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move TF-TensorRT util functions #35233

Merged
merged 6 commits into from Dec 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
89 changes: 5 additions & 84 deletions tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
Expand Up @@ -200,18 +200,6 @@ int64 TFAttrs::get<int64>(const string& key) const {
return this->at(key)->i();
}

template <typename TensorShapeType>
inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
bool ignore_first_dim) {
nvinfer1::Dims trt_dims;
const int offset = (ignore_first_dim ? 1 : 0);
for (int i = offset; i < shape.dims(); i++) {
trt_dims.d[i - offset] = shape.dim_size(i);
}
trt_dims.nbDims = shape.dims() - offset;
return trt_dims;
}

template <typename Container>
Status TensorShapeArrayToTrtDims(const Container& shape, nvinfer1::Dims* out,
bool ignore_first_dim = false) {
Expand Down Expand Up @@ -314,66 +302,6 @@ Status ValidateTensorProperties(const string& producer_node_type,
return Status::OK();
}

string DebugString(const nvinfer1::DimensionType type) {
switch (type) {
case nvinfer1::DimensionType::kSPATIAL:
return "kSPATIAL";
case nvinfer1::DimensionType::kCHANNEL:
return "kCHANNEL";
case nvinfer1::DimensionType::kINDEX:
return "kINDEX";
case nvinfer1::DimensionType::kSEQUENCE:
return "kSEQUENCE";
default:
return StrCat(static_cast<int>(type), "=unknown");
}
}

string DebugString(const nvinfer1::DataType trt_dtype) {
switch (trt_dtype) {
case nvinfer1::DataType::kFLOAT:
return "kFLOAT";
case nvinfer1::DataType::kHALF:
return "kHALF";
case nvinfer1::DataType::kINT8:
return "kINT8";
case nvinfer1::DataType::kINT32:
return "kINT32";
default:
return "Invalid TRT data type";
}
}

string DebugString(const nvinfer1::Dims& dims) {
string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
for (int i = 0; i < dims.nbDims; ++i) {
StrAppend(&out, dims.d[i]);
if (VLOG_IS_ON(2)) {
StrAppend(&out, "[", DebugString(dims.type[i]), "],");
} else {
StrAppend(&out, ",");
}
}
StrAppend(&out, ")");
return out;
}

string DebugString(const nvinfer1::Permutation& permutation, int len) {
string out = "nvinfer1::Permutation(";
for (int i = 0; i < len; ++i) {
StrAppend(&out, permutation.order[i], ",");
}
StrAppend(&out, ")");
return out;
}

string DebugString(const nvinfer1::ITensor& tensor) {
return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
", name=", tensor.getName(),
", dtype=", DebugString(tensor.getType()),
", dims=", DebugString(tensor.getDimensions()), ")");
}

Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
const TRT_TensorOrWeights& operand_r,
const bool check_feasibility,
Expand Down Expand Up @@ -581,14 +509,6 @@ inline nvinfer1::Dims GetTrtDimsForTensor(const Tensor& tensor) {
return dims;
}

inline bool HasStaticShape(const nvinfer1::Dims& dims) {
if (dims.nbDims < 0) return false;
for (int d = 0; d < dims.nbDims; ++d) {
if (dims.d[d] < 0) return false;
}
return true;
}

int64_t Prod(const nvinfer1::Dims& dims) {
int64_t count = 1;
for (int d = 0; d < dims.nbDims; ++d) {
Expand Down Expand Up @@ -732,9 +652,10 @@ size_t TRT_ShapedWeights::size_bytes() const {
}

string TRT_ShapedWeights::DebugString() const {
return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_),
", type=", convert::DebugString(type_),
", values=", reinterpret_cast<uintptr_t>(GetValues()), ")");
return StrCat(
"TRT_ShapedWeights(shape=", tensorflow::tensorrt::DebugString(shape_),
", type=", tensorflow::tensorrt::DebugString(type_),
", values=", reinterpret_cast<uintptr_t>(GetValues()), ")");
}

// A fake ITensor implementation used to check whether the TF-TRT converter can
Expand Down Expand Up @@ -858,7 +779,7 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const {
string TRT_TensorOrWeights::DebugString() const {
string output = "TRT_TensorOrWeights(type=";
if (is_tensor()) {
StrAppend(&output, "tensor=", convert::DebugString(*tensor()),
StrAppend(&output, "tensor=", tensorflow::tensorrt::DebugString(*tensor()),
", batch_size=", batch_size_);
} else {
StrAppend(&output, "weights=", weights_.DebugString());
Expand Down
13 changes: 0 additions & 13 deletions tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
Expand Up @@ -42,14 +42,6 @@ namespace tensorrt {
namespace convert {
using ::stream_executor::port::StatusOr;

#define IS_TRT_VERSION_GE(major, minor, patch, build) \
((NV_TENSORRT_MAJOR > major) || \
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
NV_TENSORRT_PATCH > patch) || \
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build))

struct EngineConnection {
// Constructs a non-control edge.
EngineConnection(const string& outside, int out_id, int out_port,
Expand Down Expand Up @@ -164,11 +156,6 @@ class OutputEdgeValidator {
bool operator()(const Edge* out_edge) const;
};

string DebugString(const nvinfer1::DimensionType type);
string DebugString(const nvinfer1::DataType trt_dtype);
string DebugString(const nvinfer1::Dims& dims);
string DebugString(const nvinfer1::Permutation& permutation, int len);
string DebugString(const nvinfer1::ITensor& tensor);
int64_t TrtWeightDimsNumElements(const nvinfer1::Dims& dims);
int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims);

Expand Down
68 changes: 68 additions & 0 deletions tensorflow/compiler/tf2tensorrt/convert/utils.cc
Expand Up @@ -17,6 +17,8 @@ limitations under the License.

#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"

namespace tensorflow {
namespace tensorrt {
Expand Down Expand Up @@ -51,5 +53,71 @@ Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode) {
return Status::OK();
}

using absl::StrAppend;
using absl::StrCat;

#if GOOGLE_CUDA && GOOGLE_TENSORRT
string DebugString(const nvinfer1::DimensionType type) {
switch (type) {
case nvinfer1::DimensionType::kSPATIAL:
return "kSPATIAL";
case nvinfer1::DimensionType::kCHANNEL:
return "kCHANNEL";
case nvinfer1::DimensionType::kINDEX:
return "kINDEX";
case nvinfer1::DimensionType::kSEQUENCE:
return "kSEQUENCE";
default:
return StrCat(static_cast<int>(type), "=unknown");
}
}

string DebugString(const nvinfer1::Dims& dims) {
string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
for (int i = 0; i < dims.nbDims; ++i) {
StrAppend(&out, dims.d[i]);
if (VLOG_IS_ON(2)) {
StrAppend(&out, "[", DebugString(dims.type[i]), "],");
} else {
StrAppend(&out, ",");
}
}
StrAppend(&out, ")");
return out;
}

string DebugString(const nvinfer1::DataType trt_dtype) {
switch (trt_dtype) {
case nvinfer1::DataType::kFLOAT:
return "kFLOAT";
case nvinfer1::DataType::kHALF:
return "kHALF";
case nvinfer1::DataType::kINT8:
return "kINT8";
case nvinfer1::DataType::kINT32:
return "kINT32";
default:
return "Invalid TRT data type";
}
}

string DebugString(const nvinfer1::Permutation& permutation, int len) {
string out = "nvinfer1::Permutation(";
for (int i = 0; i < len; ++i) {
StrAppend(&out, permutation.order[i], ",");
}
StrAppend(&out, ")");
return out;
}

string DebugString(const nvinfer1::ITensor& tensor) {
return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
", name=", tensor.getName(),
", dtype=", DebugString(tensor.getType()),
", dims=", DebugString(tensor.getDimensions()), ")");
}

#endif

} // namespace tensorrt
} // namespace tensorflow
51 changes: 51 additions & 0 deletions tensorflow/compiler/tf2tensorrt/convert/utils.h
Expand Up @@ -17,9 +17,15 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_

#include <memory>
#include <vector>

#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"

#if GOOGLE_CUDA && GOOGLE_TENSORRT
#include "third_party/tensorrt/NvInfer.h"
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT

namespace tensorflow {
namespace tensorrt {

Expand All @@ -45,6 +51,51 @@ Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name);

Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode);

// Define a hash function for vector<TensorShape> because it is used as the key
// for the engine cache.
struct VectorTensorShapeHasher {
std::size_t operator()(const std::vector<TensorShape>& key) const {
return std::hash<std::string>()(TensorShapeUtils::ShapeListString(key));
}
};

#if GOOGLE_CUDA && GOOGLE_TENSORRT

#define IS_TRT_VERSION_GE(major, minor, patch, build) \
((NV_TENSORRT_MAJOR > major) || \
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
NV_TENSORRT_PATCH > patch) || \
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build))

string DebugString(const nvinfer1::DimensionType type);
string DebugString(const nvinfer1::Dims& dims);
string DebugString(const nvinfer1::DataType trt_dtype);
string DebugString(const nvinfer1::Permutation& permutation, int len);
string DebugString(const nvinfer1::ITensor& tensor);

inline bool HasStaticShape(const nvinfer1::Dims& dims) {
if (dims.nbDims < 0) return false;
for (int d = 0; d < dims.nbDims; ++d) {
if (dims.d[d] < 0) return false;
}
return true;
}

template <typename TensorShapeType>
inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
bool ignore_first_dim) {
nvinfer1::Dims trt_dims;
const int offset = (ignore_first_dim ? 1 : 0);
for (int i = offset; i < shape.dims(); i++) {
trt_dims.d[i - offset] = shape.dim_size(i);
}
trt_dims.nbDims = shape.dims() - offset;
return trt_dims;
}
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT

} // namespace tensorrt
} // namespace tensorflow

Expand Down
7 changes: 0 additions & 7 deletions tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h
Expand Up @@ -114,13 +114,6 @@ class LRUCache {
}
};

// Define a hash function for vector<TensorShape> because it is used as the key
// for the engine cache.
struct VectorTensorShapeHasher {
std::size_t operator()(const std::vector<TensorShape>& key) const {
return std::hash<std::string>()(TensorShapeUtils::ShapeListString(key));
}
};

#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
Expand Down