Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions dltype/_lib/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from __future__ import annotations

import enum
import itertools
import logging
import math
import re
from typing import Final

Expand Down Expand Up @@ -35,6 +37,13 @@ class _DLTypeOperator(enum.Enum):
DIV = "/"
MIN = "min"
MAX = "max"
ISQRT = "isqrt"

def evaluate_unary(self, a: int) -> int:
"""Evaluate the unary operator."""
if self is _DLTypeOperator.ISQRT:
return math.isqrt(a)
raise NotImplementedError(self)

def evaluate(self, a: int, b: int) -> int: # noqa: PLR0911
"""Evaluate the operator."""
Expand Down Expand Up @@ -63,9 +72,12 @@ def evaluate(self, a: int, b: int) -> int: # noqa: PLR0911
_DLTypeOperator.EXP: 3,
_DLTypeOperator.MIN: 0,
_DLTypeOperator.MAX: 0,
_DLTypeOperator.ISQRT: 0,
}

_functional_operators: Final = frozenset({_DLTypeOperator.MIN, _DLTypeOperator.MAX})
_unary_operators: Final = frozenset({_DLTypeOperator.ISQRT})
_binary_operators: Final = frozenset({_DLTypeOperator.MIN, _DLTypeOperator.MAX})
_functional_operators: Final = frozenset(_unary_operators.union(_binary_operators))
_valid_operators: frozenset[str] = frozenset(
{op.value for op in _DLTypeOperator if op not in _functional_operators},
)
Expand Down Expand Up @@ -174,6 +186,9 @@ def evaluate(self, scope: dict[str, int]) -> int:
stack.append(scope[token])
elif isinstance(token, _DLTypeOperator): # pyright: ignore[reportUnnecessaryIsInstance]
b = stack.pop()
if token in _unary_operators:
stack.append(token.evaluate_unary(b))
continue
a = stack.pop()
stack.append(token.evaluate(a, b))
else:
Expand Down Expand Up @@ -265,19 +280,16 @@ def _maybe_parse_functional_expression(
# Strip function name and opening parenthesis
content = expression[len(function.value) + 1 :]

# Must end with closing parenthesis
if not content.endswith(")"):
msg = f"Invalid function expression: {expression}, missing closing parenthesis"
raise SyntaxError(msg)

# Remove closing parenthesis
content = content[:-1]

# Find the comma that separates arguments (accounting for nesting)
depth = 0
comma_index = -1
balanced_content: list[str] = []
current_span = ""

for i, char in enumerate(content):
for char in content:
current_span += char
if char == "(":
depth += 1
elif char == ")":
Expand All @@ -286,25 +298,23 @@ def _maybe_parse_functional_expression(
msg = f"Unbalanced parentheses in function expression: {expression}"
raise SyntaxError(msg)
elif char == "," and depth == 0:
comma_index = i
break
balanced_content.append(current_span[:-1])
current_span = ""
balanced_content.append(current_span)

if comma_index == -1:
msg = f"Invalid function expression: {expression}, expected two arguments separated by comma"
if function in _binary_operators and len(balanced_content) != 2: # noqa: PLR2004
msg = f"Function {function.value} requires 2 arguments, got {len(balanced_content)} in {expression=}"
raise SyntaxError(msg)
if function in _unary_operators and len(balanced_content) != 1:
msg = f"Function {function.value} requires 1 argument, got {len(balanced_content)} in {expression=}"
raise SyntaxError(msg)

# Split arguments and parse recursively
arg1 = content[:comma_index].strip()
arg2 = content[comma_index + 1 :].strip()

# Recursively parse both arguments
expr_1 = expression_from_string(arg1)
expr_2 = expression_from_string(arg2)
expressions = [expression_from_string(exp) for exp in balanced_content]

# Build postfix expression: [arg1 tokens, arg2 tokens, function]
return DLTypeDimensionExpression(
identifier,
[*expr_1.parsed_expression, *expr_2.parsed_expression, function],
[*list(itertools.chain(*[exp.parsed_expression for exp in expressions])), function],
)


Expand Down
9 changes: 9 additions & 0 deletions dltype/tests/parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
("max(3^x,min(3-y,99))", {"x": 2, "y": 4}, 9),
("min(3^x,3-y)", {"x": 2, "y": 4}, -1),
("variable_name_with_underscores", {"variable_name_with_underscores": 1}, 1),
("isqrt(5)", {}, 2),
("isqrt(16)", {}, 4),
("isqrt(x-y)", {"x": 20, "y": 5}, 3),
("min(isqrt(20),isqrt(16))", {}, 4),
("max(isqrt(20),isqrt(16))", {}, 4),
],
)
def test_parse_expression(
Expand All @@ -52,6 +57,10 @@ def test_parse_expression(
("*batch", {}),
("3**2", {}),
("^", {}),
("isqrt(4, 5)", {}),
("isqrt()", {}),
("max(1)", {}),
("min()", {}),
],
)
def test_parse_invalid_expression(expression: str, scope: dict[str, int]) -> None:
Expand Down