From 50ba9eb41baaa3df739d17360192027746a8444d Mon Sep 17 00:00:00 2001 From: zhxchen17 Date: Tue, 21 Dec 2021 12:10:58 -0800 Subject: [PATCH] [jit][edge] Reclaim some binary size. Should give us 1-2KB back on mobile due to removal of redundant container copy and shared_ptr decrement. Differential Revision: [D33230514](https://our.internmc.facebook.com/intern/diff/D33230514/) [ghstack-poisoned] --- .../csrc/jit/frontend/schema_type_parser.cpp | 178 +++++++++--------- torch/csrc/jit/mobile/model_compatibility.cpp | 5 +- torch/csrc/jit/mobile/type_parser.cpp | 6 +- torch/csrc/jit/mobile/type_parser.h | 6 +- 4 files changed, 101 insertions(+), 94 deletions(-) diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index f09aa60ba72a..9d51220d90b7 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -289,96 +289,104 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { } std::pair> SchemaTypeParser::parseType() { - TypePtr value; c10::optional alias_info; // Tuple type - if (L.cur().kind == '(') { - std::vector types; - parseList('(', ',', ')', [&] { - auto r = parseType(); - types.push_back(std::move(r.first)); - if (alias_info && r.second) { - alias_info->addContainedType(std::move(*r.second)); - } - }); - value = TupleType::create(std::move(types)); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") { - L.next(); // Future - L.expect('('); - auto p = parseType(); - auto subtype = std::move(p.first); - auto subalias = std::move(p.second); - L.expect(')'); - value = FutureType::create(subtype); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") { - L.next(); // RRef - L.expect('('); - auto p = parseType(); - auto subtype = std::move(p.first); - auto subalias = std::move(p.second); - L.expect(')'); - value = RRefType::create(subtype); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") { - L.next(); - value = TensorType::get(); - alias_info = parseAliasAnnotation(); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") { - L.next(); - L.expect('('); - auto key_type = parseType().first; - L.expect(','); - auto value_type = parseType().first; - L.expect(')'); - alias_info = parseAliasAnnotation(); - value = DictType::create(key_type, value_type); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") { - L.next(); - L.expect('('); - std::vector types; - types.emplace_back(parseType().first); - while (L.cur().kind != ')') { + auto parseHead = [&]() -> TypePtr { + if (L.cur().kind == '(') { + std::vector types; + parseList('(', ',', ')', [&] { + auto r = parseType(); + types.push_back(std::move(r.first)); + if (alias_info && r.second) { + alias_info->addContainedType(std::move(*r.second)); + } + }); + return TupleType::create(std::move(types)); + } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") { + L.next(); // Future + L.expect('('); + auto p = parseType(); + auto subtype = std::move(p.first); + auto subalias = std::move(p.second); + L.expect(')'); + return FutureType::create(subtype); + } else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") { + L.next(); // RRef + L.expect('('); + auto p = parseType(); + auto subtype = std::move(p.first); + auto subalias = std::move(p.second); + L.expect(')'); + return RRefType::create(subtype); + } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") { + L.next(); + auto value = TensorType::get(); + alias_info = parseAliasAnnotation(); + return value; + } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") { + L.next(); + L.expect('('); + auto key_type = parseType().first; L.expect(','); + auto value_type = parseType().first; + L.expect(')'); + alias_info = parseAliasAnnotation(); + return DictType::create(key_type, value_type); + } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") { + L.next(); + L.expect('('); + std::vector types; types.emplace_back(parseType().first); + while (L.cur().kind != ')') { + L.expect(','); + types.emplace_back(parseType().first); + } + L.expect(')'); + alias_info = parseAliasAnnotation(); + return UnionType::create(types); + } else if ( + complete_tensor_types && L.cur().kind == TK_IDENT && + parseTensorDType(L.cur().text())) { + auto value = parseRefinedTensor(); + alias_info = parseAliasAnnotation(); + return value; + } else if (L.cur().kind == TK_IDENT && L.cur().text() == "__torch__") { + L.next(); + L.expect('.'); + auto torch_tok = L.expect(TK_IDENT); + if (torch_tok.text() != "torch") { + throw ErrorReport(torch_tok.range) + << "Expected classes namespace but got " << torch_tok.text(); + } + L.expect('.'); + auto classes_tok = L.expect(TK_IDENT); + if (classes_tok.text() != "classes") { + throw ErrorReport(classes_tok.range) + << "Expected classes namespace but got " << classes_tok.text(); + } + L.expect('.'); + auto ns_tok = L.expect(TK_IDENT); + L.expect('.'); + auto class_tok = L.expect(TK_IDENT); + auto value = getCustomClass( + std::string("__torch__.torch.classes.") + ns_tok.text() + "." + + class_tok.text()); + if (!value) { + throw ErrorReport(class_tok.range) + << "Unknown custom class type " + << ns_tok.text() + "." + class_tok.text() + << ". Please ensure it is registered."; + } + return value; + } else { + auto value = parseBaseType(); + alias_info = parseAliasAnnotation(); + return value; } - L.expect(')'); - alias_info = parseAliasAnnotation(); - value = UnionType::create(types); - } else if ( - complete_tensor_types && L.cur().kind == TK_IDENT && - parseTensorDType(L.cur().text())) { - value = parseRefinedTensor(); - alias_info = parseAliasAnnotation(); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "__torch__") { - L.next(); - L.expect('.'); - auto torch_tok = L.expect(TK_IDENT); - if (torch_tok.text() != "torch") { - throw ErrorReport(torch_tok.range) - << "Expected classes namespace but got " << torch_tok.text(); - } - L.expect('.'); - auto classes_tok = L.expect(TK_IDENT); - if (classes_tok.text() != "classes") { - throw ErrorReport(classes_tok.range) - << "Expected classes namespace but got " << classes_tok.text(); - } - L.expect('.'); - auto ns_tok = L.expect(TK_IDENT); - L.expect('.'); - auto class_tok = L.expect(TK_IDENT); - value = getCustomClass( - std::string("__torch__.torch.classes.") + ns_tok.text() + "." + - class_tok.text()); - if (!value) { - throw ErrorReport(class_tok.range) - << "Unknown custom class type " - << ns_tok.text() + "." + class_tok.text() - << ". Please ensure it is registered."; - } - } else { - value = parseBaseType(); - alias_info = parseAliasAnnotation(); - } + }; + + TypePtr value = parseHead(); + while (true) { if (L.cur().kind == '[' && L.lookahead().kind == ']') { L.next(); // [ diff --git a/torch/csrc/jit/mobile/model_compatibility.cpp b/torch/csrc/jit/mobile/model_compatibility.cpp index 687e1a94e3d4..61d3db2ebf3a 100644 --- a/torch/csrc/jit/mobile/model_compatibility.cpp +++ b/torch/csrc/jit/mobile/model_compatibility.cpp @@ -334,7 +334,7 @@ ModelCompatCheckResult is_compatible( result.errors.emplace_back(s.str()); } - std::unordered_set supported_type = runtime_info.supported_types; + const auto& supported_type = runtime_info.supported_types; // Check type table for (const auto& type_name : model_info.type_table) { @@ -348,8 +348,7 @@ ModelCompatCheckResult is_compatible( } // Check operators - std::unordered_map operator_info = - model_info.operator_info; + const auto& operator_info = model_info.operator_info; for (auto const& op : operator_info) { std::string op_name = op.first; OperatorInfo model_op_info = op.second; diff --git a/torch/csrc/jit/mobile/type_parser.cpp b/torch/csrc/jit/mobile/type_parser.cpp index 2a724b2477e9..d90cf3c3cb60 100644 --- a/torch/csrc/jit/mobile/type_parser.cpp +++ b/torch/csrc/jit/mobile/type_parser.cpp @@ -64,14 +64,14 @@ std::vector TypeParser::parseList() { } // The list of non-simple types supported by currrent parser. -std::unordered_set TypeParser::getNonSimpleType() { +const std::unordered_set& TypeParser::getNonSimpleType() { static std::unordered_set nonSimpleTypes{ "List", "Optional", "Future", "Dict", "Tuple"}; return nonSimpleTypes; } // The list of custom types supported by currrent parser. -std::unordered_set TypeParser::getCustomType() { +const std::unordered_set& TypeParser::getCustomType() { static std::unordered_set customeTypes{ kTypeTorchbindCustomClass, kTypeNamedTuple}; return customeTypes; @@ -81,7 +81,7 @@ std::unordered_set TypeParser::getCustomType() { // compatibility check between model and runtime. For example: // PyThon string: "Dict[int, Tuple[Tensor, Tensor, Tensor]]" // contained type is: [Dict, int, Tuple, Tensor] -std::unordered_set TypeParser::getContainedTypes() { +const std::unordered_set& TypeParser::getContainedTypes() { return contained_types_; } diff --git a/torch/csrc/jit/mobile/type_parser.h b/torch/csrc/jit/mobile/type_parser.h index 647df8a0d05e..4b82b47c6835 100644 --- a/torch/csrc/jit/mobile/type_parser.h +++ b/torch/csrc/jit/mobile/type_parser.h @@ -19,9 +19,9 @@ class TORCH_API TypeParser { template TypePtr parse(); std::vector parseList(); - static std::unordered_set getNonSimpleType(); - static std::unordered_set getCustomType(); - std::unordered_set getContainedTypes(); + static const std::unordered_set& getNonSimpleType(); + static const std::unordered_set& getCustomType(); + const std::unordered_set& getContainedTypes(); private: // Torchbind custom class always starts with the follow prefix, so use it as