From 5b24016155032ebf246d3f0eb1f9e8dd7cb94212 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Fri, 12 May 2023 19:41:10 +0400 Subject: [PATCH] fix(rust,python): sql `BETWEEN` bounds should be inclusive (#8818) --- polars/polars-sql/src/sql_expr.rs | 2 +- py-polars/tests/unit/test_sql.py | 30 ++++++++++++++---------------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/polars/polars-sql/src/sql_expr.rs b/polars/polars-sql/src/sql_expr.rs index 30ca8c4af03b..efc81aac6d29 100644 --- a/polars/polars-sql/src/sql_expr.rs +++ b/polars/polars-sql/src/sql_expr.rs @@ -245,7 +245,7 @@ impl SqlExprVisitor<'_> { if negated { Ok(expr.clone().lt(low).or(expr.gt(high))) } else { - Ok(expr.clone().gt(low).and(expr.lt(high))) + Ok(expr.clone().gt_eq(low).and(expr.lt_eq(high))) } } diff --git a/py-polars/tests/unit/test_sql.py b/py-polars/tests/unit/test_sql.py index ed92a02653de..640458ac717f 100644 --- a/py-polars/tests/unit/test_sql.py +++ b/py-polars/tests/unit/test_sql.py @@ -154,31 +154,29 @@ def test_sql_is_between(foods_ipc_path: Path) -> None: """ SELECT * FROM foods1 - WHERE foods1.calories BETWEEN 20 AND 31 - LIMIT 4 + WHERE foods1.calories BETWEEN 22 AND 30 + ORDER BY "calories" DESC, "sugars_g" DESC """ ) - assert out.to_dict(False) == { - "category": ["fruit", "vegetables", "fruit", "vegetables"], - "calories": [30, 25, 30, 22], - "fats_g": [0.0, 0.0, 0.0, 0.0], - "sugars_g": [5, 2, 3, 3], - } + assert out.rows() == [ + ("fruit", 30, 0.0, 5), + ("vegetables", 30, 0.0, 5), + ("fruit", 30, 0.0, 3), + ("vegetables", 25, 0.0, 4), + ("vegetables", 25, 0.0, 3), + ("vegetables", 25, 0.0, 2), + ("vegetables", 22, 0.0, 3), + ] out = c.execute( """ SELECT * FROM foods1 - WHERE calories NOT BETWEEN 20 AND 31 - LIMIT 4 + WHERE calories NOT BETWEEN 22 AND 30 + ORDER BY "calories" ASC """ ) - assert out.to_dict(False) == { - "category": ["vegetables", "seafood", "meat", "fruit"], - "calories": [45, 150, 100, 60], - "fats_g": [0.5, 5.0, 5.0, 0.0], - "sugars_g": [2, 0, 0, 11], - } + assert not any((22 <= cal <= 30) for cal in out["calories"]) def test_sql_union() -> None: