From 246113d4d615bf24f272cb18f32a76d8890dd6dd Mon Sep 17 00:00:00 2001 From: Roberto Raggi Date: Thu, 28 Aug 2025 18:49:27 +0200 Subject: [PATCH] Clean up symbol instantiation --- src/parser/cxx/ast_rewriter.cc | 249 +++++++++----------- src/parser/cxx/ast_rewriter.h | 19 +- src/parser/cxx/ast_rewriter_declarations.cc | 3 + src/parser/cxx/ast_rewriter_names.cc | 4 +- src/parser/cxx/binder.cc | 6 +- src/parser/cxx/parser.cc | 10 +- src/parser/cxx/symbols.cc | 22 ++ src/parser/cxx/symbols.h | 25 ++ 8 files changed, 175 insertions(+), 163 deletions(-) diff --git a/src/parser/cxx/ast_rewriter.cc b/src/parser/cxx/ast_rewriter.cc index a441ef53..c5f9b303 100644 --- a/src/parser/cxx/ast_rewriter.cc +++ b/src/parser/cxx/ast_rewriter.cc @@ -34,6 +34,98 @@ namespace cxx { +namespace { +struct GetTemplateDeclaration { + auto operator()(ClassSymbol* symbol) -> TemplateDeclarationAST* { + return symbol->templateDeclaration(); + } + + auto operator()(VariableSymbol* symbol) -> TemplateDeclarationAST* { + return symbol->templateDeclaration(); + } + + auto operator()(TypeAliasSymbol* symbol) -> TemplateDeclarationAST* { + return symbol->templateDeclaration(); + } + + auto operator()(Symbol*) -> TemplateDeclarationAST* { return nullptr; } +}; + +struct GetDeclaration { + auto operator()(ClassSymbol* symbol) -> AST* { return symbol->declaration(); } + + auto operator()(VariableSymbol* symbol) -> AST* { + return symbol->templateDeclaration()->declaration; + } + + auto operator()(TypeAliasSymbol* symbol) -> AST* { + return symbol->templateDeclaration()->declaration; + } + + auto operator()(Symbol*) -> AST* { return nullptr; } +}; + +struct GetSpecialization { + const std::vector& templateArguments; + + auto operator()(ClassSymbol* symbol) -> Symbol* { + return symbol->findSpecialization(templateArguments); + } + + auto operator()(VariableSymbol* symbol) -> Symbol* { + return symbol->findSpecialization(templateArguments); + } + + auto operator()(TypeAliasSymbol* symbol) -> Symbol* { + return symbol->findSpecialization(templateArguments); + } + + auto operator()(Symbol*) -> Symbol* { return nullptr; } +}; + +struct Instantiate { + ASTRewriter& rewriter; + + auto operator()(ClassSymbol* symbol) -> Symbol* { + auto classSpecifier = ast_cast(symbol->declaration()); + if (!classSpecifier) return nullptr; + + auto instance = + ast_cast(rewriter.specifier(classSpecifier)); + + if (!instance) return nullptr; + + return instance->symbol; + } + + auto operator()(VariableSymbol* symbol) -> Symbol* { + auto declaration = symbol->templateDeclaration()->declaration; + auto instance = ast_cast( + rewriter.declaration(ast_cast(declaration))); + + if (!instance) return nullptr; + + auto instantiatedSymbol = instance->initDeclaratorList->value->symbol; + auto instantiatedVariable = symbol_cast(instantiatedSymbol); + + return instantiatedVariable; + } + + auto operator()(TypeAliasSymbol* symbol) -> Symbol* { + auto declaration = symbol->templateDeclaration()->declaration; + + auto instance = ast_cast( + rewriter.declaration(ast_cast(declaration))); + + if (!instance) return nullptr; + + return instance->symbol; + } + + auto operator()(Symbol*) -> Symbol* { return nullptr; } +}; +} // namespace + ASTRewriter::ASTRewriter(TranslationUnit* unit, ScopeSymbol* scope, const std::vector& templateArguments) : unit_(unit), templateArguments_(templateArguments), binder_(unit_) { @@ -91,69 +183,18 @@ auto ASTRewriter::getParameterPack(ExpressionAST* ast) -> ParameterPackSymbol* { return nullptr; } -auto ASTRewriter::instantiateClassTemplate( - TranslationUnit* unit, List* templateArgumentList, - ClassSymbol* classSymbol) -> ClassSymbol* { - auto templateDecl = classSymbol->templateDeclaration(); - - if (!classSymbol->declaration()) return nullptr; - - auto templateArguments = - make_substitution(unit, templateDecl, templateArgumentList); - - auto is_primary_template = [&]() -> bool { - int expected = 0; - for (const auto& arg : templateArguments) { - if (!std::holds_alternative(arg)) return false; - - auto ty = type_cast(std::get(arg)->type()); - if (!ty) return false; - - if (ty->index() != expected) return false; - ++expected; - } - return true; - }; - - if (is_primary_template()) { - // if this is a primary template, we can just return the class symbol - return classSymbol; - } - - auto subst = classSymbol->findSpecialization(templateArguments); - if (subst) { - return subst; - } - - auto classSpecifier = ast_cast(classSymbol->declaration()); - if (!classSpecifier) return nullptr; - - auto parentScope = classSymbol->enclosingNonTemplateParametersScope(); - - auto rewriter = ASTRewriter{unit, parentScope, templateArguments}; - rewriter.depth_ = templateDecl->depth; - - rewriter.binder().setInstantiatingSymbol(classSymbol); - - auto instance = - ast_cast(rewriter.specifier(classSpecifier)); - - if (!instance) return nullptr; - - auto classInstance = instance->symbol; - - return classInstance; -} - -auto ASTRewriter::instantiateTypeAliasTemplate( - TranslationUnit* unit, List* templateArgumentList, - TypeAliasSymbol* typeAliasSymbol) -> TypeAliasSymbol* { - auto templateDecl = typeAliasSymbol->templateDeclaration(); +auto ASTRewriter::instantiate(TranslationUnit* unit, + List* templateArgumentList, + Symbol* symbol) -> Symbol* { + auto classSymbol = symbol_cast(symbol); + auto variableSymbol = symbol_cast(symbol); + auto typeAliasSymbol = symbol_cast(symbol); - auto aliasDeclaration = - ast_cast(templateDecl->declaration); + auto templateDecl = visit(GetTemplateDeclaration{}, symbol); + if (!templateDecl) return nullptr; - if (!aliasDeclaration) return nullptr; + auto declaration = visit(GetDeclaration{}, symbol); + if (!declaration) return nullptr; auto templateArguments = make_substitution(unit, templateDecl, templateArgumentList); @@ -174,93 +215,21 @@ auto ASTRewriter::instantiateTypeAliasTemplate( if (is_primary_template()) { // if this is a primary template, we can just return the class symbol - return typeAliasSymbol; + return symbol; } -#if false - auto subst = typeAliasSymbol->findSpecialization(templateArguments); - if (subst) { - return subst; - } -#endif - - auto parentScope = typeAliasSymbol->parent(); - while (parentScope->isTemplateParameters()) { - parentScope = parentScope->parent(); - } - - auto rewriter = ASTRewriter{unit, parentScope, templateArguments}; - - rewriter.binder().setInstantiatingSymbol(typeAliasSymbol); - - auto instance = - ast_cast(rewriter.declaration(aliasDeclaration)); - - if (!instance) return nullptr; - - return instance->symbol; -} - -auto ASTRewriter::instantiateVariableTemplate( - TranslationUnit* unit, List* templateArgumentList, - VariableSymbol* variableSymbol) -> VariableSymbol* { - auto templateDecl = variableSymbol->templateDeclaration(); - - if (!templateDecl) { - unit->error(variableSymbol->location(), "not a template"); - return nullptr; - } + auto specialization = visit(GetSpecialization{templateArguments}, symbol); - auto variableDeclaration = - ast_cast(templateDecl->declaration); + if (specialization) return specialization; - if (!variableDeclaration) return nullptr; - - auto templateArguments = - make_substitution(unit, templateDecl, templateArgumentList); - - auto is_primary_template = [&]() -> bool { - int expected = 0; - for (const auto& arg : templateArguments) { - if (!std::holds_alternative(arg)) return false; - - auto ty = type_cast(std::get(arg)->type()); - if (!ty) return false; - - if (ty->index() != expected) return false; - ++expected; - } - return true; - }; - - if (is_primary_template()) { - // if this is a primary template, we can just return the class symbol - return variableSymbol; - } - - auto subst = variableSymbol->findSpecialization(templateArguments); - if (subst) { - return subst; - } - - auto parentScope = variableSymbol->parent(); - while (parentScope->isTemplateParameters()) { - parentScope = parentScope->parent(); - } + auto parentScope = symbol->enclosingNonTemplateParametersScope(); auto rewriter = ASTRewriter{unit, parentScope, templateArguments}; + rewriter.depth_ = templateDecl->depth; - rewriter.binder().setInstantiatingSymbol(variableSymbol); - - auto instance = - ast_cast(rewriter.declaration(variableDeclaration)); - - if (!instance) return nullptr; - - auto instantiatedSymbol = instance->initDeclaratorList->value->symbol; - auto instantiatedVariable = symbol_cast(instantiatedSymbol); + rewriter.binder().setInstantiatingSymbol(symbol); - return instantiatedVariable; + return visit(Instantiate{rewriter}, symbol); } auto ASTRewriter::make_substitution( diff --git a/src/parser/cxx/ast_rewriter.h b/src/parser/cxx/ast_rewriter.h index 99181236..e97a0586 100644 --- a/src/parser/cxx/ast_rewriter.h +++ b/src/parser/cxx/ast_rewriter.h @@ -34,17 +34,9 @@ class Arena; class ASTRewriter { public: - [[nodiscard]] static auto instantiateClassTemplate( + [[nodiscard]] static auto instantiate( TranslationUnit* unit, List* templateArgumentList, - ClassSymbol* symbol) -> ClassSymbol*; - - [[nodiscard]] static auto instantiateTypeAliasTemplate( - TranslationUnit* unit, List* templateArgumentList, - TypeAliasSymbol* typeAliasSymbol) -> TypeAliasSymbol*; - - [[nodiscard]] static auto instantiateVariableTemplate( - TranslationUnit* unit, List* templateArgumentList, - VariableSymbol* variableSymbol) -> VariableSymbol*; + Symbol* symbol) -> Symbol*; explicit ASTRewriter(TranslationUnit* unit, ScopeSymbol* scope, const std::vector& templateArguments); @@ -58,6 +50,10 @@ class ASTRewriter { TemplateDeclarationAST* templateHead = nullptr) -> DeclarationAST*; + [[nodiscard]] auto specifier(SpecifierAST* ast, + TemplateDeclarationAST* templateHead = nullptr) + -> SpecifierAST*; + [[nodiscard]] static auto make_substitution( TranslationUnit* unit, TemplateDeclarationAST* templateDecl, List* templateArgumentList) @@ -91,9 +87,6 @@ class ASTRewriter { [[nodiscard]] auto designator(DesignatorAST* ast) -> DesignatorAST*; [[nodiscard]] auto templateParameter(TemplateParameterAST* ast) -> TemplateParameterAST*; - [[nodiscard]] auto specifier(SpecifierAST* ast, - TemplateDeclarationAST* templateHead = nullptr) - -> SpecifierAST*; [[nodiscard]] auto ptrOperator(PtrOperatorAST* ast) -> PtrOperatorAST*; [[nodiscard]] auto coreDeclarator(CoreDeclaratorAST* ast) -> CoreDeclaratorAST*; diff --git a/src/parser/cxx/ast_rewriter_declarations.cc b/src/parser/cxx/ast_rewriter_declarations.cc index 6dfa2fed..018c7ec8 100644 --- a/src/parser/cxx/ast_rewriter_declarations.cc +++ b/src/parser/cxx/ast_rewriter_declarations.cc @@ -462,6 +462,9 @@ auto ASTRewriter::DeclarationVisitor::operator()(AliasDeclarationAST* ast) auto symbol = binder()->declareTypeAlias(copy->identifierLoc, copy->typeId, addSymbolToParentScope); + if (!addSymbolToParentScope) { + ast->symbol->addSpecialization(rewrite.templateArguments(), symbol); + } // symbol->setTemplateDeclaration(templateHead); copy->symbol = symbol; diff --git a/src/parser/cxx/ast_rewriter_names.cc b/src/parser/cxx/ast_rewriter_names.cc index c073b7d8..1cb63226 100644 --- a/src/parser/cxx/ast_rewriter_names.cc +++ b/src/parser/cxx/ast_rewriter_names.cc @@ -305,10 +305,10 @@ auto ASTRewriter::NestedNameSpecifierVisitor::operator()( auto classSymbol = symbol_cast(copy->symbol); - auto instance = ASTRewriter::instantiateClassTemplate( + auto instance = ASTRewriter::instantiate( translationUnit(), copy->templateId->templateArgumentList, classSymbol); - copy->symbol = instance; + copy->symbol = symbol_cast(instance); return copy; } diff --git a/src/parser/cxx/binder.cc b/src/parser/cxx/binder.cc index d80e1dcf..67ccadc8 100644 --- a/src/parser/cxx/binder.cc +++ b/src/parser/cxx/binder.cc @@ -914,7 +914,7 @@ auto Binder::resolve(NestedNameSpecifierAST* nestedNameSpecifier, if (auto classSymbol = symbol_cast(templateId->symbol)) { // todo: delay - auto instance = ASTRewriter::instantiateClassTemplate( + auto instance = ASTRewriter::instantiate( unit_, templateId->templateArgumentList, classSymbol); return instance; @@ -922,7 +922,7 @@ auto Binder::resolve(NestedNameSpecifierAST* nestedNameSpecifier, if (auto typeAliasSymbol = symbol_cast(templateId->symbol)) { - auto instance = ASTRewriter::instantiateTypeAliasTemplate( + auto instance = ASTRewriter::instantiate( unit_, templateId->templateArgumentList, typeAliasSymbol); return instance; @@ -973,7 +973,7 @@ void Binder::bind(IdExpressionAST* ast) { if (!var) { error(templateId->firstSourceLocation(), std::format("not a template")); } else { - auto instance = ASTRewriter::instantiateVariableTemplate( + auto instance = ASTRewriter::instantiate( unit_, templateId->templateArgumentList, var); ast->symbol = instance; diff --git a/src/parser/cxx/parser.cc b/src/parser/cxx/parser.cc index 98658c00..4f51931d 100644 --- a/src/parser/cxx/parser.cc +++ b/src/parser/cxx/parser.cc @@ -1057,13 +1057,13 @@ auto Parser::parse_template_nested_name_specifier( if (config().checkTypes) { // replace with instantiated class template if (auto classSymbol = symbol_cast(templateId->symbol)) { - auto instance = ASTRewriter::instantiateClassTemplate( + auto instance = ASTRewriter::instantiate( unit, templateId->templateArgumentList, classSymbol); - ast->symbol = instance; + ast->symbol = symbol_cast(instance); } else if (auto typeAliasSymbol = symbol_cast(templateId->symbol)) { - auto instance = ASTRewriter::instantiateTypeAliasTemplate( + auto instance = ASTRewriter::instantiate( unit, templateId->templateArgumentList, typeAliasSymbol); ast->symbol = symbol_cast(instance); @@ -9758,7 +9758,7 @@ auto Parser::parse_explicit_instantiation(DeclarationAST*& yyast) -> bool { } if (config().checkTypes) { - auto instance = ASTRewriter::instantiateClassTemplate( + auto instance = ASTRewriter::instantiate( unit, templateId->templateArgumentList, classSymbol); (void)instance; @@ -9783,7 +9783,7 @@ auto Parser::parse_explicit_instantiation(DeclarationAST*& yyast) -> bool { } if (config().checkTypes) { - auto instance = ASTRewriter::instantiateClassTemplate( + auto instance = ASTRewriter::instantiate( unit, templateId->templateArgumentList, classSymbol); (void)instance; diff --git a/src/parser/cxx/symbols.cc b/src/parser/cxx/symbols.cc index e788837b..3fedef1c 100644 --- a/src/parser/cxx/symbols.cc +++ b/src/parser/cxx/symbols.cc @@ -725,6 +725,28 @@ void TypeAliasSymbol::setTemplateDeclaration( templateDeclaration_ = declaration; } +auto TypeAliasSymbol::specializations() const + -> std::span> { + if (!templateInfo_) return {}; + return templateInfo_->specializations(); +} + +auto TypeAliasSymbol::findSpecialization( + const std::vector& arguments) const -> TypeAliasSymbol* { + if (!templateInfo_) return {}; + return templateInfo_->findSpecialization(arguments); +} + +void TypeAliasSymbol::addSpecialization(std::vector arguments, + TypeAliasSymbol* specialization) { + if (!templateInfo_) { + templateInfo_ = std::make_unique>(this); + } + auto index = templateInfo_->specializations().size(); + specialization->setSpecializationInfo(this, index); + templateInfo_->addSpecialization(std::move(arguments), specialization); +} + VariableSymbol::VariableSymbol(ScopeSymbol* enclosingScope) : Symbol(Kind, enclosingScope) {} diff --git a/src/parser/cxx/symbols.h b/src/parser/cxx/symbols.h index 0cdd63ea..ec732b5b 100644 --- a/src/parser/cxx/symbols.h +++ b/src/parser/cxx/symbols.h @@ -563,8 +563,33 @@ class TypeAliasSymbol final : public Symbol { [[nodiscard]] auto templateDeclaration() const -> TemplateDeclarationAST*; void setTemplateDeclaration(TemplateDeclarationAST* declaration); + [[nodiscard]] auto specializations() const + -> std::span>; + + [[nodiscard]] auto findSpecialization( + const std::vector& arguments) const -> TypeAliasSymbol*; + + void addSpecialization(std::vector arguments, + TypeAliasSymbol* specialization); + + void setSpecializationInfo(TypeAliasSymbol* templateVariable, + std::size_t index) { + templateVariable_ = templateVariable; + templateSepcializationIndex_ = index; + } + + [[nodiscard]] auto templateArguments() const + -> std::span { + if (!templateVariable_) return {}; + return templateVariable_->specializations()[templateSepcializationIndex_] + .arguments; + } + private: TemplateDeclarationAST* templateDeclaration_ = nullptr; + std::unique_ptr> templateInfo_; + TypeAliasSymbol* templateVariable_ = nullptr; + std::size_t templateSepcializationIndex_ = 0; }; class VariableSymbol final : public Symbol {