diff --git a/compiler/include/compiler/frontend/parser/type_registry.hpp b/compiler/include/compiler/frontend/parser/type_registry.hpp index 3df2db10..bcbb0352 100644 --- a/compiler/include/compiler/frontend/parser/type_registry.hpp +++ b/compiler/include/compiler/frontend/parser/type_registry.hpp @@ -1,8 +1,5 @@ #pragma once -#include -#include - #include "compiler/ast/types.hpp" #include "compiler/frontend/lexer/token.hpp" @@ -10,9 +7,6 @@ namespace parser { class TypeRegistry { - private: - static std::map userDefinedTypes; - public: static bool isTypename(const lexer::Token &token); static ast::TypeId typeId(const lexer::Token &token); diff --git a/compiler/lib/frontend/parser/parser.cpp b/compiler/lib/frontend/parser/parser.cpp index aaa4f13f..bf1810ad 100644 --- a/compiler/lib/frontend/parser/parser.cpp +++ b/compiler/lib/frontend/parser/parser.cpp @@ -12,6 +12,7 @@ #include "compiler/ast/node.hpp" #include "compiler/ast/node_type.hpp" +#include "compiler/ast/types.hpp" #include "compiler/utils/error_buffer.hpp" #include "lexer/token.hpp" @@ -471,6 +472,28 @@ std::stack generatePostfixForm(TokenIterator tokenIterBegin, Toke return postfixForm; } +bool isElementaryType(TypeId typeId) { + return typeId == IntType || typeId == FloatType || typeId == BoolType || typeId == NoneType; +} + +void parseType(ParserContext &ctx) { + ctx.node = ctx.pushChildNode(NodeType::TypeName); + ctx.node->value = TypeRegistry::typeId(ctx.token()); + if (ctx.token().is(Keyword::List)) { + const Token &leftBrace = (ctx.goNextToken(), ctx.token()); + const Token &varTypeList = (ctx.goNextToken(), ctx.token()); + auto typeId = TypeRegistry::typeId(varTypeList); + const Token &rightBrace = (ctx.goNextToken(), ctx.token()); + if (!leftBrace.is(Operator::RectLeftBrace) || !isElementaryType(typeId) || + !rightBrace.is(Operator::RectRightBrace)) { + ctx.pushError("Unexpected syntax for list declaration"); + } + auto node = ctx.pushChildNode(NodeType::TypeName); + node->value = typeId; + } + ctx.goParentNode(); +} + void parseSimpleStatement(ParserContext &ctx) { assert(ctx.tokenIter->is(Keyword::Break) || ctx.tokenIter->is(Keyword::Continue) || ctx.tokenIter->is(Keyword::Pass)); @@ -593,26 +616,23 @@ void parseFunctionArguments(ParserContext &ctx) { assert(ctx.token().is(Operator::LeftBrace)); ctx.goNextToken(); while (!ctx.token().is(Operator::RightBrace)) { - const Token &argName = *ctx.tokenIter; - const Token &colon = *std::next(ctx.tokenIter); - const Token &argType = *std::next(ctx.tokenIter, 2); + const Token &argName = ctx.token(); + const Token &colon = (ctx.goNextToken(), ctx.token()); + const Token &argType = (ctx.goNextToken(), ctx.token()); if (argName.type != TokenType::Identifier || !colon.is(Special::Colon) || !TypeRegistry::isTypename(argType)) { ctx.pushError("Function argument declaration is ill-formed"); while (!ctx.token().is(Operator::RightBrace) && !ctx.token().is(Special::Colon)) ctx.goNextToken(); break; } - auto node = ctx.pushChildNode(NodeType::FunctionArgument); - auto argTypeNode = ParserContext::pushChildNode(node, NodeType::TypeName, argType.ref); - argTypeNode->value = TypeRegistry::typeId(argType); - auto argNameNode = ParserContext::pushChildNode(node, NodeType::VariableName, argName.ref); + ctx.node = ctx.pushChildNode(NodeType::FunctionArgument); + parseType(ctx); + auto argNameNode = ParserContext::pushChildNode(ctx.node, NodeType::VariableName, argName.ref); argNameNode->value = argName.id(); - - const Token &last = *std::next(ctx.tokenIter, 3); - if (last.is(Operator::Comma)) - std::advance(ctx.tokenIter, 4); - else - std::advance(ctx.tokenIter, 3); + ctx.goParentNode(); + ctx.goNextToken(); + if (ctx.token().is(Operator::Comma)) + ctx.goNextToken(); } ctx.goParentNode(); ctx.goNextToken(); @@ -638,7 +658,10 @@ void parseFunctionDefinition(ParserContext &ctx) { if (!TypeRegistry::isTypename(ctx.token())) { ctx.pushError("Type name not found"); } - ctx.pushChildNode(NodeType::FunctionReturnType)->value = TypeRegistry::typeId(ctx.token()); + auto retTypeId = TypeRegistry::typeId(ctx.token()); + if (!isElementaryType(retTypeId)) + ctx.pushError("Function return type must be one of the following: int, float, bool, None"); + ctx.pushChildNode(NodeType::FunctionReturnType)->value = retTypeId; ctx.goNextToken(); if (!ctx.token().is(Special::Colon)) { ctx.pushError("Colon expected at the end of function header"); @@ -697,24 +720,11 @@ void parseReturnStatement(ParserContext &ctx) { } void parseVariableDeclaration(ParserContext &ctx) { + const Token &varName = ctx.token(); ctx.goNextToken(); - const Token &colon = ctx.token(); - const Token &varName = *std::prev(ctx.tokenIter); - const Token &varType = (std::advance(ctx.tokenIter, 1), ctx.token()); - auto node = ctx.pushChildNode(NodeType::TypeName); - node->value = TypeRegistry::typeId(varType); - bool isListType = varType.is(Keyword::List); - if (isListType) { - const Token &leftBrace = (std::advance(ctx.tokenIter, 1), ctx.token()); - const Token &varTypeList = (std::advance(ctx.tokenIter, 1), ctx.token()); - const Token &rightBrace = (std::advance(ctx.tokenIter, 1), ctx.token()); - if (!leftBrace.is(Operator::RectLeftBrace) || !rightBrace.is(Operator::RectRightBrace)) { - ctx.pushError("Unexepted syntax for list declaration"); - } - auto listTypeNode = ParserContext::pushChildNode(node, NodeType::TypeName, ctx.tokenIter->ref); - listTypeNode->value = TypeRegistry::typeId(varTypeList); - } - node = ctx.pushChildNode(NodeType::VariableName); + const Token &varType = (ctx.goNextToken(), ctx.token()); + parseType(ctx); + auto node = ctx.pushChildNode(NodeType::VariableName); node->value = varName.id(); auto endOfDecl = std::next(ctx.tokenIter); @@ -725,7 +735,7 @@ void parseVariableDeclaration(ParserContext &ctx) { } else if (endOfDecl->is(Operator::Assign)) { // declaration with definition ctx.node = ctx.pushChildNode(NodeType::Expression); - if (isListType) { + if (varType.is(Keyword::List)) { ctx.node = ctx.pushChildNode(NodeType::ListStatement); } std::advance(ctx.tokenIter, 2); diff --git a/compiler/lib/frontend/parser/type_registry.cpp b/compiler/lib/frontend/parser/type_registry.cpp index af2d7cf9..b9ca44b0 100644 --- a/compiler/lib/frontend/parser/type_registry.cpp +++ b/compiler/lib/frontend/parser/type_registry.cpp @@ -1,16 +1,25 @@ #include "parser/type_registry.hpp" + +#include +#include + #include "compiler/ast/types.hpp" + #include "lexer/token_types.hpp" using ast::TypeId; using namespace lexer; using namespace parser; -std::map TypeRegistry::userDefinedTypes = {}; +namespace { + +std::map userDefinedTypes = {}; + +} // namespace bool TypeRegistry::isTypename(const Token &token) { return token.is(Keyword::Int) || token.is(Keyword::Float) || token.is(Keyword::Bool) || token.is(Keyword::Str) || - token.is(Keyword::None) || + token.is(Keyword::None) || token.is(Keyword::List) || (token.type == TokenType::Identifier && userDefinedTypes.find(token.id()) != userDefinedTypes.end()); } diff --git a/compiler/tests/frontend/parser.cpp b/compiler/tests/frontend/parser.cpp index 29dd4231..3d0727f6 100644 --- a/compiler/tests/frontend/parser.cpp +++ b/compiler/tests/frontend/parser.cpp @@ -1448,3 +1448,37 @@ TEST(Parser, can_throw_error_for_out_of_range_scientific_float_literals) { TokenList tokens = Lexer::process(source); ASSERT_ANY_THROW(Parser::process(tokens)); } + +TEST(Parser, can_parse_function_arguments) { + StringVec source = { + "def main(v1: int, v2: float, v3: bool, v5: list[int], v6: list[float]) -> None:", + " return", + }; + TokenList tokens = Lexer::process(source); + SyntaxTree tree = Parser::process(tokens); + std::string expected = "ProgramRoot\n" + " FunctionDefinition\n" + " FunctionName: main\n" + " FunctionArguments\n" + " FunctionArgument\n" + " TypeName: IntType\n" + " VariableName: v1\n" + " FunctionArgument\n" + " TypeName: FloatType\n" + " VariableName: v2\n" + " FunctionArgument\n" + " TypeName: BoolType\n" + " VariableName: v3\n" + " FunctionArgument\n" + " TypeName: ListType\n" + " TypeName: IntType\n" + " VariableName: v5\n" + " FunctionArgument\n" + " TypeName: ListType\n" + " TypeName: FloatType\n" + " VariableName: v6\n" + " FunctionReturnType: NoneType\n" + " BranchRoot\n" + " ReturnStatement\n"; + ASSERT_EQ(expected, tree.dump()); +}