Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jit][edge] Reclaim some binary size. #70255

Closed
wants to merge 9 commits into from
47 changes: 21 additions & 26 deletions aten/src/ATen/core/union_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,8 @@ void filterDuplicateSubtypes(std::vector<TypePtr>* types) {
size_t end_idx = types->size()-1;
for (size_t i = types->size()-1; i > 0; --i) {
for (size_t j = std::min(i-1, end_idx); ; --j) {
c10::optional<TypePtr> unified;
unified = get_supertype((*types)[i], (*types)[j]);
if (unified) {
(*types)[j] = *unified;
if (auto unified = get_supertype((*types)[i], (*types)[j])) {
(*types)[j] = std::move(*unified);
(*types)[i] = (*types)[end_idx];
--end_idx;
break;
Expand Down Expand Up @@ -160,13 +158,13 @@ void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten) {
"passed a `nullptr`");
std::vector<TypePtr> to_fill;
standardizeVectorForUnion(*to_flatten, &to_fill);
*to_flatten = to_fill;
*to_flatten = std::move(to_fill);
}

OptionalType::OptionalType(TypePtr contained)
: UnionType({contained, NoneType::get()}, TypeKind::OptionalType) {
bool is_numbertype = false;
if (auto as_union = contained->cast<UnionType>()) {
if (auto as_union = contained->castRaw<UnionType>()) {
is_numbertype = as_union->containedTypes().size() == 3 &&
as_union->canHoldType(*NumberType::get());
}
Expand Down Expand Up @@ -195,20 +193,17 @@ UnionType::UnionType(std::vector<TypePtr> reference, TypeKind kind) : SharedType
// Gate the assert in a regular conditional so that we don't create
// this long error message unnecessarily
if (types_.size() == 1) {
std::stringstream msg;
msg << "After type unification was performed, the Union with the "
<< "original types {";
std::string msg = "After type unification was performed, the Union with the original types {";
for (const auto i : c10::irange(reference.size())) {
msg << reference[i]->repr_str();
msg += reference[i]->repr_str();
if (i > 0) {
msg << ",";
msg += ",";
}
msg << " ";
msg += " ";
}
msg << "} has the single type " << types_[0]->repr_str()
<< ". Use the common supertype instead of creating a Union"
<< "type";
TORCH_INTERNAL_ASSERT(false, msg.str());
msg += "} has the single type " + types_[0]->repr_str()
+ ". Use the common supertype instead of creating a Union type";
TORCH_INTERNAL_ASSERT(false, msg);
}

can_hold_none_ = false;
Expand Down Expand Up @@ -357,7 +352,7 @@ bool UnionType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {

std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str)
const {
std::stringstream ss;
std::string ss;

bool can_hold_numbertype = this->canHoldType(*NumberType::get());

Expand All @@ -375,33 +370,33 @@ std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str)
std::string open_delimeter = is_annotation_str ? "[" : "(";
std::string close_delimeter = is_annotation_str ? "]" : ")";

ss << "Union" + open_delimeter;
ss = "Union" + open_delimeter;
bool printed = false;
for (size_t i = 0; i < types_.size(); ++i) {
if (!can_hold_numbertype || !is_numbertype(types_[i])) {
if (i > 0) {
ss << ", ";
ss += ", ";
printed = true;
}
if (is_annotation_str) {
ss << this->containedTypes()[i]->annotation_str(printer);
ss += this->containedTypes()[i]->annotation_str(printer);
} else {
ss << this->containedTypes()[i]->str();
ss += this->containedTypes()[i]->str();
}
}
}
if (can_hold_numbertype) {
if (printed) {
ss << ", ";
ss += ", ";
}
if (is_annotation_str) {
ss << NumberType::get()->annotation_str(printer);
ss += NumberType::get()->annotation_str(printer);
} else {
ss << NumberType::get()->str();
ss += NumberType::get()->str();
}
}
ss << close_delimeter;
return ss.str();
ss += close_delimeter;
return ss;
}

std::string UnionType::str() const {
Expand Down
178 changes: 93 additions & 85 deletions torch/csrc/jit/frontend/schema_type_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,96 +289,104 @@ TypePtr SchemaTypeParser::parseRefinedTensor() {
}

std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
TypePtr value;
c10::optional<AliasInfo> alias_info;
// Tuple type
if (L.cur().kind == '(') {
std::vector<TypePtr> types;
parseList('(', ',', ')', [&] {
auto r = parseType();
types.push_back(std::move(r.first));
if (alias_info && r.second) {
alias_info->addContainedType(std::move(*r.second));
}
});
value = TupleType::create(std::move(types));
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
L.next(); // Future
L.expect('(');
auto p = parseType();
auto subtype = std::move(p.first);
auto subalias = std::move(p.second);
L.expect(')');
value = FutureType::create(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") {
L.next(); // RRef
L.expect('(');
auto p = parseType();
auto subtype = std::move(p.first);
auto subalias = std::move(p.second);
L.expect(')');
value = RRefType::create(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
L.next();
value = TensorType::get();
alias_info = parseAliasAnnotation();
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
L.next();
L.expect('(');
auto key_type = parseType().first;
L.expect(',');
auto value_type = parseType().first;
L.expect(')');
alias_info = parseAliasAnnotation();
value = DictType::create(key_type, value_type);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") {
L.next();
L.expect('(');
std::vector<TypePtr> types;
types.emplace_back(parseType().first);
while (L.cur().kind != ')') {
auto parseHead = [&]() -> TypePtr {
if (L.cur().kind == '(') {
std::vector<TypePtr> types;
parseList('(', ',', ')', [&] {
auto r = parseType();
types.push_back(std::move(r.first));
if (alias_info && r.second) {
alias_info->addContainedType(std::move(*r.second));
}
});
return TupleType::create(std::move(types));
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
L.next(); // Future
L.expect('(');
auto p = parseType();
auto subtype = std::move(p.first);
auto subalias = std::move(p.second);
L.expect(')');
return FutureType::create(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") {
L.next(); // RRef
L.expect('(');
auto p = parseType();
auto subtype = std::move(p.first);
auto subalias = std::move(p.second);
L.expect(')');
return RRefType::create(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
L.next();
auto value = TensorType::get();
alias_info = parseAliasAnnotation();
return value;
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
L.next();
L.expect('(');
auto key_type = parseType().first;
L.expect(',');
auto value_type = parseType().first;
L.expect(')');
alias_info = parseAliasAnnotation();
return DictType::create(key_type, value_type);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") {
L.next();
L.expect('(');
std::vector<TypePtr> types;
types.emplace_back(parseType().first);
while (L.cur().kind != ')') {
L.expect(',');
types.emplace_back(parseType().first);
}
L.expect(')');
alias_info = parseAliasAnnotation();
return UnionType::create(types);
} else if (
complete_tensor_types && L.cur().kind == TK_IDENT &&
parseTensorDType(L.cur().text())) {
auto value = parseRefinedTensor();
alias_info = parseAliasAnnotation();
return value;
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "__torch__") {
L.next();
L.expect('.');
auto torch_tok = L.expect(TK_IDENT);
if (torch_tok.text() != "torch") {
throw ErrorReport(torch_tok.range)
<< "Expected classes namespace but got " << torch_tok.text();
}
L.expect('.');
auto classes_tok = L.expect(TK_IDENT);
if (classes_tok.text() != "classes") {
throw ErrorReport(classes_tok.range)
<< "Expected classes namespace but got " << classes_tok.text();
}
L.expect('.');
auto ns_tok = L.expect(TK_IDENT);
L.expect('.');
auto class_tok = L.expect(TK_IDENT);
auto value = getCustomClass(
std::string("__torch__.torch.classes.") + ns_tok.text() + "." +
class_tok.text());
if (!value) {
throw ErrorReport(class_tok.range)
<< "Unknown custom class type "
<< ns_tok.text() + "." + class_tok.text()
<< ". Please ensure it is registered.";
}
return value;
} else {
auto value = parseBaseType();
alias_info = parseAliasAnnotation();
return value;
}
L.expect(')');
alias_info = parseAliasAnnotation();
value = UnionType::create(types);
} else if (
complete_tensor_types && L.cur().kind == TK_IDENT &&
parseTensorDType(L.cur().text())) {
value = parseRefinedTensor();
alias_info = parseAliasAnnotation();
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "__torch__") {
L.next();
L.expect('.');
auto torch_tok = L.expect(TK_IDENT);
if (torch_tok.text() != "torch") {
throw ErrorReport(torch_tok.range)
<< "Expected classes namespace but got " << torch_tok.text();
}
L.expect('.');
auto classes_tok = L.expect(TK_IDENT);
if (classes_tok.text() != "classes") {
throw ErrorReport(classes_tok.range)
<< "Expected classes namespace but got " << classes_tok.text();
}
L.expect('.');
auto ns_tok = L.expect(TK_IDENT);
L.expect('.');
auto class_tok = L.expect(TK_IDENT);
value = getCustomClass(
std::string("__torch__.torch.classes.") + ns_tok.text() + "." +
class_tok.text());
if (!value) {
throw ErrorReport(class_tok.range)
<< "Unknown custom class type "
<< ns_tok.text() + "." + class_tok.text()
<< ". Please ensure it is registered.";
}
} else {
value = parseBaseType();
alias_info = parseAliasAnnotation();
}
};

TypePtr value = parseHead();

while (true) {
if (L.cur().kind == '[' && L.lookahead().kind == ']') {
L.next(); // [
Expand Down
5 changes: 2 additions & 3 deletions torch/csrc/jit/mobile/model_compatibility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ ModelCompatCheckResult is_compatible(
result.errors.emplace_back(s.str());
}

std::unordered_set<std::string> supported_type = runtime_info.supported_types;
const auto& supported_type = runtime_info.supported_types;

// Check type table
for (const auto& type_name : model_info.type_table) {
Expand All @@ -349,8 +349,7 @@ ModelCompatCheckResult is_compatible(
}

// Check operators
std::unordered_map<std::string, OperatorInfo> operator_info =
model_info.operator_info;
const auto& operator_info = model_info.operator_info;
for (auto const& op : operator_info) {
std::string op_name = op.first;
OperatorInfo model_op_info = op.second;
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/mobile/type_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ const std::unordered_set<std::string>& TypeParser::getCustomType() {
// compatibility check between model and runtime. For example:
// PyThon string: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
// contained type is: [Dict, int, Tuple, Tensor]
std::unordered_set<std::string> TypeParser::getContainedTypes() {
const std::unordered_set<std::string>& TypeParser::getContainedTypes() {
return contained_types_;
}

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/mobile/type_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TORCH_API TypeParser {
std::vector<TypePtr> parseList();
static const std::unordered_set<std::string>& getNonSimpleType();
static const std::unordered_set<std::string>& getCustomType();
std::unordered_set<std::string> getContainedTypes();
const std::unordered_set<std::string>& getContainedTypes();

private:
TypePtr parseNamedTuple(const std::string& qualified_name);
Expand Down