Skip to content

Commit

Permalink
[ONNX] Optimize export_onnx api to reduce string and model proto exch…
Browse files Browse the repository at this point in the history
…ange (#44332)

Summary:
Optimize export_onnx api to reduce string and model proto exchange in export.cpp

Pull Request resolved: #44332

Reviewed By: bwasti, eellison

Differential Revision: D23880129

Pulled By: bzinodev

fbshipit-source-id: 1d216d8f710f356cbba2334fb21ea15a89dd16fa
  • Loading branch information
shubhambhokare1 authored and facebook-github-bot committed Sep 27, 2020
1 parent 4005afe commit 5b839bc
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 32 deletions.
20 changes: 9 additions & 11 deletions torch/csrc/jit/passes/onnx/shape_type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,10 @@ bool IsGraphValidForInference(std::shared_ptr<Graph> graph) {

void ConvertGraphToONNXProto(
std::shared_ptr<Graph> graph,
onnx::ModelProto& model_proto,
std::shared_ptr<onnx::ModelProto>& model_proto,
int opset_version) {
std::string model_str;
RawDataExportMap export_map;
std::tie(model_str, export_map) = export_onnx(
std::tie(model_proto, export_map) = export_onnx(
graph,
{},
opset_version,
Expand All @@ -250,9 +249,8 @@ void ConvertGraphToONNXProto(
true,
false,
std::string());
model_proto.ParseFromString(model_str);
for (int i = 0; i < model_proto.graph().output_size(); ++i) {
model_proto.mutable_graph()->mutable_output(i)->clear_type();
for (int i = 0; i < model_proto->graph().output_size(); ++i) {
model_proto->mutable_graph()->mutable_output(i)->clear_type();
}
}

Expand Down Expand Up @@ -330,15 +328,15 @@ void ONNXShapeTypeInference(Node* n, int opset_version) {
// TODO: Some ops have conversion happen at Peephole pass.
// The conversion here is incomplete for these ops.
// e.g: ListConstruct, ListUnpack, etc.
onnx::ModelProto model_proto;
std::shared_ptr<onnx::ModelProto> model_proto;
ConvertGraphToONNXProto(n_graph, model_proto, opset_version);
GRAPH_DEBUG("ONNX graph to run shape inference: ", prettyPrint(model_proto));
GRAPH_DEBUG("ONNX graph to run shape inference: ", prettyPrint(*model_proto));

// infer shape
onnx::shape_inference::InferShapes(model_proto);
GRAPH_DEBUG("ONNX graph after shape inference: ", prettyPrint(model_proto));
onnx::shape_inference::InferShapes(*model_proto);
GRAPH_DEBUG("ONNX graph after shape inference: ", prettyPrint(*model_proto));

UpdateOutputTypeByONNXProto(n, clone_node, model_proto);
UpdateOutputTypeByONNXProto(n, clone_node, *model_proto);
GRAPH_DEBUG(
"Torch graph after shape inference:", n->owningGraph()->toString());
}
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/python/python_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,9 @@ void initPythonIRBindings(PyObject* module_) {
bool use_external_data_format,
const std::string& onnx_file_path) {
std::string graph;
std::shared_ptr<::ONNX_NAMESPACE::ModelProto> model_proto;
RawDataExportMap export_map;
std::tie(graph, export_map) = export_onnx(
std::tie(model_proto, export_map) = export_onnx(
g,
initializers,
onnx_opset_version,
Expand All @@ -261,6 +262,7 @@ void initPythonIRBindings(PyObject* module_) {
python_serialized_export_map[kv.first] =
py::bytes(static_cast<const char*>(t.data_ptr()), copy_bytes);
}
graph = serialize_model_proto_to_string(model_proto);
return std::make_tuple(
py::bytes(graph), python_serialized_export_map);
},
Expand Down
14 changes: 10 additions & 4 deletions torch/csrc/jit/serialization/export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,8 @@ std::string pretty_print_onnx(
// conform to the ONNX op specification. Thus, the output will not
// be interpretable by a ONNX-compatible framework. However, PyTorch or
// libtorch will be able to import the IR and play it back.
std::tuple<std::string, RawDataExportMap> export_onnx(
std::tuple<std::shared_ptr<::ONNX_NAMESPACE::ModelProto>, RawDataExportMap>
export_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
Expand Down Expand Up @@ -888,9 +889,14 @@ std::tuple<std::string, RawDataExportMap> export_onnx(
"Exporting model exceed maximum protobuf size of 2GB. "
"Please call torch.onnx.export with use_external_data_format=True.");
GRAPH_UPDATE("onnx proto:", prettyPrint(graph_encoder.get_model_proto()));
return std::make_tuple(
graph_encoder.get_model_proto().SerializeAsString(),
graph_encoder.get_raw_data_export_map());
std::shared_ptr<onnx::ModelProto> model_proto =
std::make_shared<onnx::ModelProto>(graph_encoder.get_model_proto());
return std::make_tuple(model_proto, graph_encoder.get_raw_data_export_map());
}

std::string serialize_model_proto_to_string(
const std::shared_ptr<::ONNX_NAMESPACE::ModelProto>& model_proto) {
return model_proto->SerializeAsString();
}

void check_onnx_proto(const std::string& proto_string) {
Expand Down
41 changes: 25 additions & 16 deletions torch/csrc/jit/serialization/export.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

#include <ostream>

namespace ONNX_NAMESPACE {
class ModelProto;
}

namespace torch {
namespace jit {

Expand All @@ -21,22 +25,27 @@ namespace jit {
// file contents being the raw tensor data.
using RawDataExportMap = std::unordered_map<std::string, at::Tensor>;

TORCH_API std::tuple<std::string, RawDataExportMap> export_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
const std::unordered_map<
std::string,
std::unordered_map<int64_t, std::string>>& dynamic_axes,
bool defer_weight_export = false,
::torch::onnx::OperatorExportTypes operator_export_type =
::torch::onnx::OperatorExportTypes::ONNX,
bool strip_doc_string = true,
bool keep_initializers_as_inputs = true,
const std::map<std::string, int>& custom_opsets = {},
bool add_node_names = true,
bool use_external_data_format = false,
const std::string& onnx_file_path = std::string());
TORCH_API std::
tuple<std::shared_ptr<::ONNX_NAMESPACE::ModelProto>, RawDataExportMap>
export_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
const std::unordered_map<
std::string,
std::unordered_map<int64_t, std::string>>& dynamic_axes,
bool defer_weight_export = false,
::torch::onnx::OperatorExportTypes operator_export_type =
::torch::onnx::OperatorExportTypes::ONNX,
bool strip_doc_string = true,
bool keep_initializers_as_inputs = true,
const std::map<std::string, int>& custom_opsets = {},
bool add_node_names = true,
bool use_external_data_format = false,
const std::string& onnx_file_path = std::string());

TORCH_API std::string serialize_model_proto_to_string(
const std::shared_ptr<::ONNX_NAMESPACE::ModelProto>& model_proto);

TORCH_API void check_onnx_proto(const std::string& proto_string);

Expand Down

0 comments on commit 5b839bc

Please sign in to comment.