Skip to content

Commit

Permalink
[onnx.export] Avoid linear loop over symbol_dim_map (#123029)
Browse files Browse the repository at this point in the history
This PR is part of an effort to speed up torch.onnx.export (#121422).

- Doing a reverse look-up in `symbol_dim_map` incurs a linear cost in number of symbols. This happens for each node, so incurs a quadratic cost to the whole export.
- Add a reverse look-up `dim_symbol_map` that is kept in parallel of `symbol_dim_map`. This avoids a linear time look-up, which creates a quadratic export time complexity.
- This is a highly pragmatic solution. If someone more familiar with the code base has a better solution, I'm interested to hear about it.
- Resolves (9) in #121422.

(partial fix of #121422)

Pull Request resolved: #123029
Approved by: https://github.com/justinchuby
  • Loading branch information
gustavla authored and ZelboK committed May 19, 2024
1 parent 35117bf commit 3ca1ae4
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 23 deletions.
14 changes: 14 additions & 0 deletions torch/csrc/jit/passes/onnx/constant_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ SymbolDimMap& ConstantValueMap::GetSymbolDimMap() {
return ConstantValueMap::getInstance().symbolDimMap;
}

DimSymbolMap& ConstantValueMap::GetDimSymbolMap() {
return ConstantValueMap::getInstance().dimSymbolMap;
}

template <typename Map>
void UpdateStrKey(
Map& map,
Expand Down Expand Up @@ -271,6 +275,7 @@ void ConstantValueMap::ClearMaps() {
ConstantValueMap::getInstance().shapeValueMap.clear();
ConstantValueMap::getInstance().inferredShapeData.clear();
ConstantValueMap::getInstance().symbolDimMap.clear();
ConstantValueMap::getInstance().dimSymbolMap.clear();
ConstantValueMap::getInstance().allGraphInputsStatic = c10::nullopt;
}

Expand Down Expand Up @@ -359,6 +364,15 @@ void ConstantValueMap::PrintMaps() {
std::cout << std::endl;
}
}
std::cout << "DimSymbol Map:" << std::endl;
count = 0;
for (const auto& x : ConstantValueMap::getInstance().dimSymbolMap) {
std::cout << "(" << x.first << ": " << x.second << "), ";
count++;
if (count % 10 == 0) {
std::cout << std::endl;
}
}
}

} // namespace jit
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/passes/onnx/constant_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class ConstantValueMap {
static ShapeDataMap& GetInferredShapeData();

static SymbolDimMap& GetSymbolDimMap();
static DimSymbolMap& GetDimSymbolMap();

static void UpdateValueName(
const std::string& old_name,
Expand Down Expand Up @@ -104,6 +105,7 @@ class ConstantValueMap {
// during future node-level shape inference.
ShapeDataMap inferredShapeData;
SymbolDimMap symbolDimMap;
DimSymbolMap dimSymbolMap;
// Stores if all graph-level inputs have static shape
c10::optional<bool> allGraphInputsStatic;
};
Expand Down
59 changes: 36 additions & 23 deletions torch/csrc/jit/passes/onnx/shape_type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,35 +87,40 @@ namespace onnx_torch = ::torch::onnx;
namespace onnx = ::ONNX_NAMESPACE;
namespace diagnostics = ::torch::onnx::diagnostics;

// SymbolDimMap is a Torch-to-ONNX shape look-up. This is built so it can be
// returned by the export function. During the export however, when we come
// across new ONNX shapes, the reverse look-up is needed. To avoid incurring
// a linear-time look-up, we maintain DimSymbolMap in parallel.
c10::ShapeSymbol ONNXDimToShapeSymbol(
const onnx::TensorShapeProto_Dimension& dim,
SymbolDimMap& symbol_dim_map) {
SymbolDimMap& symbol_dim_map,
DimSymbolMap& dim_symbol_map) {
if (dim.has_dim_value()) {
return c10::ShapeSymbol::fromStaticSize(dim.dim_value());
}
std::optional<c10::ShapeSymbol> sym = c10::nullopt;
if (dim.has_dim_param()) {
// If this param is already known, assign the same Symbol.
GRAPH_UPDATE("Got dim_param:", dim.dim_param());
for (const auto& pair : symbol_dim_map) {
if (pair.second == dim.dim_param()) {
sym = pair.first;
break;
}
auto maybe_symbol = dim_symbol_map.find(dim.dim_param());
if (maybe_symbol != dim_symbol_map.end()) {
sym = maybe_symbol->second;
}
}
if (!sym) {
sym = c10::ShapeSymbol::newSymbol();
// If dim.dim_param() is empty, no need to keep track
// because there won't be duplicates.
symbol_dim_map[sym.value()] = dim.dim_param();
dim_symbol_map[dim.dim_param()] = sym.value();
}
return sym.value();
}

TensorTypePtr TorchTensorTypeFromONNX(
const onnx::TypeProto_Tensor& onnx_tensor_type,
SymbolDimMap& symbol_dim_map) {
SymbolDimMap& symbol_dim_map,
DimSymbolMap& dim_symbol_map) {
std::optional<at::ScalarType> scalar_type;
if (onnx_tensor_type.has_elem_type()) {
scalar_type = ONNXTypeToATenType(onnx_tensor_type.elem_type());
Expand All @@ -132,8 +137,8 @@ TensorTypePtr TorchTensorTypeFromONNX(
const auto& onnx_shape = onnx_tensor_type.shape();

for (const auto i : c10::irange(onnx_shape.dim_size())) {
sizes.emplace_back(
ONNXDimToShapeSymbol(onnx_shape.dim(i), symbol_dim_map));
sizes.emplace_back(ONNXDimToShapeSymbol(
onnx_shape.dim(i), symbol_dim_map, dim_symbol_map));
}
v_type = TensorType::create(scalar_type, at::kCPU, sizes.size(), {});
v_type = v_type->withSymbolicShapes(c10::SymbolicShape(sizes));
Expand All @@ -150,13 +155,14 @@ TensorTypePtr TorchTensorTypeFromONNX(

ListTypePtr TorchListTypeFromONNX(
const onnx::TypeProto_Sequence& onnx_sequence_type,
SymbolDimMap& symbol_dim_map) {
SymbolDimMap& symbol_dim_map,
DimSymbolMap& dim_symbol_map) {
if (onnx_sequence_type.has_elem_type()) {
const auto& onnx_seq_elem_type = onnx_sequence_type.elem_type();
if (onnx_seq_elem_type.has_tensor_type()) {
const auto& onnx_tensor_type = onnx_seq_elem_type.tensor_type();
const auto v_tensor_type =
TorchTensorTypeFromONNX(onnx_tensor_type, symbol_dim_map);
const auto v_tensor_type = TorchTensorTypeFromONNX(
onnx_tensor_type, symbol_dim_map, dim_symbol_map);
auto v_type = ListType::create(v_tensor_type);
return v_type;
}
Expand All @@ -167,21 +173,22 @@ ListTypePtr TorchListTypeFromONNX(
void UpdateTorchValueByOnnxValueInfo(
Value* v,
const onnx::ValueInfoProto& p_info,
SymbolDimMap& symbol_dim_map) {
SymbolDimMap& symbol_dim_map,
DimSymbolMap& dim_symbol_map) {
if (!p_info.has_type()) {
return;
}

const auto& p_type = p_info.type();
if (p_type.has_tensor_type()) {
const auto torch_tensor_type =
TorchTensorTypeFromONNX(p_type.tensor_type(), symbol_dim_map);
const auto torch_tensor_type = TorchTensorTypeFromONNX(
p_type.tensor_type(), symbol_dim_map, dim_symbol_map);
if (torch_tensor_type) {
MergeInferredTypeAndSetMap(v, v->type(), torch_tensor_type);
}
} else if (p_type.has_sequence_type()) {
const auto torch_list_type =
TorchListTypeFromONNX(p_type.sequence_type(), symbol_dim_map);
const auto torch_list_type = TorchListTypeFromONNX(
p_type.sequence_type(), symbol_dim_map, dim_symbol_map);
if (torch_list_type) {
MergeInferredTypeAndSetMap(v, v->type(), torch_list_type);
}
Expand Down Expand Up @@ -377,6 +384,7 @@ void ConvertGraphToONNXProto(
std::shared_ptr<Graph> graph,
std::shared_ptr<onnx::ModelProto>& model_proto,
SymbolDimMap& symbol_dim_map,
DimSymbolMap& dim_symbol_map,
int opset_version) {
RawDataExportMap export_map;
bool val_use_external_data_format;
Expand All @@ -402,6 +410,9 @@ void ConvertGraphToONNXProto(
false,
std::string());
symbol_dim_map.insert(new_symbol_dim_map.begin(), new_symbol_dim_map.end());
for (const auto& pair : new_symbol_dim_map) {
dim_symbol_map[pair.second] = pair.first;
}
for (int i = 0; i < model_proto->graph().output_size(); ++i) {
model_proto->mutable_graph()->mutable_output(i)->clear_type();
}
Expand Down Expand Up @@ -1796,7 +1807,8 @@ void UpdateOutputTypeByONNXProto(
Node* n,
Node* clone_node,
const onnx::ModelProto& model_proto,
SymbolDimMap& symbol_dim_map) {
SymbolDimMap& symbol_dim_map,
DimSymbolMap& dim_symbol_map) {
const auto& graph_proto = model_proto.graph();

// get data from value_info and updated original graph.
Expand All @@ -1805,7 +1817,7 @@ void UpdateOutputTypeByONNXProto(
for (size_t i = 0; i < n->outputs().size(); ++i) {
if (clone_node->output(i)->debugName() == v_info.name()) {
UpdateTorchValueByOnnxValueInfo(
n->output(i), v_info, symbol_dim_map);
n->output(i), v_info, symbol_dim_map, dim_symbol_map);
}
}
};
Expand Down Expand Up @@ -2040,6 +2052,7 @@ void ONNXShapeTypeInference(
auto& original_shape_data = ConstantValueMap::GetInferredShapeData();
ShapeDataMap inferred_shape_data;
auto& symbol_dim_map = ConstantValueMap::GetSymbolDimMap();
auto& dim_symbol_map = ConstantValueMap::GetDimSymbolMap();

SetGraphInputTypeReliable(n->owningGraph());
GRAPH_UPDATE(
Expand Down Expand Up @@ -2094,7 +2107,7 @@ void ONNXShapeTypeInference(
// e.g: ListConstruct, ListUnpack, etc.
std::shared_ptr<onnx::ModelProto> model_proto;
ConvertGraphToONNXProto(
n_graph, model_proto, symbol_dim_map, opset_version);
n_graph, model_proto, symbol_dim_map, dim_symbol_map, opset_version);
GRAPH_DEBUG(
"ONNX graph to run shape inference: ", prettyPrint(*model_proto));

Expand All @@ -2119,7 +2132,7 @@ void ONNXShapeTypeInference(
}
}
UpdateOutputTypeByONNXProto(
n, clone_node, *model_proto, symbol_dim_map);
n, clone_node, *model_proto, symbol_dim_map, dim_symbol_map);
} catch (std::runtime_error& ex) {
// TODO: include this as warning once we have a more consolidated
// warning system.
Expand Down Expand Up @@ -2161,8 +2174,8 @@ void ONNXShapeTypeInference(
int rank = inferred_shape.dim_size();
std::vector<::c10::ShapeSymbol> final_shape(rank);
for (int i = 0; i < rank; ++i) {
final_shape[i] =
ONNXDimToShapeSymbol(inferred_shape.dim(i), symbol_dim_map);
final_shape[i] = ONNXDimToShapeSymbol(
inferred_shape.dim(i), symbol_dim_map, dim_symbol_map);
}
c10::SymbolicShape shape_value(final_shape);
// Store data propagation result into shapeValueMap
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/serialization/export.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace jit {
using RawDataExportMap = std::unordered_map<std::string, at::Tensor>;

using SymbolDimMap = std::map<c10::ShapeSymbol, std::string>;
using DimSymbolMap = std::map<std::string, c10::ShapeSymbol>;

using NodeNameMap = std::unordered_map<const Node*, std::string>;

Expand Down

0 comments on commit 3ca1ae4

Please sign in to comment.