Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from torch.nn.utils import rnn as rnn_utils
from model_defs.lstm_flattening_result import LstmFlatteningResult
from model_defs.rnn_model_with_packed_sequence import RnnModelWithPackedSequence
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfNoLapack, enableScriptTest
from test_pytorch_common import (skipIfUnsupportedMinOpsetVersion, enableScriptTest,
skipIfNoLapack)
from test_pytorch_common import BATCH_SIZE
from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
import model_defs.word_language_model as word_language_model
Expand Down Expand Up @@ -124,6 +125,78 @@ def _run_test(m):
_run_test(script_model)
_run_test(model)

def run_model_test_with_external_data(self, model, input, rtol=0.001, atol=1e-7,
example_outputs=None, do_constant_folding=True,
dynamic_axes=None, input_names=None, output_names=None,
ort_optim_on=True):
import os
import tempfile

model.eval()
with torch.no_grad():
if isinstance(input, torch.Tensor):
input = (input,)
# In-place operators will update input tensor data as well.
# Thus inputs are replicated before every forward call.
input_copy = copy.deepcopy(input)
output = model(*input_copy)
if isinstance(output, torch.Tensor):
output = (output,)

# export the model to ONNX
with tempfile.TemporaryDirectory() as tmpdirname:
model_file_name = os.path.join(tmpdirname, 'model.onnx')
input_copy = copy.deepcopy(input)
torch.onnx.export(model, input_copy, model_file_name,
opset_version=self.opset_version,
example_outputs=output,
verbose=False,
do_constant_folding=do_constant_folding,
keep_initializers_as_inputs=self.keep_initializers_as_inputs,
dynamic_axes=dynamic_axes,
input_names=input_names, output_names=output_names,
use_external_data_format=True)
# compute onnxruntime output prediction
ort_sess_opt = onnxruntime.SessionOptions()
ort_sess_opt.graph_optimization_level = \
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED if ort_optim_on else \
onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
ort_sess = onnxruntime.InferenceSession(model_file_name, sess_options=ort_sess_opt)
input_copy = copy.deepcopy(input)
ort_test_with_input(ort_sess, input_copy, output, rtol, atol)


@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
def test_embedding_model_with_external_data(self):
class LargeModel(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: EmbeddingModel?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's better. Updated.

def __init__(self):
super(LargeModel, self).__init__()
dim = 15
n = 4 * 100
self.emb = torch.nn.Embedding(n, dim)
self.lin1 = torch.nn.Linear(dim, 1)
self.seq = torch.nn.Sequential(
self.emb,
self.lin1,
)

def forward(self, input):
return self.seq(input)

model = LargeModel()
x = torch.tensor([2], dtype=torch.long)
self.run_model_test_with_external_data(model, x)

@skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9.
def test_mobilenet_v2_with_external_data(self):
model = torchvision.models.mobilenet_v2(pretrained=True)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
# We are turning off Onnx Runtime optimization off in this test,
# because external data format is not supported to in ORT optimizer.
# Once that support is added, we can set ort_optim_on=True (default).
self.run_model_test_with_external_data(model, x, rtol=1e-3, atol=1e-5,
ort_optim_on=False)

# Export Torchvision models

def test_alexnet(self):
Expand Down
128 changes: 111 additions & 17 deletions torch/csrc/jit/export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include <sstream>
#include <string>
#include <vector>

#include <regex>
namespace torch {
namespace jit {

Expand Down Expand Up @@ -110,6 +110,47 @@ void validateGraph(
EliminateDeadCode(graph->block(), true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
}

std::string GetFileRootPath(const std::string& rootPath) {
std::string rootPath_ = rootPath;
// First, making slash consistent.
std::replace(rootPath_.begin(), rootPath_.end(), '\\', '/');
// Second, remove trailing slashes, if any
std::regex trailer("/+$");
std::string root = std::regex_replace(rootPath_, trailer, std::string());
std::string folder = root.substr(0, root.find_last_of('/'));
if (folder == rootPath_) { // If no root folder specified, select cwd.
return std::string(".");
}
return folder;
}

std::string GetExternalFileName(const c10::optional<std::string> external_ref) {
auto tensorName = external_ref.value();
const std::string illegalChars = "\\/:?\"<>|";
for (int i = 0; i < tensorName.size(); i++) {
if (illegalChars.find(tensorName[i]) != std::string::npos) {
tensorName[i] = '_';
}
}
return tensorName;
}

void CloseFile(FILE* fp) {
fclose(fp);
}

void CreateExternalFile(const at::Tensor& tensor, const std::string& tensorName,
const std::string& onnx_file_path) {
auto folder = GetFileRootPath(onnx_file_path);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: do we want to make sure onnx_file_path is a file, not folder?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this should work for both. In fact, we don't check explicitly for the file/folder even with existing regular model file path, because the file is not yet created.
Anyway, I feel that this will work as expected in both cases.

std::string fullFilePath = folder + "/" + tensorName;
std::unique_ptr<FILE, decltype(&CloseFile)> fp(fopen(fullFilePath.c_str(), "wb"),
&CloseFile);
if (fp == NULL) {
throw std::runtime_error(std::string("ONNX export failed. Could not open file or directory: ") + fullFilePath);
}
fwrite(tensor.data_ptr(), tensor.element_size(), tensor.numel(), fp.get());
} // fclose() called here through CloseFile(), if FILE* is not a null pointer.

class EncoderBase {
public:
EncoderBase(
Expand All @@ -135,7 +176,9 @@ class EncoderBase {
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes =
std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>(),
bool keep_initializers_as_inputs = true,
bool add_node_names = true);
bool add_node_names = true,
bool use_external_data_format = false,
const std::string& onnx_file_path = std::string());

void EncodeBlock(
onnx::GraphProto* graph_proto,
Expand All @@ -145,12 +188,16 @@ class EncoderBase {
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes =
std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>(),
bool keep_initializers_as_inputs = true,
bool add_node_names = true);
bool add_node_names = true,
bool use_external_data_format = false,
const std::string& onnx_file_path = std::string());

virtual void EncodeTensor(
onnx::TensorProto* tensor_proto,
const at::Tensor& tensor,
const c10::optional<std::string> external_ref = {}) = 0;
const c10::optional<std::string> external_ref = {},
const bool use_external_data_format = false,
const std::string& onnx_file_path = std::string()) = 0;

virtual void EncodeIntermediateValueInfo(
onnx::GraphProto* graph_proto,
Expand All @@ -174,6 +221,13 @@ class EncoderBase {
onnx_torch::OperatorExportTypes operator_export_type_;
bool strip_doc_;
std::set<std::string> domains_;

// For large models, the parameters can be stored in separate binary files.
// This parameter sets a threshold on the number of elements in the parameter
// tensor, beyond which the parameter is stored in a separate file (if API
// argument use_external_data_format is set to True). This threshold is in place
// so as not to create too many external files.
const size_t ParamSizeThresholdForExternalStorage = 1024;
};

onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
Expand Down Expand Up @@ -264,9 +318,12 @@ void EncoderBase::EncodeGraph(
const std::map<std::string, at::Tensor>& initializers,
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes,
bool keep_initializers_as_inputs,
bool add_node_names) {
bool add_node_names,
bool use_external_data_format,
const std::string& onnx_file_path) {
EncodeBlock(graph_proto, graph->block(), initializers, dynamic_axes,
keep_initializers_as_inputs, add_node_names);
keep_initializers_as_inputs, add_node_names, use_external_data_format,
onnx_file_path);
}

void EncoderBase::EncodeBlock(
Expand All @@ -275,7 +332,9 @@ void EncoderBase::EncodeBlock(
const std::map<std::string, at::Tensor>& initializers,
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes,
bool keep_initializers_as_inputs,
bool add_node_names) {
bool add_node_names,
bool use_external_data_format,
const std::string& onnx_file_path) {
AT_ASSERT(graph_proto != nullptr);
std::string block_name = "torch-jit-export";
if (num_blocks_) {
Expand Down Expand Up @@ -401,7 +460,8 @@ void EncoderBase::EncodeBlock(
for (auto& name_tensor_pair : initializers) {
auto p = graph_proto->add_initializer();
p->set_name(name_tensor_pair.first);
EncodeTensor(p, name_tensor_pair.second, name_tensor_pair.first);
EncodeTensor(p, name_tensor_pair.second, name_tensor_pair.first,
use_external_data_format, onnx_file_path);
}
}

Expand Down Expand Up @@ -481,7 +541,9 @@ class GraphEncoder : public EncoderBase {
bool strip_doc,
bool keep_initializers_as_inputs,
const std::map<std::string, int>& custom_opsets,
bool add_node_names);
bool add_node_names,
bool use_external_data_format,
const std::string& onnx_file_path);

RawDataExportMap get_raw_data_export_map() {
return raw_data_export_map_;
Expand All @@ -491,7 +553,9 @@ class GraphEncoder : public EncoderBase {
void EncodeTensor(
onnx::TensorProto* tensor_proto,
const at::Tensor& tensor,
const c10::optional<std::string> external_ref = {}) override;
const c10::optional<std::string> external_ref = {},
const bool use_external_data_format = false,
const std::string& onnx_file_path = std::string()) override;

RawDataExportMap raw_data_export_map_;
bool defer_weight_export_;
Expand All @@ -507,7 +571,9 @@ GraphEncoder::GraphEncoder(
bool strip_doc,
bool keep_initializers_as_inputs,
const std::map<std::string, int>& custom_opsets,
bool add_node_names)
bool add_node_names,
bool use_external_data_format,
const std::string& onnx_file_path)
: EncoderBase(operator_export_type, strip_doc),
defer_weight_export_(defer_weight_export) {
if (operator_export_type != onnx_torch::OperatorExportTypes::RAW) {
Expand All @@ -519,7 +585,8 @@ GraphEncoder::GraphEncoder(
imp->set_version(onnx_opset_version);

EncodeGraph(model_proto_.mutable_graph(), graph, initializers, dynamic_axes,
keep_initializers_as_inputs, add_node_names);
keep_initializers_as_inputs, add_node_names, use_external_data_format,
onnx_file_path);

for (const std::string& domain : domains_) {
auto* opset = model_proto_.add_opset_import();
Expand All @@ -544,7 +611,9 @@ GraphEncoder::GraphEncoder(
void GraphEncoder::EncodeTensor(
onnx::TensorProto* tensor_proto,
const at::Tensor& tensor,
const c10::optional<std::string> external_ref) {
const c10::optional<std::string> external_ref,
const bool use_external_data_format,
const std::string& onnx_file_path) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: onnx_param_file_path or onnx_weights_file_path?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, at this point this is the file path for the ONNX model file itself. We derive the param/weight file name from this (variable named fullFilePath). So I feel that this name may be appropriate. Let me know if that's OK, otherwise I can update it.

for (auto d : tensor.sizes()) {
tensor_proto->add_dims(d);
}
Expand All @@ -559,6 +628,11 @@ void GraphEncoder::EncodeTensor(
else {
t = tensor.contiguous().cpu();
}

// Either defer_weight_export should be true and external_ref must be present,
// or use_external_data_format should be true, not both at the same time. They can
// both be false at the same time (for ONNX export for regular model size).
AT_ASSERT(!((defer_weight_export_ && external_ref) && use_external_data_format));
// Add a buffer to the raw_data_export_map for the caller to dump into an
// external data store. If external_ref is not specified, we instead dump
// the contiguous data into the protobuf itself
Expand All @@ -571,8 +645,22 @@ void GraphEncoder::EncodeTensor(
tensor_proto->set_raw_data("__EXTERNAL");
} else {
AT_ASSERT(t.is_contiguous());
tensor_proto->set_raw_data(std::string(
size_t tensorSize = static_cast<size_t>(std::accumulate(std::begin(tensor.sizes()),
std::end(tensor.sizes()), static_cast<int64_t>(1), std::multiplies<int64_t>()));
if (use_external_data_format && tensorSize > ParamSizeThresholdForExternalStorage) {
AT_ASSERT(!onnx_file_path.empty());
AT_ASSERT((external_ref != c10::nullopt) && (external_ref.value() == tensor_proto->name()));
auto tensorName = GetExternalFileName(external_ref);
CreateExternalFile(t, tensorName, onnx_file_path);
onnx::StringStringEntryProto* location = tensor_proto->mutable_external_data()->Add();
location->set_key("location");
location->set_value(tensorName);
tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL);
}
else {
tensor_proto->set_raw_data(std::string(
static_cast<char*>(t.data_ptr()), t.element_size() * t.numel()));
}
}
}

Expand Down Expand Up @@ -773,7 +861,9 @@ std::string pretty_print_onnx(
true,
keep_initializers_as_inputs,
custom_opsets,
add_node_names);
add_node_names,
false,
std::string());
if (google_printer) {
return graph_encoder.get_model_proto().DebugString();
}
Expand All @@ -795,7 +885,9 @@ std::tuple<std::string, RawDataExportMap> export_onnx(
bool strip_doc_string,
bool keep_initializers_as_inputs,
const std::map<std::string, int>& custom_opsets,
bool add_node_names) {
bool add_node_names,
bool use_external_data_format,
const std::string& onnx_file_path) {
auto graph_encoder = GraphEncoder(
graph,
onnx_opset_version,
Expand All @@ -806,7 +898,9 @@ std::tuple<std::string, RawDataExportMap> export_onnx(
strip_doc_string,
keep_initializers_as_inputs,
custom_opsets,
add_node_names);
add_node_names,
use_external_data_format,
onnx_file_path);
return std::make_tuple(
graph_encoder.get_model_proto().SerializeAsString(),
graph_encoder.get_raw_data_export_map());
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/export.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ TORCH_API std::tuple<std::string, RawDataExportMap> export_onnx(
bool strip_doc_string = true,
bool keep_initializers_as_inputs = true,
const std::map<std::string, int>& custom_opsets = {},
bool add_node_names = true);
bool add_node_names = true,
bool use_external_data_format = false,
const std::string& onnx_file_path = std::string());

TORCH_API void check_onnx_proto(const std::string& proto_string);

Expand Down
12 changes: 9 additions & 3 deletions torch/csrc/jit/python_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ void initPythonIRBindings(PyObject* module_) {
bool strip_doc_string,
bool keep_initializers_as_inputs,
const std::map<std::string, int>& custom_opsets,
bool add_node_names) {
bool add_node_names,
bool use_external_data_format,
const std::string& onnx_file_path) {
std::string graph;
RawDataExportMap export_map;
std::tie(graph, export_map) = export_onnx(
Expand All @@ -249,7 +251,9 @@ void initPythonIRBindings(PyObject* module_) {
strip_doc_string,
keep_initializers_as_inputs,
custom_opsets,
add_node_names);
add_node_names,
use_external_data_format,
onnx_file_path);
std::unordered_map<std::string, py::bytes>
python_serialized_export_map;
for (auto& kv : export_map) {
Expand All @@ -273,7 +277,9 @@ void initPythonIRBindings(PyObject* module_) {
py::arg("strip_doc_string") = true,
py::arg("keep_initializers_as_inputs") = true,
py::arg("custom_opsets"),
py::arg("add_node_names") = true)
py::arg("add_node_names") = true,
py::arg("use_external_data_format") = false,
py::arg("onnx_file_path") = std::string())
.def(
"_pretty_print_onnx",
[](const std::shared_ptr<Graph> g,
Expand Down
Loading