-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[ONNX] Adding ONNX large model export support in exporter #33062
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0f3d928
4f32360
b3fa08c
7e99b9d
d4969e5
1686a82
bcd6f73
1025732
d32b28e
fd4ec21
8e262ec
38fbafe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,8 @@ | |
| from torch.nn.utils import rnn as rnn_utils | ||
| from model_defs.lstm_flattening_result import LstmFlatteningResult | ||
| from model_defs.rnn_model_with_packed_sequence import RnnModelWithPackedSequence | ||
| from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfNoLapack, enableScriptTest | ||
| from test_pytorch_common import (skipIfUnsupportedMinOpsetVersion, enableScriptTest, | ||
| skipIfNoLapack) | ||
| from test_pytorch_common import BATCH_SIZE | ||
| from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE | ||
| import model_defs.word_language_model as word_language_model | ||
|
|
@@ -124,6 +125,78 @@ def _run_test(m): | |
| _run_test(script_model) | ||
| _run_test(model) | ||
|
|
||
| def run_model_test_with_external_data(self, model, input, rtol=0.001, atol=1e-7, | ||
| example_outputs=None, do_constant_folding=True, | ||
| dynamic_axes=None, input_names=None, output_names=None, | ||
| ort_optim_on=True): | ||
| import os | ||
| import tempfile | ||
|
|
||
| model.eval() | ||
| with torch.no_grad(): | ||
| if isinstance(input, torch.Tensor): | ||
| input = (input,) | ||
| # In-place operators will update input tensor data as well. | ||
| # Thus inputs are replicated before every forward call. | ||
| input_copy = copy.deepcopy(input) | ||
| output = model(*input_copy) | ||
| if isinstance(output, torch.Tensor): | ||
| output = (output,) | ||
|
|
||
| # export the model to ONNX | ||
| with tempfile.TemporaryDirectory() as tmpdirname: | ||
| model_file_name = os.path.join(tmpdirname, 'model.onnx') | ||
| input_copy = copy.deepcopy(input) | ||
| torch.onnx.export(model, input_copy, model_file_name, | ||
| opset_version=self.opset_version, | ||
| example_outputs=output, | ||
| verbose=False, | ||
| do_constant_folding=do_constant_folding, | ||
| keep_initializers_as_inputs=self.keep_initializers_as_inputs, | ||
| dynamic_axes=dynamic_axes, | ||
| input_names=input_names, output_names=output_names, | ||
| use_external_data_format=True) | ||
| # compute onnxruntime output prediction | ||
| ort_sess_opt = onnxruntime.SessionOptions() | ||
| ort_sess_opt.graph_optimization_level = \ | ||
| onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED if ort_optim_on else \ | ||
| onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL | ||
| ort_sess = onnxruntime.InferenceSession(model_file_name, sess_options=ort_sess_opt) | ||
| input_copy = copy.deepcopy(input) | ||
| ort_test_with_input(ort_sess, input_copy, output, rtol, atol) | ||
|
|
||
|
|
||
| @skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9. | ||
| def test_embedding_model_with_external_data(self): | ||
| class LargeModel(torch.nn.Module): | ||
|
||
| def __init__(self): | ||
| super(LargeModel, self).__init__() | ||
| dim = 15 | ||
| n = 4 * 100 | ||
| self.emb = torch.nn.Embedding(n, dim) | ||
| self.lin1 = torch.nn.Linear(dim, 1) | ||
| self.seq = torch.nn.Sequential( | ||
| self.emb, | ||
| self.lin1, | ||
| ) | ||
|
|
||
| def forward(self, input): | ||
| return self.seq(input) | ||
|
|
||
| model = LargeModel() | ||
| x = torch.tensor([2], dtype=torch.long) | ||
| self.run_model_test_with_external_data(model, x) | ||
|
|
||
| @skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9. | ||
| def test_mobilenet_v2_with_external_data(self): | ||
| model = torchvision.models.mobilenet_v2(pretrained=True) | ||
| x = torch.randn(2, 3, 224, 224, requires_grad=True) | ||
| # We are turning off Onnx Runtime optimization off in this test, | ||
| # because external data format is not supported to in ORT optimizer. | ||
| # Once that support is added, we can set ort_optim_on=True (default). | ||
| self.run_model_test_with_external_data(model, x, rtol=1e-3, atol=1e-5, | ||
| ort_optim_on=False) | ||
|
|
||
| # Export Torchvision models | ||
|
|
||
| def test_alexnet(self): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,7 +21,7 @@ | |
| #include <sstream> | ||
| #include <string> | ||
| #include <vector> | ||
|
|
||
| #include <regex> | ||
| namespace torch { | ||
| namespace jit { | ||
|
|
||
|
|
@@ -110,6 +110,47 @@ void validateGraph( | |
| EliminateDeadCode(graph->block(), true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); | ||
| } | ||
|
|
||
| std::string GetFileRootPath(const std::string& rootPath) { | ||
| std::string rootPath_ = rootPath; | ||
| // First, making slash consistent. | ||
| std::replace(rootPath_.begin(), rootPath_.end(), '\\', '/'); | ||
| // Second, remove trailing slashes, if any | ||
| std::regex trailer("/+$"); | ||
| std::string root = std::regex_replace(rootPath_, trailer, std::string()); | ||
| std::string folder = root.substr(0, root.find_last_of('/')); | ||
| if (folder == rootPath_) { // If no root folder specified, select cwd. | ||
| return std::string("."); | ||
| } | ||
| return folder; | ||
| } | ||
|
|
||
| std::string GetExternalFileName(const c10::optional<std::string> external_ref) { | ||
| auto tensorName = external_ref.value(); | ||
| const std::string illegalChars = "\\/:?\"<>|"; | ||
| for (int i = 0; i < tensorName.size(); i++) { | ||
| if (illegalChars.find(tensorName[i]) != std::string::npos) { | ||
| tensorName[i] = '_'; | ||
| } | ||
| } | ||
| return tensorName; | ||
| } | ||
|
|
||
| void CloseFile(FILE* fp) { | ||
| fclose(fp); | ||
| } | ||
|
|
||
| void CreateExternalFile(const at::Tensor& tensor, const std::string& tensorName, | ||
| const std::string& onnx_file_path) { | ||
| auto folder = GetFileRootPath(onnx_file_path); | ||
|
||
| std::string fullFilePath = folder + "/" + tensorName; | ||
| std::unique_ptr<FILE, decltype(&CloseFile)> fp(fopen(fullFilePath.c_str(), "wb"), | ||
| &CloseFile); | ||
| if (fp == NULL) { | ||
| throw std::runtime_error(std::string("ONNX export failed. Could not open file or directory: ") + fullFilePath); | ||
| } | ||
| fwrite(tensor.data_ptr(), tensor.element_size(), tensor.numel(), fp.get()); | ||
| } // fclose() called here through CloseFile(), if FILE* is not a null pointer. | ||
|
|
||
| class EncoderBase { | ||
| public: | ||
| EncoderBase( | ||
|
|
@@ -135,7 +176,9 @@ class EncoderBase { | |
| const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes = | ||
| std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>(), | ||
| bool keep_initializers_as_inputs = true, | ||
| bool add_node_names = true); | ||
| bool add_node_names = true, | ||
| bool use_external_data_format = false, | ||
| const std::string& onnx_file_path = std::string()); | ||
|
|
||
| void EncodeBlock( | ||
| onnx::GraphProto* graph_proto, | ||
|
|
@@ -145,12 +188,16 @@ class EncoderBase { | |
| const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes = | ||
| std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>(), | ||
| bool keep_initializers_as_inputs = true, | ||
| bool add_node_names = true); | ||
| bool add_node_names = true, | ||
| bool use_external_data_format = false, | ||
| const std::string& onnx_file_path = std::string()); | ||
|
|
||
| virtual void EncodeTensor( | ||
| onnx::TensorProto* tensor_proto, | ||
| const at::Tensor& tensor, | ||
| const c10::optional<std::string> external_ref = {}) = 0; | ||
| const c10::optional<std::string> external_ref = {}, | ||
| const bool use_external_data_format = false, | ||
| const std::string& onnx_file_path = std::string()) = 0; | ||
|
|
||
| virtual void EncodeIntermediateValueInfo( | ||
| onnx::GraphProto* graph_proto, | ||
|
|
@@ -174,6 +221,13 @@ class EncoderBase { | |
| onnx_torch::OperatorExportTypes operator_export_type_; | ||
| bool strip_doc_; | ||
| std::set<std::string> domains_; | ||
|
|
||
| // For large models, the parameters can be stored in separate binary files. | ||
| // This parameter sets a threshold on the number of elements in the parameter | ||
| // tensor, beyond which the parameter is stored in a separate file (if API | ||
| // argument use_external_data_format is set to True). This threshold is in place | ||
| // so as not to create too many external files. | ||
| const size_t ParamSizeThresholdForExternalStorage = 1024; | ||
| }; | ||
|
|
||
| onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) { | ||
|
|
@@ -264,9 +318,12 @@ void EncoderBase::EncodeGraph( | |
| const std::map<std::string, at::Tensor>& initializers, | ||
| const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes, | ||
| bool keep_initializers_as_inputs, | ||
| bool add_node_names) { | ||
| bool add_node_names, | ||
| bool use_external_data_format, | ||
| const std::string& onnx_file_path) { | ||
| EncodeBlock(graph_proto, graph->block(), initializers, dynamic_axes, | ||
| keep_initializers_as_inputs, add_node_names); | ||
| keep_initializers_as_inputs, add_node_names, use_external_data_format, | ||
| onnx_file_path); | ||
| } | ||
|
|
||
| void EncoderBase::EncodeBlock( | ||
|
|
@@ -275,7 +332,9 @@ void EncoderBase::EncodeBlock( | |
| const std::map<std::string, at::Tensor>& initializers, | ||
| const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes, | ||
| bool keep_initializers_as_inputs, | ||
| bool add_node_names) { | ||
| bool add_node_names, | ||
| bool use_external_data_format, | ||
| const std::string& onnx_file_path) { | ||
| AT_ASSERT(graph_proto != nullptr); | ||
| std::string block_name = "torch-jit-export"; | ||
| if (num_blocks_) { | ||
|
|
@@ -401,7 +460,8 @@ void EncoderBase::EncodeBlock( | |
| for (auto& name_tensor_pair : initializers) { | ||
| auto p = graph_proto->add_initializer(); | ||
| p->set_name(name_tensor_pair.first); | ||
| EncodeTensor(p, name_tensor_pair.second, name_tensor_pair.first); | ||
| EncodeTensor(p, name_tensor_pair.second, name_tensor_pair.first, | ||
| use_external_data_format, onnx_file_path); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -481,7 +541,9 @@ class GraphEncoder : public EncoderBase { | |
| bool strip_doc, | ||
| bool keep_initializers_as_inputs, | ||
| const std::map<std::string, int>& custom_opsets, | ||
| bool add_node_names); | ||
| bool add_node_names, | ||
| bool use_external_data_format, | ||
| const std::string& onnx_file_path); | ||
|
|
||
| RawDataExportMap get_raw_data_export_map() { | ||
| return raw_data_export_map_; | ||
|
|
@@ -491,7 +553,9 @@ class GraphEncoder : public EncoderBase { | |
| void EncodeTensor( | ||
| onnx::TensorProto* tensor_proto, | ||
| const at::Tensor& tensor, | ||
| const c10::optional<std::string> external_ref = {}) override; | ||
| const c10::optional<std::string> external_ref = {}, | ||
| const bool use_external_data_format = false, | ||
| const std::string& onnx_file_path = std::string()) override; | ||
|
|
||
| RawDataExportMap raw_data_export_map_; | ||
| bool defer_weight_export_; | ||
|
|
@@ -507,7 +571,9 @@ GraphEncoder::GraphEncoder( | |
| bool strip_doc, | ||
| bool keep_initializers_as_inputs, | ||
| const std::map<std::string, int>& custom_opsets, | ||
| bool add_node_names) | ||
| bool add_node_names, | ||
| bool use_external_data_format, | ||
| const std::string& onnx_file_path) | ||
| : EncoderBase(operator_export_type, strip_doc), | ||
| defer_weight_export_(defer_weight_export) { | ||
| if (operator_export_type != onnx_torch::OperatorExportTypes::RAW) { | ||
|
|
@@ -519,7 +585,8 @@ GraphEncoder::GraphEncoder( | |
| imp->set_version(onnx_opset_version); | ||
|
|
||
| EncodeGraph(model_proto_.mutable_graph(), graph, initializers, dynamic_axes, | ||
| keep_initializers_as_inputs, add_node_names); | ||
| keep_initializers_as_inputs, add_node_names, use_external_data_format, | ||
| onnx_file_path); | ||
|
|
||
| for (const std::string& domain : domains_) { | ||
| auto* opset = model_proto_.add_opset_import(); | ||
|
|
@@ -544,7 +611,9 @@ GraphEncoder::GraphEncoder( | |
| void GraphEncoder::EncodeTensor( | ||
| onnx::TensorProto* tensor_proto, | ||
| const at::Tensor& tensor, | ||
| const c10::optional<std::string> external_ref) { | ||
| const c10::optional<std::string> external_ref, | ||
| const bool use_external_data_format, | ||
| const std::string& onnx_file_path) { | ||
|
||
| for (auto d : tensor.sizes()) { | ||
| tensor_proto->add_dims(d); | ||
| } | ||
|
|
@@ -559,6 +628,11 @@ void GraphEncoder::EncodeTensor( | |
| else { | ||
| t = tensor.contiguous().cpu(); | ||
| } | ||
|
|
||
| // Either defer_weight_export should be true and external_ref must be present, | ||
| // or use_external_data_format should be true, not both at the same time. They can | ||
| // both be false at the same time (for ONNX export for regular model size). | ||
| AT_ASSERT(!((defer_weight_export_ && external_ref) && use_external_data_format)); | ||
| // Add a buffer to the raw_data_export_map for the caller to dump into an | ||
| // external data store. If external_ref is not specified, we instead dump | ||
| // the contiguous data into the protobuf itself | ||
|
|
@@ -571,8 +645,22 @@ void GraphEncoder::EncodeTensor( | |
| tensor_proto->set_raw_data("__EXTERNAL"); | ||
| } else { | ||
| AT_ASSERT(t.is_contiguous()); | ||
spandantiwari marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| tensor_proto->set_raw_data(std::string( | ||
| size_t tensorSize = static_cast<size_t>(std::accumulate(std::begin(tensor.sizes()), | ||
| std::end(tensor.sizes()), static_cast<int64_t>(1), std::multiplies<int64_t>())); | ||
| if (use_external_data_format && tensorSize > ParamSizeThresholdForExternalStorage) { | ||
| AT_ASSERT(!onnx_file_path.empty()); | ||
| AT_ASSERT((external_ref != c10::nullopt) && (external_ref.value() == tensor_proto->name())); | ||
| auto tensorName = GetExternalFileName(external_ref); | ||
spandantiwari marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| CreateExternalFile(t, tensorName, onnx_file_path); | ||
| onnx::StringStringEntryProto* location = tensor_proto->mutable_external_data()->Add(); | ||
| location->set_key("location"); | ||
| location->set_value(tensorName); | ||
| tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); | ||
| } | ||
| else { | ||
| tensor_proto->set_raw_data(std::string( | ||
| static_cast<char*>(t.data_ptr()), t.element_size() * t.numel())); | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -773,7 +861,9 @@ std::string pretty_print_onnx( | |
| true, | ||
| keep_initializers_as_inputs, | ||
| custom_opsets, | ||
| add_node_names); | ||
| add_node_names, | ||
| false, | ||
| std::string()); | ||
| if (google_printer) { | ||
| return graph_encoder.get_model_proto().DebugString(); | ||
| } | ||
|
|
@@ -795,7 +885,9 @@ std::tuple<std::string, RawDataExportMap> export_onnx( | |
| bool strip_doc_string, | ||
| bool keep_initializers_as_inputs, | ||
| const std::map<std::string, int>& custom_opsets, | ||
| bool add_node_names) { | ||
| bool add_node_names, | ||
| bool use_external_data_format, | ||
| const std::string& onnx_file_path) { | ||
| auto graph_encoder = GraphEncoder( | ||
| graph, | ||
| onnx_opset_version, | ||
|
|
@@ -806,7 +898,9 @@ std::tuple<std::string, RawDataExportMap> export_onnx( | |
| strip_doc_string, | ||
| keep_initializers_as_inputs, | ||
| custom_opsets, | ||
| add_node_names); | ||
| add_node_names, | ||
| use_external_data_format, | ||
| onnx_file_path); | ||
| return std::make_tuple( | ||
| graph_encoder.get_model_proto().SerializeAsString(), | ||
| graph_encoder.get_raw_data_export_map()); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.