diff --git a/dltype/_lib/_parser.py b/dltype/_lib/_parser.py index f243255..1149558 100644 --- a/dltype/_lib/_parser.py +++ b/dltype/_lib/_parser.py @@ -3,7 +3,9 @@ from __future__ import annotations import enum +import itertools import logging +import math import re from typing import Final @@ -35,6 +37,13 @@ class _DLTypeOperator(enum.Enum): DIV = "/" MIN = "min" MAX = "max" + ISQRT = "isqrt" + + def evaluate_unary(self, a: int) -> int: + """Evaluate the unary operator.""" + if self is _DLTypeOperator.ISQRT: + return math.isqrt(a) + raise NotImplementedError(self) def evaluate(self, a: int, b: int) -> int: # noqa: PLR0911 """Evaluate the operator.""" @@ -63,9 +72,12 @@ def evaluate(self, a: int, b: int) -> int: # noqa: PLR0911 _DLTypeOperator.EXP: 3, _DLTypeOperator.MIN: 0, _DLTypeOperator.MAX: 0, + _DLTypeOperator.ISQRT: 0, } -_functional_operators: Final = frozenset({_DLTypeOperator.MIN, _DLTypeOperator.MAX}) +_unary_operators: Final = frozenset({_DLTypeOperator.ISQRT}) +_binary_operators: Final = frozenset({_DLTypeOperator.MIN, _DLTypeOperator.MAX}) +_functional_operators: Final = frozenset(_unary_operators.union(_binary_operators)) _valid_operators: frozenset[str] = frozenset( {op.value for op in _DLTypeOperator if op not in _functional_operators}, ) @@ -174,6 +186,9 @@ def evaluate(self, scope: dict[str, int]) -> int: stack.append(scope[token]) elif isinstance(token, _DLTypeOperator): # pyright: ignore[reportUnnecessaryIsInstance] b = stack.pop() + if token in _unary_operators: + stack.append(token.evaluate_unary(b)) + continue a = stack.pop() stack.append(token.evaluate(a, b)) else: @@ -265,19 +280,16 @@ def _maybe_parse_functional_expression( # Strip function name and opening parenthesis content = expression[len(function.value) + 1 :] - # Must end with closing parenthesis - if not content.endswith(")"): - msg = f"Invalid function expression: {expression}, missing closing parenthesis" - raise SyntaxError(msg) - # Remove closing parenthesis content = content[:-1] # Find the comma that separates arguments (accounting for nesting) depth = 0 - comma_index = -1 + balanced_content: list[str] = [] + current_span = "" - for i, char in enumerate(content): + for char in content: + current_span += char if char == "(": depth += 1 elif char == ")": @@ -286,25 +298,23 @@ def _maybe_parse_functional_expression( msg = f"Unbalanced parentheses in function expression: {expression}" raise SyntaxError(msg) elif char == "," and depth == 0: - comma_index = i - break + balanced_content.append(current_span[:-1]) + current_span = "" + balanced_content.append(current_span) - if comma_index == -1: - msg = f"Invalid function expression: {expression}, expected two arguments separated by comma" + if function in _binary_operators and len(balanced_content) != 2: # noqa: PLR2004 + msg = f"Function {function.value} requires 2 arguments, got {len(balanced_content)} in {expression=}" + raise SyntaxError(msg) + if function in _unary_operators and len(balanced_content) != 1: + msg = f"Function {function.value} requires 1 argument, got {len(balanced_content)} in {expression=}" raise SyntaxError(msg) - # Split arguments and parse recursively - arg1 = content[:comma_index].strip() - arg2 = content[comma_index + 1 :].strip() - - # Recursively parse both arguments - expr_1 = expression_from_string(arg1) - expr_2 = expression_from_string(arg2) + expressions = [expression_from_string(exp) for exp in balanced_content] # Build postfix expression: [arg1 tokens, arg2 tokens, function] return DLTypeDimensionExpression( identifier, - [*expr_1.parsed_expression, *expr_2.parsed_expression, function], + [*list(itertools.chain(*[exp.parsed_expression for exp in expressions])), function], ) diff --git a/dltype/tests/parser_test.py b/dltype/tests/parser_test.py index e28144a..31d79e7 100644 --- a/dltype/tests/parser_test.py +++ b/dltype/tests/parser_test.py @@ -31,6 +31,11 @@ ("max(3^x,min(3-y,99))", {"x": 2, "y": 4}, 9), ("min(3^x,3-y)", {"x": 2, "y": 4}, -1), ("variable_name_with_underscores", {"variable_name_with_underscores": 1}, 1), + ("isqrt(5)", {}, 2), + ("isqrt(16)", {}, 4), + ("isqrt(x-y)", {"x": 20, "y": 5}, 3), + ("min(isqrt(20),isqrt(16))", {}, 4), + ("max(isqrt(20),isqrt(16))", {}, 4), ], ) def test_parse_expression( @@ -52,6 +57,10 @@ def test_parse_expression( ("*batch", {}), ("3**2", {}), ("^", {}), + ("isqrt(4, 5)", {}), + ("isqrt()", {}), + ("max(1)", {}), + ("min()", {}), ], ) def test_parse_invalid_expression(expression: str, scope: dict[str, int]) -> None: