From 8f51c5c8b37393d2e4a5863490001fc0afcf5ca1 Mon Sep 17 00:00:00 2001 From: Meghan Lele Date: Tue, 10 Nov 2020 21:35:18 -0800 Subject: [PATCH] [JIT] Resolve string literal type annotations using `Resolver::resolveType` **Summary** This commit modifies `ScriptTypeParser::parseTypeFromExpr` so that string literal type annotations are resolved using `Resolver::resolveType`. At present, they are parsed in `parseBaseTypeName`, which inadvertently allows any key from `string_to_type_lut` to be used as a string literal type annotation. **Test Plan** Existing unit tests (most notably `TestClassType.test_self_referential_method` which tests the main feature, self-referential class type annotations, that make use of string literal type annotations). **Fixes** This commit fixes #47570. ghstack-source-id: 9e5a2c83f6211084a378df1863e662e2ced936ce Pull Request resolved: https://github.com/pytorch/pytorch/pull/47731 --- torch/csrc/jit/frontend/script_type_parser.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index bec9e879a397..9ce631cede92 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -149,9 +149,6 @@ c10::optional ScriptTypeParser::parseBaseTypeName( case TK_NONE: { return "None"; } - case TK_STRINGLITERAL: { - return StringLiteral(expr).text(); - } case '.': { auto select = Select(expr); const std::string& name = select.selector().name(); @@ -190,6 +187,15 @@ TypePtr ScriptTypeParser::parseTypeFromExprImpl(const Expr& expr) const { } return subscriptToType(*value_name, subscript); + } else if (expr.kind() == TK_STRINGLITERAL) { + auto type_name = StringLiteral(expr).text(); + if (resolver_) { + if (auto typePtr = resolver_->resolveType(type_name, expr.range())) { + return typePtr; + } + } + + throw ErrorReport(expr) << "Unknown type name '" << type_name << "'"; } else if (auto name = parseBaseTypeName(expr)) { auto itr = string_to_type_lut().find(*name); if (itr != string_to_type_lut().end()) {