diff --git a/stdlib/@tests/test_cases/check_ast.py b/stdlib/@tests/test_cases/check_ast.py new file mode 100644 index 000000000000..4e99d00d6ccb --- /dev/null +++ b/stdlib/@tests/test_cases/check_ast.py @@ -0,0 +1,44 @@ +import ast +from typing_extensions import assert_type + +# Test with source code strings +assert_type(ast.parse("x = 1"), ast.Module) +assert_type(ast.parse("x = 1", mode="exec"), ast.Module) +assert_type(ast.parse("1 + 1", mode="eval"), ast.Expression) +assert_type(ast.parse("x = 1", mode="single"), ast.Interactive) +assert_type(ast.parse("(int, str) -> None", mode="func_type"), ast.FunctionType) + +# Test with mod objects - Module +mod1: ast.Module = ast.Module([], []) +assert_type(ast.parse(mod1), ast.Module) +assert_type(ast.parse(mod1, mode="exec"), ast.Module) +mod2: ast.Module = ast.Module(body=[ast.Expr(value=ast.Constant(value=42))], type_ignores=[]) +assert_type(ast.parse(mod2), ast.Module) + +# Test with mod objects - Expression +expr1: ast.Expression = ast.Expression(body=ast.Constant(value=42)) +assert_type(ast.parse(expr1, mode="eval"), ast.Expression) + +# Test with mod objects - Interactive +inter1: ast.Interactive = ast.Interactive(body=[]) +assert_type(ast.parse(inter1, mode="single"), ast.Interactive) + +# Test with mod objects - FunctionType +func1: ast.FunctionType = ast.FunctionType(argtypes=[], returns=ast.Constant(value=None)) +assert_type(ast.parse(func1, mode="func_type"), ast.FunctionType) + +# Test that any AST node can be passed and returns the same type +binop: ast.BinOp = ast.BinOp(left=ast.Constant(1), op=ast.Add(), right=ast.Constant(2)) +assert_type(ast.parse(binop), ast.BinOp) + +constant: ast.Constant = ast.Constant(value=42) +assert_type(ast.parse(constant), ast.Constant) + +expr_stmt: ast.Expr = ast.Expr(value=ast.Constant(value=42)) +assert_type(ast.parse(expr_stmt), ast.Expr) + +# Test with additional parameters +assert_type(ast.parse(mod1, filename="test.py"), ast.Module) +assert_type(ast.parse(mod1, type_comments=True), ast.Module) +assert_type(ast.parse(mod1, feature_version=(3, 10)), ast.Module) +assert_type(ast.parse(binop, filename="test.py"), ast.BinOp) diff --git a/stdlib/ast.pyi b/stdlib/ast.pyi index d360c2ed60e5..5f11e7dcd434 100644 --- a/stdlib/ast.pyi +++ b/stdlib/ast.pyi @@ -1744,6 +1744,16 @@ if sys.version_info < (3, 14): _T = _TypeVar("_T", bound=AST) if sys.version_info >= (3, 13): + @overload + def parse( + source: _T, + filename: str | ReadableBuffer | os.PathLike[Any] = "", + mode: Literal["exec", "eval", "func_type", "single"] = "exec", + *, + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + optimize: Literal[-1, 0, 1, 2] = -1, + ) -> _T: ... @overload def parse( source: str | ReadableBuffer, @@ -1823,6 +1833,15 @@ if sys.version_info >= (3, 13): ) -> mod: ... else: + @overload + def parse( + source: _T, + filename: str | ReadableBuffer | os.PathLike[Any] = "", + mode: Literal["exec", "eval", "func_type", "single"] = "exec", + *, + type_comments: bool = False, + feature_version: None | int | tuple[int, int] = None, + ) -> _T: ... @overload def parse( source: str | ReadableBuffer,