diff --git a/aten/src/ATen/core/union_type.cpp b/aten/src/ATen/core/union_type.cpp index f763879b77b9f..5d4edc70d8a73 100644 --- a/aten/src/ATen/core/union_type.cpp +++ b/aten/src/ATen/core/union_type.cpp @@ -111,10 +111,8 @@ void filterDuplicateSubtypes(std::vector* types) { size_t end_idx = types->size()-1; for (size_t i = types->size()-1; i > 0; --i) { for (size_t j = std::min(i-1, end_idx); ; --j) { - c10::optional unified; - unified = get_supertype((*types)[i], (*types)[j]); - if (unified) { - (*types)[j] = *unified; + if (auto unified = get_supertype((*types)[i], (*types)[j])) { + (*types)[j] = std::move(*unified); (*types)[i] = (*types)[end_idx]; --end_idx; break; @@ -160,13 +158,13 @@ void standardizeVectorForUnion(std::vector* to_flatten) { "passed a `nullptr`"); std::vector to_fill; standardizeVectorForUnion(*to_flatten, &to_fill); - *to_flatten = to_fill; + *to_flatten = std::move(to_fill); } OptionalType::OptionalType(TypePtr contained) : UnionType({contained, NoneType::get()}, TypeKind::OptionalType) { bool is_numbertype = false; - if (auto as_union = contained->cast()) { + if (auto as_union = contained->castRaw()) { is_numbertype = as_union->containedTypes().size() == 3 && as_union->canHoldType(*NumberType::get()); } @@ -195,20 +193,17 @@ UnionType::UnionType(std::vector reference, TypeKind kind) : SharedType // Gate the assert in a regular conditional so that we don't create // this long error message unnecessarily if (types_.size() == 1) { - std::stringstream msg; - msg << "After type unification was performed, the Union with the " - << "original types {"; + std::string msg = "After type unification was performed, the Union with the original types {"; for (const auto i : c10::irange(reference.size())) { - msg << reference[i]->repr_str(); + msg += reference[i]->repr_str(); if (i > 0) { - msg << ","; + msg += ","; } - msg << " "; + msg += " "; } - msg << "} has the single type " << types_[0]->repr_str() - << ". Use the common supertype instead of creating a Union" - << "type"; - TORCH_INTERNAL_ASSERT(false, msg.str()); + msg += "} has the single type " + types_[0]->repr_str() + + ". Use the common supertype instead of creating a Union type"; + TORCH_INTERNAL_ASSERT(false, msg); } can_hold_none_ = false; @@ -357,7 +352,7 @@ bool UnionType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const { std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) const { - std::stringstream ss; + std::string ss; bool can_hold_numbertype = this->canHoldType(*NumberType::get()); @@ -375,33 +370,33 @@ std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) std::string open_delimeter = is_annotation_str ? "[" : "("; std::string close_delimeter = is_annotation_str ? "]" : ")"; - ss << "Union" + open_delimeter; + ss = "Union" + open_delimeter; bool printed = false; for (size_t i = 0; i < types_.size(); ++i) { if (!can_hold_numbertype || !is_numbertype(types_[i])) { if (i > 0) { - ss << ", "; + ss += ", "; printed = true; } if (is_annotation_str) { - ss << this->containedTypes()[i]->annotation_str(printer); + ss += this->containedTypes()[i]->annotation_str(printer); } else { - ss << this->containedTypes()[i]->str(); + ss += this->containedTypes()[i]->str(); } } } if (can_hold_numbertype) { if (printed) { - ss << ", "; + ss += ", "; } if (is_annotation_str) { - ss << NumberType::get()->annotation_str(printer); + ss += NumberType::get()->annotation_str(printer); } else { - ss << NumberType::get()->str(); + ss += NumberType::get()->str(); } } - ss << close_delimeter; - return ss.str(); + ss += close_delimeter; + return ss; } std::string UnionType::str() const { diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index f09aa60ba72a9..9d51220d90b79 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -289,96 +289,104 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { } std::pair> SchemaTypeParser::parseType() { - TypePtr value; c10::optional alias_info; // Tuple type - if (L.cur().kind == '(') { - std::vector types; - parseList('(', ',', ')', [&] { - auto r = parseType(); - types.push_back(std::move(r.first)); - if (alias_info && r.second) { - alias_info->addContainedType(std::move(*r.second)); - } - }); - value = TupleType::create(std::move(types)); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") { - L.next(); // Future - L.expect('('); - auto p = parseType(); - auto subtype = std::move(p.first); - auto subalias = std::move(p.second); - L.expect(')'); - value = FutureType::create(subtype); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") { - L.next(); // RRef - L.expect('('); - auto p = parseType(); - auto subtype = std::move(p.first); - auto subalias = std::move(p.second); - L.expect(')'); - value = RRefType::create(subtype); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") { - L.next(); - value = TensorType::get(); - alias_info = parseAliasAnnotation(); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") { - L.next(); - L.expect('('); - auto key_type = parseType().first; - L.expect(','); - auto value_type = parseType().first; - L.expect(')'); - alias_info = parseAliasAnnotation(); - value = DictType::create(key_type, value_type); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") { - L.next(); - L.expect('('); - std::vector types; - types.emplace_back(parseType().first); - while (L.cur().kind != ')') { + auto parseHead = [&]() -> TypePtr { + if (L.cur().kind == '(') { + std::vector types; + parseList('(', ',', ')', [&] { + auto r = parseType(); + types.push_back(std::move(r.first)); + if (alias_info && r.second) { + alias_info->addContainedType(std::move(*r.second)); + } + }); + return TupleType::create(std::move(types)); + } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") { + L.next(); // Future + L.expect('('); + auto p = parseType(); + auto subtype = std::move(p.first); + auto subalias = std::move(p.second); + L.expect(')'); + return FutureType::create(subtype); + } else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") { + L.next(); // RRef + L.expect('('); + auto p = parseType(); + auto subtype = std::move(p.first); + auto subalias = std::move(p.second); + L.expect(')'); + return RRefType::create(subtype); + } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") { + L.next(); + auto value = TensorType::get(); + alias_info = parseAliasAnnotation(); + return value; + } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") { + L.next(); + L.expect('('); + auto key_type = parseType().first; L.expect(','); + auto value_type = parseType().first; + L.expect(')'); + alias_info = parseAliasAnnotation(); + return DictType::create(key_type, value_type); + } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") { + L.next(); + L.expect('('); + std::vector types; types.emplace_back(parseType().first); + while (L.cur().kind != ')') { + L.expect(','); + types.emplace_back(parseType().first); + } + L.expect(')'); + alias_info = parseAliasAnnotation(); + return UnionType::create(types); + } else if ( + complete_tensor_types && L.cur().kind == TK_IDENT && + parseTensorDType(L.cur().text())) { + auto value = parseRefinedTensor(); + alias_info = parseAliasAnnotation(); + return value; + } else if (L.cur().kind == TK_IDENT && L.cur().text() == "__torch__") { + L.next(); + L.expect('.'); + auto torch_tok = L.expect(TK_IDENT); + if (torch_tok.text() != "torch") { + throw ErrorReport(torch_tok.range) + << "Expected classes namespace but got " << torch_tok.text(); + } + L.expect('.'); + auto classes_tok = L.expect(TK_IDENT); + if (classes_tok.text() != "classes") { + throw ErrorReport(classes_tok.range) + << "Expected classes namespace but got " << classes_tok.text(); + } + L.expect('.'); + auto ns_tok = L.expect(TK_IDENT); + L.expect('.'); + auto class_tok = L.expect(TK_IDENT); + auto value = getCustomClass( + std::string("__torch__.torch.classes.") + ns_tok.text() + "." + + class_tok.text()); + if (!value) { + throw ErrorReport(class_tok.range) + << "Unknown custom class type " + << ns_tok.text() + "." + class_tok.text() + << ". Please ensure it is registered."; + } + return value; + } else { + auto value = parseBaseType(); + alias_info = parseAliasAnnotation(); + return value; } - L.expect(')'); - alias_info = parseAliasAnnotation(); - value = UnionType::create(types); - } else if ( - complete_tensor_types && L.cur().kind == TK_IDENT && - parseTensorDType(L.cur().text())) { - value = parseRefinedTensor(); - alias_info = parseAliasAnnotation(); - } else if (L.cur().kind == TK_IDENT && L.cur().text() == "__torch__") { - L.next(); - L.expect('.'); - auto torch_tok = L.expect(TK_IDENT); - if (torch_tok.text() != "torch") { - throw ErrorReport(torch_tok.range) - << "Expected classes namespace but got " << torch_tok.text(); - } - L.expect('.'); - auto classes_tok = L.expect(TK_IDENT); - if (classes_tok.text() != "classes") { - throw ErrorReport(classes_tok.range) - << "Expected classes namespace but got " << classes_tok.text(); - } - L.expect('.'); - auto ns_tok = L.expect(TK_IDENT); - L.expect('.'); - auto class_tok = L.expect(TK_IDENT); - value = getCustomClass( - std::string("__torch__.torch.classes.") + ns_tok.text() + "." + - class_tok.text()); - if (!value) { - throw ErrorReport(class_tok.range) - << "Unknown custom class type " - << ns_tok.text() + "." + class_tok.text() - << ". Please ensure it is registered."; - } - } else { - value = parseBaseType(); - alias_info = parseAliasAnnotation(); - } + }; + + TypePtr value = parseHead(); + while (true) { if (L.cur().kind == '[' && L.lookahead().kind == ']') { L.next(); // [ diff --git a/torch/csrc/jit/mobile/model_compatibility.cpp b/torch/csrc/jit/mobile/model_compatibility.cpp index 6dd2aeab88ad2..f2a9d82f62e8c 100644 --- a/torch/csrc/jit/mobile/model_compatibility.cpp +++ b/torch/csrc/jit/mobile/model_compatibility.cpp @@ -335,7 +335,7 @@ ModelCompatCheckResult is_compatible( result.errors.emplace_back(s.str()); } - std::unordered_set supported_type = runtime_info.supported_types; + const auto& supported_type = runtime_info.supported_types; // Check type table for (const auto& type_name : model_info.type_table) { @@ -349,8 +349,7 @@ ModelCompatCheckResult is_compatible( } // Check operators - std::unordered_map operator_info = - model_info.operator_info; + const auto& operator_info = model_info.operator_info; for (auto const& op : operator_info) { std::string op_name = op.first; OperatorInfo model_op_info = op.second; diff --git a/torch/csrc/jit/mobile/type_parser.cpp b/torch/csrc/jit/mobile/type_parser.cpp index 2863fcf6d5b5c..9a40d1200acd6 100644 --- a/torch/csrc/jit/mobile/type_parser.cpp +++ b/torch/csrc/jit/mobile/type_parser.cpp @@ -89,7 +89,7 @@ const std::unordered_set& TypeParser::getCustomType() { // compatibility check between model and runtime. For example: // PyThon string: "Dict[int, Tuple[Tensor, Tensor, Tensor]]" // contained type is: [Dict, int, Tuple, Tensor] -std::unordered_set TypeParser::getContainedTypes() { +const std::unordered_set& TypeParser::getContainedTypes() { return contained_types_; } diff --git a/torch/csrc/jit/mobile/type_parser.h b/torch/csrc/jit/mobile/type_parser.h index e215b17a73a57..4a8a274c69605 100644 --- a/torch/csrc/jit/mobile/type_parser.h +++ b/torch/csrc/jit/mobile/type_parser.h @@ -14,7 +14,7 @@ class TORCH_API TypeParser { std::vector parseList(); static const std::unordered_set& getNonSimpleType(); static const std::unordered_set& getCustomType(); - std::unordered_set getContainedTypes(); + const std::unordered_set& getContainedTypes(); private: TypePtr parseNamedTuple(const std::string& qualified_name);