Skip to content

Commit

Permalink
[jit][edge] Reclaim some binary size.
Browse files Browse the repository at this point in the history
Pull Request resolved: #70255

Should give us 1-2KB back on mobile due to removal of redundant container copy and shared_ptr decrement.
ghstack-source-id: 146227585

Differential Revision: [D33230514](https://our.internmc.facebook.com/intern/diff/D33230514/)
  • Loading branch information
zhxchen17 committed Dec 24, 2021
1 parent b32a290 commit 509561f
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 120 deletions.
47 changes: 21 additions & 26 deletions aten/src/ATen/core/union_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,8 @@ void filterDuplicateSubtypes(std::vector<TypePtr>* types) {
size_t end_idx = types->size()-1;
for (size_t i = types->size()-1; i > 0; --i) {
for (size_t j = std::min(i-1, end_idx); ; --j) {
c10::optional<TypePtr> unified;
unified = get_supertype((*types)[i], (*types)[j]);
if (unified) {
(*types)[j] = *unified;
if (auto unified = get_supertype((*types)[i], (*types)[j])) {
(*types)[j] = std::move(*unified);
(*types)[i] = (*types)[end_idx];
--end_idx;
break;
Expand Down Expand Up @@ -160,13 +158,13 @@ void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten) {
"passed a `nullptr`");
std::vector<TypePtr> to_fill;
standardizeVectorForUnion(*to_flatten, &to_fill);
*to_flatten = to_fill;
*to_flatten = std::move(to_fill);
}

OptionalType::OptionalType(TypePtr contained)
: UnionType({contained, NoneType::get()}, TypeKind::OptionalType) {
bool is_numbertype = false;
if (auto as_union = contained->cast<UnionType>()) {
if (auto as_union = contained->castRaw<UnionType>()) {
is_numbertype = as_union->containedTypes().size() == 3 &&
as_union->canHoldType(*NumberType::get());
}
Expand Down Expand Up @@ -195,20 +193,17 @@ UnionType::UnionType(std::vector<TypePtr> reference, TypeKind kind) : Type(kind)
// Gate the assert in a regular conditional so that we don't create
// this long error message unnecessarily
if (types_.size() == 1) {
std::stringstream msg;
msg << "After type unification was performed, the Union with the "
<< "original types {";
std::string msg = "After type unification was performed, the Union with the original types {";
for (const auto i : c10::irange(reference.size())) {
msg << reference[i]->repr_str();
msg += reference[i]->repr_str();
if (i > 0) {
msg << ",";
msg += ",";
}
msg << " ";
msg += " ";
}
msg << "} has the single type " << types_[0]->repr_str()
<< ". Use the common supertype instead of creating a Union"
<< "type";
TORCH_INTERNAL_ASSERT(false, msg.str());
msg += "} has the single type " + types_[0]->repr_str()
+ ". Use the common supertype instead of creating a Union type";
TORCH_INTERNAL_ASSERT(false, msg);
}

can_hold_none_ = false;
Expand Down Expand Up @@ -357,7 +352,7 @@ bool UnionType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {

std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str)
const {
std::stringstream ss;
std::string ss;

bool can_hold_numbertype = this->canHoldType(*NumberType::get());

Expand All @@ -375,33 +370,33 @@ std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str)
std::string open_delimeter = is_annotation_str ? "[" : "(";
std::string close_delimeter = is_annotation_str ? "]" : ")";

ss << "Union" + open_delimeter;
ss = "Union" + open_delimeter;
bool printed = false;
for (size_t i = 0; i < types_.size(); ++i) {
if (!can_hold_numbertype || !is_numbertype(types_[i])) {
if (i > 0) {
ss << ", ";
ss += ", ";
printed = true;
}
if (is_annotation_str) {
ss << this->containedTypes()[i]->annotation_str(printer);
ss += this->containedTypes()[i]->annotation_str(printer);
} else {
ss << this->containedTypes()[i]->str();
ss += this->containedTypes()[i]->str();
}
}
}
if (can_hold_numbertype) {
if (printed) {
ss << ", ";
ss += ", ";
}
if (is_annotation_str) {
ss << NumberType::get()->annotation_str(printer);
ss += NumberType::get()->annotation_str(printer);
} else {
ss << NumberType::get()->str();
ss += NumberType::get()->str();
}
}
ss << close_delimeter;
return ss.str();
ss += close_delimeter;
return ss;
}

std::string UnionType::str() const {
Expand Down
178 changes: 93 additions & 85 deletions torch/csrc/jit/frontend/schema_type_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,96 +289,104 @@ TypePtr SchemaTypeParser::parseRefinedTensor() {
}

std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
TypePtr value;
c10::optional<AliasInfo> alias_info;
// Tuple type
if (L.cur().kind == '(') {
std::vector<TypePtr> types;
parseList('(', ',', ')', [&] {
auto r = parseType();
types.push_back(std::move(r.first));
if (alias_info && r.second) {
alias_info->addContainedType(std::move(*r.second));
}
});
value = TupleType::create(std::move(types));
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
L.next(); // Future
L.expect('(');
auto p = parseType();
auto subtype = std::move(p.first);
auto subalias = std::move(p.second);
L.expect(')');
value = FutureType::create(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") {
L.next(); // RRef
L.expect('(');
auto p = parseType();
auto subtype = std::move(p.first);
auto subalias = std::move(p.second);
L.expect(')');
value = RRefType::create(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
L.next();
value = TensorType::get();
alias_info = parseAliasAnnotation();
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
L.next();
L.expect('(');
auto key_type = parseType().first;
L.expect(',');
auto value_type = parseType().first;
L.expect(')');
alias_info = parseAliasAnnotation();
value = DictType::create(key_type, value_type);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") {
L.next();
L.expect('(');
std::vector<TypePtr> types;
types.emplace_back(parseType().first);
while (L.cur().kind != ')') {
auto parseHead = [&]() -> TypePtr {
if (L.cur().kind == '(') {
std::vector<TypePtr> types;
parseList('(', ',', ')', [&] {
auto r = parseType();
types.push_back(std::move(r.first));
if (alias_info && r.second) {
alias_info->addContainedType(std::move(*r.second));
}
});
return TupleType::create(std::move(types));
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
L.next(); // Future
L.expect('(');
auto p = parseType();
auto subtype = std::move(p.first);
auto subalias = std::move(p.second);
L.expect(')');
return FutureType::create(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") {
L.next(); // RRef
L.expect('(');
auto p = parseType();
auto subtype = std::move(p.first);
auto subalias = std::move(p.second);
L.expect(')');
return RRefType::create(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
L.next();
auto value = TensorType::get();
alias_info = parseAliasAnnotation();
return value;
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
L.next();
L.expect('(');
auto key_type = parseType().first;
L.expect(',');
auto value_type = parseType().first;
L.expect(')');
alias_info = parseAliasAnnotation();
return DictType::create(key_type, value_type);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") {
L.next();
L.expect('(');
std::vector<TypePtr> types;
types.emplace_back(parseType().first);
while (L.cur().kind != ')') {
L.expect(',');
types.emplace_back(parseType().first);
}
L.expect(')');
alias_info = parseAliasAnnotation();
return UnionType::create(types);
} else if (
complete_tensor_types && L.cur().kind == TK_IDENT &&
parseTensorDType(L.cur().text())) {
auto value = parseRefinedTensor();
alias_info = parseAliasAnnotation();
return value;
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "__torch__") {
L.next();
L.expect('.');
auto torch_tok = L.expect(TK_IDENT);
if (torch_tok.text() != "torch") {
throw ErrorReport(torch_tok.range)
<< "Expected classes namespace but got " << torch_tok.text();
}
L.expect('.');
auto classes_tok = L.expect(TK_IDENT);
if (classes_tok.text() != "classes") {
throw ErrorReport(classes_tok.range)
<< "Expected classes namespace but got " << classes_tok.text();
}
L.expect('.');
auto ns_tok = L.expect(TK_IDENT);
L.expect('.');
auto class_tok = L.expect(TK_IDENT);
auto value = getCustomClass(
std::string("__torch__.torch.classes.") + ns_tok.text() + "." +
class_tok.text());
if (!value) {
throw ErrorReport(class_tok.range)
<< "Unknown custom class type "
<< ns_tok.text() + "." + class_tok.text()
<< ". Please ensure it is registered.";
}
return value;
} else {
auto value = parseBaseType();
alias_info = parseAliasAnnotation();
return value;
}
L.expect(')');
alias_info = parseAliasAnnotation();
value = UnionType::create(types);
} else if (
complete_tensor_types && L.cur().kind == TK_IDENT &&
parseTensorDType(L.cur().text())) {
value = parseRefinedTensor();
alias_info = parseAliasAnnotation();
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "__torch__") {
L.next();
L.expect('.');
auto torch_tok = L.expect(TK_IDENT);
if (torch_tok.text() != "torch") {
throw ErrorReport(torch_tok.range)
<< "Expected classes namespace but got " << torch_tok.text();
}
L.expect('.');
auto classes_tok = L.expect(TK_IDENT);
if (classes_tok.text() != "classes") {
throw ErrorReport(classes_tok.range)
<< "Expected classes namespace but got " << classes_tok.text();
}
L.expect('.');
auto ns_tok = L.expect(TK_IDENT);
L.expect('.');
auto class_tok = L.expect(TK_IDENT);
value = getCustomClass(
std::string("__torch__.torch.classes.") + ns_tok.text() + "." +
class_tok.text());
if (!value) {
throw ErrorReport(class_tok.range)
<< "Unknown custom class type "
<< ns_tok.text() + "." + class_tok.text()
<< ". Please ensure it is registered.";
}
} else {
value = parseBaseType();
alias_info = parseAliasAnnotation();
}
};

TypePtr value = parseHead();

while (true) {
if (L.cur().kind == '[' && L.lookahead().kind == ']') {
L.next(); // [
Expand Down
5 changes: 2 additions & 3 deletions torch/csrc/jit/mobile/model_compatibility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ ModelCompatCheckResult is_compatible(
result.errors.emplace_back(s.str());
}

std::unordered_set<std::string> supported_type = runtime_info.supported_types;
const auto& supported_type = runtime_info.supported_types;

// Check type table
for (const auto& type_name : model_info.type_table) {
Expand All @@ -349,8 +349,7 @@ ModelCompatCheckResult is_compatible(
}

// Check operators
std::unordered_map<std::string, OperatorInfo> operator_info =
model_info.operator_info;
const auto& operator_info = model_info.operator_info;
for (auto const& op : operator_info) {
std::string op_name = op.first;
OperatorInfo model_op_info = op.second;
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/jit/mobile/type_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ std::vector<TypePtr> TypeParser::parseList() {
}

// The list of non-simple types supported by currrent parser.
std::unordered_set<std::string> TypeParser::getNonSimpleType() {
const std::unordered_set<std::string>& TypeParser::getNonSimpleType() {
static std::unordered_set<std::string> nonSimpleTypes{
"List", "Optional", "Dict", "Tuple"};
return nonSimpleTypes;
}

// The list of custom types supported by currrent parser.
std::unordered_set<std::string> TypeParser::getCustomType() {
const std::unordered_set<std::string>& TypeParser::getCustomType() {
static std::unordered_set<std::string> customeTypes{
kTypeTorchbindCustomClass, kTypeNamedTuple};
return customeTypes;
Expand All @@ -89,7 +89,7 @@ std::unordered_set<std::string> TypeParser::getCustomType() {
// compatibility check between model and runtime. For example:
// PyThon string: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
// contained type is: [Dict, int, Tuple, Tensor]
std::unordered_set<std::string> TypeParser::getContainedTypes() {
const std::unordered_set<std::string>& TypeParser::getContainedTypes() {
return contained_types_;
}

Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/jit/mobile/type_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ class TORCH_API TypeParser {

TypePtr parse();
std::vector<TypePtr> parseList();
static std::unordered_set<std::string> getNonSimpleType();
static std::unordered_set<std::string> getCustomType();
std::unordered_set<std::string> getContainedTypes();
static const std::unordered_set<std::string>& getNonSimpleType();
static const std::unordered_set<std::string>& getCustomType();
const std::unordered_set<std::string>& getContainedTypes();

private:
TypePtr parseNamedTuple(const std::string& qualified_name);
Expand Down

0 comments on commit 509561f

Please sign in to comment.