diff --git a/mypy/constant_fold.py b/mypy/constant_fold.py index 4582b2a7396d..7908260c4c3f 100644 --- a/mypy/constant_fold.py +++ b/mypy/constant_fold.py @@ -9,6 +9,7 @@ from mypy.nodes import ( ComplexExpr, + ConditionalExpr, Expression, FloatExpr, IntExpr, @@ -73,6 +74,11 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non value = constant_fold_expr(expr.expr, cur_mod_id) if value is not None: return constant_fold_unary_op(expr.op, value) + elif isinstance(expr, ConditionalExpr): + cond = constant_fold_expr(expr.cond, cur_mod_id) + if cond is not None: + value_expr = expr.if_expr if cond else expr.else_expr + return constant_fold_expr(value_expr, cur_mod_id) return None diff --git a/mypyc/irbuild/constant_fold.py b/mypyc/irbuild/constant_fold.py index 12a4b15dd40c..102f9ccf0fbf 100644 --- a/mypyc/irbuild/constant_fold.py +++ b/mypyc/irbuild/constant_fold.py @@ -16,6 +16,7 @@ from mypy.nodes import ( BytesExpr, ComplexExpr, + ConditionalExpr, Expression, FloatExpr, IntExpr, @@ -72,6 +73,11 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | value = constant_fold_expr(builder, expr.expr) if value is not None and not isinstance(value, bytes): return constant_fold_unary_op(expr.op, value) + elif isinstance(expr, ConditionalExpr): + cond = constant_fold_expr(builder, expr.cond) + if cond is not None: + value_expr = expr.if_expr if cond else expr.else_expr + return constant_fold_expr(builder, value_expr) return None diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 59ecc4ac2c5c..f239f3a1fee7 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -653,6 +653,7 @@ def try_gen_slice_op(builder: IRBuilder, base: Value, index: SliceExpr) -> Value return None +@folding_candidate def transform_conditional_expr(builder: IRBuilder, expr: ConditionalExpr) -> Value: if_body, else_body, next_block = BasicBlock(), BasicBlock(), BasicBlock() diff --git a/mypyc/test-data/irbuild-constant-fold.test b/mypyc/test-data/irbuild-constant-fold.test index cd953c84c541..f5e041cf2ceb 100644 --- a/mypyc/test-data/irbuild-constant-fold.test +++ b/mypyc/test-data/irbuild-constant-fold.test @@ -478,3 +478,16 @@ L0: r3 = (-1.5+2j) neg_2 = r3 return 1 + +[case testConditionalConstantFolding] +from typing import Final + +constant: Final = 1 + +def f() -> None: + a = "t" if constant else "f" +[out] +def f(): + a :: str +L0: + a = "t"