From 85459a08871bc6418f2d8f1e16cd9d6381bc17d5 Mon Sep 17 00:00:00 2001 From: Roberto Raggi Date: Sat, 15 Mar 2025 15:44:41 +0100 Subject: [PATCH] Improve substituion of non-type template parameters --- src/parser/cxx/ast_interpreter.cc | 60 ++++++++++++++-------- src/parser/cxx/ast_rewriter.cc | 35 +++++-------- src/parser/cxx/name_printer.cc | 2 + src/parser/cxx/names_fwd.h | 5 +- src/parser/cxx/symbols.cc | 8 +++ src/parser/cxx/symbols.h | 4 ++ tests/api_tests/test_rewriter.cc | 85 ++++++++++++++++++++++++++++--- 7 files changed, 150 insertions(+), 49 deletions(-) diff --git a/src/parser/cxx/ast_interpreter.cc b/src/parser/cxx/ast_interpreter.cc index aed8fa88..9f8e2be3 100644 --- a/src/parser/cxx/ast_interpreter.cc +++ b/src/parser/cxx/ast_interpreter.cc @@ -1634,6 +1634,10 @@ auto ASTInterpreter::ExpressionVisitor::operator()(IdExpressionAST* ast) return enumerator->value(); } + if (auto var = symbol_cast(ast->symbol)) { + return var->constValue(); + } + return std::nullopt; } @@ -1989,33 +1993,49 @@ auto ASTInterpreter::ExpressionVisitor::operator()(BinaryExpressionAST* ast) case TokenKind::T_STAR: if (control()->is_floating_point(ast->type)) - return std::visit(ArithmeticCast{}, *left) + + return std::visit(ArithmeticCast{}, *left) * std::visit(ArithmeticCast{}, *right); else if (control()->is_unsigned(ast->type)) - return std::visit(ArithmeticCast{}, *left) + + return std::visit(ArithmeticCast{}, *left) * std::visit(ArithmeticCast{}, *right); else - return std::visit(ArithmeticCast{}, *left) + + return std::visit(ArithmeticCast{}, *left) * std::visit(ArithmeticCast{}, *right); - case TokenKind::T_SLASH: - if (control()->is_floating_point(ast->type)) - return std::visit(ArithmeticCast{}, *left) + - std::visit(ArithmeticCast{}, *right); - else if (control()->is_unsigned(ast->type)) - return std::visit(ArithmeticCast{}, *left) + - std::visit(ArithmeticCast{}, *right); - else - return std::visit(ArithmeticCast{}, *left) + - std::visit(ArithmeticCast{}, *right); + case TokenKind::T_SLASH: { + if (control()->is_floating_point(ast->type)) { + auto l = std::visit(ArithmeticCast{}, *left); + auto r = std::visit(ArithmeticCast{}, *right); + if (r == 0.0) return std::nullopt; + return l / r; + } - case TokenKind::T_PERCENT: - if (control()->is_unsigned(ast->type)) - return std::visit(ArithmeticCast{}, *left) % - std::visit(ArithmeticCast{}, *right); - else - return std::visit(ArithmeticCast{}, *left) % - std::visit(ArithmeticCast{}, *right); + if (control()->is_unsigned(ast->type)) { + auto l = std::visit(ArithmeticCast{}, *left); + auto r = std::visit(ArithmeticCast{}, *right); + if (r == 0) return std::nullopt; + return l / r; + } + + auto l = std::visit(ArithmeticCast{}, *left); + auto r = std::visit(ArithmeticCast{}, *right); + if (r == 0) return std::nullopt; + return l / r; + } + + case TokenKind::T_PERCENT: { + if (control()->is_unsigned(ast->type)) { + auto l = std::visit(ArithmeticCast{}, *left); + auto r = std::visit(ArithmeticCast{}, *right); + if (r == 0) return std::nullopt; + return l % r; + } + + auto l = std::visit(ArithmeticCast{}, *left); + auto r = std::visit(ArithmeticCast{}, *right); + if (r == 0) return std::nullopt; + return l % r; + } case TokenKind::T_PLUS: if (control()->is_floating_point(ast->type)) diff --git a/src/parser/cxx/ast_rewriter.cc b/src/parser/cxx/ast_rewriter.cc index aca0e2e2..474a59ec 100644 --- a/src/parser/cxx/ast_rewriter.cc +++ b/src/parser/cxx/ast_rewriter.cc @@ -2343,27 +2343,6 @@ auto ASTRewriter::ExpressionVisitor::operator()(NestedExpressionAST* ast) auto ASTRewriter::ExpressionVisitor::operator()(IdExpressionAST* ast) -> ExpressionAST* { - if (auto x = symbol_cast(ast->symbol); - x && x->depth() == 0 && x->index() < rewrite.templateArguments_.size()) { - auto initializerPtr = - std::get_if(&rewrite.templateArguments_[x->index()]); - if (!initializerPtr) { - cxx_runtime_error("expected initializer for non-type template parameter"); - } - - auto initializer = rewrite(*initializerPtr); - - if (auto eq = ast_cast(initializer)) { - return eq->expression; - } - - if (auto bracedInit = ast_cast(initializer)) { - if (bracedInit->expressionList && !bracedInit->expressionList->next) { - return bracedInit->expressionList->value; - } - } - } - auto copy = make_node(arena()); copy->valueCategory = ast->valueCategory; @@ -2371,7 +2350,21 @@ auto ASTRewriter::ExpressionVisitor::operator()(IdExpressionAST* ast) copy->nestedNameSpecifier = rewrite(ast->nestedNameSpecifier); copy->templateLoc = ast->templateLoc; copy->unqualifiedId = rewrite(ast->unqualifiedId); + copy->symbol = ast->symbol; + + if (auto x = symbol_cast(copy->symbol); + x && x->depth() == 0 && x->index() < rewrite.templateArguments_.size()) { + auto initializerPtr = + std::get_if(&rewrite.templateArguments_[x->index()]); + if (!initializerPtr) { + cxx_runtime_error("expected initializer for non-type template parameter"); + } + + copy->symbol = *initializerPtr; + copy->type = copy->symbol->type(); + } + copy->isTemplateIntroduced = ast->isTemplateIntroduced; return copy; diff --git a/src/parser/cxx/name_printer.cc b/src/parser/cxx/name_printer.cc index 1851920b..d5a949a5 100644 --- a/src/parser/cxx/name_printer.cc +++ b/src/parser/cxx/name_printer.cc @@ -36,6 +36,8 @@ struct NamePrinter { auto operator()(const ConstValue& value) const -> std::string { return {}; } + auto operator()(const Symbol* symbol) const -> std::string { return {}; } + auto operator()(ExpressionAST* value) const -> std::string { return {}; } } template_argument_to_string; diff --git a/src/parser/cxx/names_fwd.h b/src/parser/cxx/names_fwd.h index bed5a10b..ef11ea44 100644 --- a/src/parser/cxx/names_fwd.h +++ b/src/parser/cxx/names_fwd.h @@ -47,7 +47,10 @@ class Name; CXX_FOR_EACH_NAME(PROCESS_NAME) #undef PROCESS_NAME -using TemplateArgument = std::variant; +class Symbol; + +using TemplateArgument = + std::variant; enum class IdentifierInfoKind { kTypeTrait, diff --git a/src/parser/cxx/symbols.cc b/src/parser/cxx/symbols.cc index 830633ee..7fef6a22 100644 --- a/src/parser/cxx/symbols.cc +++ b/src/parser/cxx/symbols.cc @@ -563,6 +563,14 @@ void VariableSymbol::setInitializer(ExpressionAST* initializer) { initializer_ = initializer; } +auto VariableSymbol::constValue() const -> const std::optional& { + return constValue_; +} + +void VariableSymbol::setConstValue(std::optional value) { + constValue_ = std::move(value); +} + FieldSymbol::FieldSymbol(Scope* enclosingScope) : Symbol(Kind, enclosingScope) {} diff --git a/src/parser/cxx/symbols.h b/src/parser/cxx/symbols.h index c806a939..1c7be357 100644 --- a/src/parser/cxx/symbols.h +++ b/src/parser/cxx/symbols.h @@ -547,10 +547,14 @@ class VariableSymbol final : public Symbol { [[nodiscard]] auto initializer() const -> ExpressionAST*; void setInitializer(ExpressionAST*); + [[nodiscard]] auto constValue() const -> const std::optional&; + void setConstValue(std::optional value); + private: TemplateParametersSymbol* templateParameters_ = nullptr; TemplateDeclarationAST* templateDeclaration_ = nullptr; ExpressionAST* initializer_ = nullptr; + std::optional constValue_; union { std::uint32_t flags_{}; diff --git a/tests/api_tests/test_rewriter.cc b/tests/api_tests/test_rewriter.cc index b915ab1b..142f3830 100644 --- a/tests/api_tests/test_rewriter.cc +++ b/tests/api_tests/test_rewriter.cc @@ -89,12 +89,15 @@ using Func = void(T, const U&); TEST(Rewriter, Var) { auto source = R"( template -const int c = i + 321; +const int c = i + 321 + i; -constexpr int x = 123; +constexpr int x = 123 * 2; +constexpr int y = c<123 * 2>; )"_cxx; + auto interp = ASTInterpreter{&source.unit}; + auto control = source.control(); auto c = source.getAs("c"); @@ -102,14 +105,24 @@ constexpr int x = 123; auto templateDeclaration = c->templateDeclaration(); ASSERT_TRUE(templateDeclaration != nullptr); + // extract the expression 123 * 2 from the AST auto x = source.getAs("x"); ASSERT_TRUE(x != nullptr); - auto xinit = x->initializer(); + auto xinit = ast_cast(x->initializer())->expression; ASSERT_TRUE(xinit != nullptr); + // synthesize const auto i = 123 * 2; + + // ### need to set scope and location + auto templArg = control->newVariableSymbol(nullptr, {}); + templArg->setInitializer(xinit); + templArg->setType(control->add_const(x->type())); + templArg->setConstValue(interp.evaluate(xinit)); + ASSERT_TRUE(templArg->constValue().has_value()); + auto instance = subst( source, getTemplateBodyAs(templateDeclaration), - {xinit}); + {templArg}); auto decl = instance->initDeclaratorList->value; ASSERT_TRUE(decl != nullptr); @@ -117,11 +130,69 @@ constexpr int x = 123; auto init = ast_cast(decl->initializer); ASSERT_TRUE(init); - ASTInterpreter interp{&source.unit}; - auto value = interp.evaluate(init->expression); ASSERT_TRUE(value.has_value()); - ASSERT_EQ(std::visit(ArithmeticCast{}, *value), 123 + 321); + ASSERT_EQ(std::visit(ArithmeticCast{}, *value), 123 * 2 + 321 + 123 * 2); +} + +// simulate a template-id instantiation +TEST(Rewriter, TemplateId) { + auto source = R"( +template +const int c = i + 321 + i; + +constexpr int y = c<123 * 2>; +)"_cxx; + + auto interp = ASTInterpreter{&source.unit}; + + auto control = source.control(); + + auto y = source.getAs("y"); + ASSERT_TRUE(y != nullptr); + auto yinit = ast_cast(y->initializer())->expression; + ASSERT_TRUE(yinit != nullptr); + + auto idExpr = ast_cast(yinit); + ASSERT_TRUE(idExpr != nullptr); + + ASSERT_TRUE(idExpr->symbol); + + auto templateId = ast_cast(idExpr->unqualifiedId); + ASSERT_TRUE(templateId != nullptr); + + // get the primary template declaration + auto templateSym = + symbol_cast(templateId->primaryTemplateSymbol); + ASSERT_TRUE(templateSym != nullptr); + auto templateDecl = getTemplateBodyAs( + templateSym->templateDeclaration()); + ASSERT_TRUE(templateDecl != nullptr); + + std::vector templateArguments; + for (auto arg : ListView{templateId->templateArgumentList}) { + if (auto exprArg = ast_cast(arg)) { + auto expr = exprArg->expression; + // ### need to set scope and location + auto templArg = control->newVariableSymbol(nullptr, {}); + templArg->setInitializer(expr); + templArg->setType(control->add_const(expr->type)); + templArg->setConstValue(interp.evaluate(expr)); + ASSERT_TRUE(templArg->constValue().has_value()); + templateArguments.push_back(templArg); + } + } + + auto instance = subst(source, templateDecl, templateArguments); + ASSERT_TRUE(instance != nullptr); + + auto decl = instance->initDeclaratorList->value; + ASSERT_TRUE(decl != nullptr); + auto init = ast_cast(decl->initializer); + ASSERT_TRUE(init != nullptr); + auto value = interp.evaluate(init->expression); + ASSERT_TRUE(value.has_value()); + ASSERT_EQ(std::visit(ArithmeticCast{}, *value), 123 * 2 + 321 + 123 * 2); }