Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions src/parser/cxx/ast_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,10 @@ auto ASTInterpreter::ExpressionVisitor::operator()(IdExpressionAST* ast)
return enumerator->value();
}

if (auto var = symbol_cast<VariableSymbol>(ast->symbol)) {
return var->constValue();
}

return std::nullopt;
}

Expand Down Expand Up @@ -1989,33 +1993,49 @@ auto ASTInterpreter::ExpressionVisitor::operator()(BinaryExpressionAST* ast)

case TokenKind::T_STAR:
if (control()->is_floating_point(ast->type))
return std::visit(ArithmeticCast<double>{}, *left) +
return std::visit(ArithmeticCast<double>{}, *left) *
std::visit(ArithmeticCast<double>{}, *right);
else if (control()->is_unsigned(ast->type))
return std::visit(ArithmeticCast<std::uint64_t>{}, *left) +
return std::visit(ArithmeticCast<std::uint64_t>{}, *left) *
std::visit(ArithmeticCast<std::uint64_t>{}, *right);
else
return std::visit(ArithmeticCast<std::int64_t>{}, *left) +
return std::visit(ArithmeticCast<std::int64_t>{}, *left) *
std::visit(ArithmeticCast<std::int64_t>{}, *right);

case TokenKind::T_SLASH:
if (control()->is_floating_point(ast->type))
return std::visit(ArithmeticCast<double>{}, *left) +
std::visit(ArithmeticCast<double>{}, *right);
else if (control()->is_unsigned(ast->type))
return std::visit(ArithmeticCast<std::uint64_t>{}, *left) +
std::visit(ArithmeticCast<std::uint64_t>{}, *right);
else
return std::visit(ArithmeticCast<std::int64_t>{}, *left) +
std::visit(ArithmeticCast<std::int64_t>{}, *right);
case TokenKind::T_SLASH: {
if (control()->is_floating_point(ast->type)) {
auto l = std::visit(ArithmeticCast<double>{}, *left);
auto r = std::visit(ArithmeticCast<double>{}, *right);
if (r == 0.0) return std::nullopt;
return l / r;
}

case TokenKind::T_PERCENT:
if (control()->is_unsigned(ast->type))
return std::visit(ArithmeticCast<std::uint64_t>{}, *left) %
std::visit(ArithmeticCast<std::uint64_t>{}, *right);
else
return std::visit(ArithmeticCast<std::int64_t>{}, *left) %
std::visit(ArithmeticCast<std::int64_t>{}, *right);
if (control()->is_unsigned(ast->type)) {
auto l = std::visit(ArithmeticCast<std::uint64_t>{}, *left);
auto r = std::visit(ArithmeticCast<std::uint64_t>{}, *right);
if (r == 0) return std::nullopt;
return l / r;
}

auto l = std::visit(ArithmeticCast<std::int64_t>{}, *left);
auto r = std::visit(ArithmeticCast<std::int64_t>{}, *right);
if (r == 0) return std::nullopt;
return l / r;
}

case TokenKind::T_PERCENT: {
if (control()->is_unsigned(ast->type)) {
auto l = std::visit(ArithmeticCast<std::uint64_t>{}, *left);
auto r = std::visit(ArithmeticCast<std::uint64_t>{}, *right);
if (r == 0) return std::nullopt;
return l % r;
}

auto l = std::visit(ArithmeticCast<std::int64_t>{}, *left);
auto r = std::visit(ArithmeticCast<std::int64_t>{}, *right);
if (r == 0) return std::nullopt;
return l % r;
}

case TokenKind::T_PLUS:
if (control()->is_floating_point(ast->type))
Expand Down
35 changes: 14 additions & 21 deletions src/parser/cxx/ast_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2343,35 +2343,28 @@ auto ASTRewriter::ExpressionVisitor::operator()(NestedExpressionAST* ast)

auto ASTRewriter::ExpressionVisitor::operator()(IdExpressionAST* ast)
-> ExpressionAST* {
if (auto x = symbol_cast<NonTypeParameterSymbol>(ast->symbol);
x && x->depth() == 0 && x->index() < rewrite.templateArguments_.size()) {
auto initializerPtr =
std::get_if<ExpressionAST*>(&rewrite.templateArguments_[x->index()]);
if (!initializerPtr) {
cxx_runtime_error("expected initializer for non-type template parameter");
}

auto initializer = rewrite(*initializerPtr);

if (auto eq = ast_cast<EqualInitializerAST>(initializer)) {
return eq->expression;
}

if (auto bracedInit = ast_cast<BracedInitListAST>(initializer)) {
if (bracedInit->expressionList && !bracedInit->expressionList->next) {
return bracedInit->expressionList->value;
}
}
}

auto copy = make_node<IdExpressionAST>(arena());

copy->valueCategory = ast->valueCategory;
copy->type = ast->type;
copy->nestedNameSpecifier = rewrite(ast->nestedNameSpecifier);
copy->templateLoc = ast->templateLoc;
copy->unqualifiedId = rewrite(ast->unqualifiedId);

copy->symbol = ast->symbol;

if (auto x = symbol_cast<NonTypeParameterSymbol>(copy->symbol);
x && x->depth() == 0 && x->index() < rewrite.templateArguments_.size()) {
auto initializerPtr =
std::get_if<Symbol*>(&rewrite.templateArguments_[x->index()]);
if (!initializerPtr) {
cxx_runtime_error("expected initializer for non-type template parameter");
}

copy->symbol = *initializerPtr;
copy->type = copy->symbol->type();
}

copy->isTemplateIntroduced = ast->isTemplateIntroduced;

return copy;
Expand Down
2 changes: 2 additions & 0 deletions src/parser/cxx/name_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ struct NamePrinter {

auto operator()(const ConstValue& value) const -> std::string { return {}; }

auto operator()(const Symbol* symbol) const -> std::string { return {}; }

auto operator()(ExpressionAST* value) const -> std::string { return {}; }

} template_argument_to_string;
Expand Down
5 changes: 4 additions & 1 deletion src/parser/cxx/names_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ class Name;
CXX_FOR_EACH_NAME(PROCESS_NAME)
#undef PROCESS_NAME

using TemplateArgument = std::variant<const Type*, ConstValue, ExpressionAST*>;
class Symbol;

using TemplateArgument =
std::variant<const Type*, Symbol*, ConstValue, ExpressionAST*>;

enum class IdentifierInfoKind {
kTypeTrait,
Expand Down
8 changes: 8 additions & 0 deletions src/parser/cxx/symbols.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,14 @@ void VariableSymbol::setInitializer(ExpressionAST* initializer) {
initializer_ = initializer;
}

auto VariableSymbol::constValue() const -> const std::optional<ConstValue>& {
return constValue_;
}

void VariableSymbol::setConstValue(std::optional<ConstValue> value) {
constValue_ = std::move(value);
}

FieldSymbol::FieldSymbol(Scope* enclosingScope)
: Symbol(Kind, enclosingScope) {}

Expand Down
4 changes: 4 additions & 0 deletions src/parser/cxx/symbols.h
Original file line number Diff line number Diff line change
Expand Up @@ -547,10 +547,14 @@ class VariableSymbol final : public Symbol {
[[nodiscard]] auto initializer() const -> ExpressionAST*;
void setInitializer(ExpressionAST*);

[[nodiscard]] auto constValue() const -> const std::optional<ConstValue>&;
void setConstValue(std::optional<ConstValue> value);

private:
TemplateParametersSymbol* templateParameters_ = nullptr;
TemplateDeclarationAST* templateDeclaration_ = nullptr;
ExpressionAST* initializer_ = nullptr;
std::optional<ConstValue> constValue_;

union {
std::uint32_t flags_{};
Expand Down
85 changes: 78 additions & 7 deletions tests/api_tests/test_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,39 +89,110 @@ using Func = void(T, const U&);
TEST(Rewriter, Var) {
auto source = R"(
template <int i>
const int c = i + 321;
const int c = i + 321 + i;

constexpr int x = 123;
constexpr int x = 123 * 2;

constexpr int y = c<123 * 2>;
)"_cxx;

auto interp = ASTInterpreter{&source.unit};

auto control = source.control();

auto c = source.getAs<VariableSymbol>("c");
ASSERT_TRUE(c != nullptr);
auto templateDeclaration = c->templateDeclaration();
ASSERT_TRUE(templateDeclaration != nullptr);

// extract the expression 123 * 2 from the AST
auto x = source.getAs<VariableSymbol>("x");
ASSERT_TRUE(x != nullptr);
auto xinit = x->initializer();
auto xinit = ast_cast<EqualInitializerAST>(x->initializer())->expression;
ASSERT_TRUE(xinit != nullptr);

// synthesize const auto i = 123 * 2;

// ### need to set scope and location
auto templArg = control->newVariableSymbol(nullptr, {});
templArg->setInitializer(xinit);
templArg->setType(control->add_const(x->type()));
templArg->setConstValue(interp.evaluate(xinit));
ASSERT_TRUE(templArg->constValue().has_value());

auto instance = subst(
source, getTemplateBodyAs<SimpleDeclarationAST>(templateDeclaration),
{xinit});
{templArg});

auto decl = instance->initDeclaratorList->value;
ASSERT_TRUE(decl != nullptr);

auto init = ast_cast<EqualInitializerAST>(decl->initializer);
ASSERT_TRUE(init);

ASTInterpreter interp{&source.unit};

auto value = interp.evaluate(init->expression);

ASSERT_TRUE(value.has_value());

ASSERT_EQ(std::visit(ArithmeticCast<int>{}, *value), 123 + 321);
ASSERT_EQ(std::visit(ArithmeticCast<int>{}, *value), 123 * 2 + 321 + 123 * 2);
}

// simulate a template-id instantiation
TEST(Rewriter, TemplateId) {
auto source = R"(
template <int i>
const int c = i + 321 + i;

constexpr int y = c<123 * 2>;
)"_cxx;

auto interp = ASTInterpreter{&source.unit};

auto control = source.control();

auto y = source.getAs<VariableSymbol>("y");
ASSERT_TRUE(y != nullptr);
auto yinit = ast_cast<EqualInitializerAST>(y->initializer())->expression;
ASSERT_TRUE(yinit != nullptr);

auto idExpr = ast_cast<IdExpressionAST>(yinit);
ASSERT_TRUE(idExpr != nullptr);

ASSERT_TRUE(idExpr->symbol);

auto templateId = ast_cast<SimpleTemplateIdAST>(idExpr->unqualifiedId);
ASSERT_TRUE(templateId != nullptr);

// get the primary template declaration
auto templateSym =
symbol_cast<VariableSymbol>(templateId->primaryTemplateSymbol);
ASSERT_TRUE(templateSym != nullptr);
auto templateDecl = getTemplateBodyAs<SimpleDeclarationAST>(
templateSym->templateDeclaration());
ASSERT_TRUE(templateDecl != nullptr);

std::vector<TemplateArgument> templateArguments;
for (auto arg : ListView{templateId->templateArgumentList}) {
if (auto exprArg = ast_cast<ExpressionTemplateArgumentAST>(arg)) {
auto expr = exprArg->expression;
// ### need to set scope and location
auto templArg = control->newVariableSymbol(nullptr, {});
templArg->setInitializer(expr);
templArg->setType(control->add_const(expr->type));
templArg->setConstValue(interp.evaluate(expr));
ASSERT_TRUE(templArg->constValue().has_value());
templateArguments.push_back(templArg);
}
}

auto instance = subst(source, templateDecl, templateArguments);
ASSERT_TRUE(instance != nullptr);

auto decl = instance->initDeclaratorList->value;
ASSERT_TRUE(decl != nullptr);
auto init = ast_cast<EqualInitializerAST>(decl->initializer);
ASSERT_TRUE(init != nullptr);
auto value = interp.evaluate(init->expression);
ASSERT_TRUE(value.has_value());
ASSERT_EQ(std::visit(ArithmeticCast<int>{}, *value), 123 * 2 + 321 + 123 * 2);
}