diff --git a/py-polars/docs/source/reference/dataframe/miscellaneous.rst b/py-polars/docs/source/reference/dataframe/miscellaneous.rst index 116c1d577231..b60056e7fbce 100644 --- a/py-polars/docs/source/reference/dataframe/miscellaneous.rst +++ b/py-polars/docs/source/reference/dataframe/miscellaneous.rst @@ -8,6 +8,7 @@ Miscellaneous DataFrame.apply DataFrame.corr + DataFrame.equals DataFrame.frame_equal DataFrame.lazy DataFrame.map_rows diff --git a/py-polars/docs/source/reference/series/miscellaneous.rst b/py-polars/docs/source/reference/series/miscellaneous.rst index 949cfc320a9c..b310ee6f4992 100644 --- a/py-polars/docs/source/reference/series/miscellaneous.rst +++ b/py-polars/docs/source/reference/series/miscellaneous.rst @@ -7,6 +7,7 @@ Miscellaneous :toctree: api/ Series.apply + Series.equals Series.map_elements Series.reinterpret Series.series_equal diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index a1d079f87d66..dff1c47e41d6 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -4578,9 +4578,9 @@ def bottom_k( ) ) - def frame_equal(self, other: DataFrame, *, null_equal: bool = True) -> bool: + def equals(self, other: DataFrame, *, null_equal: bool = True) -> bool: """ - Check if DataFrame is equal to other. + Check whether the DataFrame is equal to another DataFrame. Parameters ---------- @@ -4589,6 +4589,10 @@ def frame_equal(self, other: DataFrame, *, null_equal: bool = True) -> bool: null_equal Consider null values as equal. + See Also + -------- + assert_frame_equal + Examples -------- >>> df1 = pl.DataFrame( @@ -4605,13 +4609,13 @@ def frame_equal(self, other: DataFrame, *, null_equal: bool = True) -> bool: ... "ham": ["c", "b", "a"], ... } ... ) - >>> df1.frame_equal(df1) + >>> df1.equals(df1) True - >>> df1.frame_equal(df2) + >>> df1.equals(df2) False """ - return self._df.frame_equal(other._df, null_equal) + return self._df.equals(other._df, null_equal) @deprecate_function( "DataFrame.replace is deprecated and will be removed in a future version. " @@ -10477,6 +10481,23 @@ def replace_at_idx(self, index: int, new_column: Series) -> Self: """ return self.replace_column(index, new_column) + @deprecate_renamed_function("equals", version="0.19.16") + def frame_equal(self, other: DataFrame, *, null_equal: bool = True) -> bool: + """ + Check whether the DataFrame is equal to another DataFrame. + + .. deprecated:: 0.19.16 + This method has been renamed to :func:`equals`. + + Parameters + ---------- + other + DataFrame to compare with. + null_equal + Consider null values as equal. + """ + return self.equals(other, null_equal=null_equal) + def _prepare_other_arg(other: Any, length: int | None = None) -> Series: # if not a series create singleton series such that it will broadcast diff --git a/py-polars/polars/expr/datetime.py b/py-polars/polars/expr/datetime.py index 3f8d1ff8ca82..2cd2107cc386 100644 --- a/py-polars/polars/expr/datetime.py +++ b/py-polars/polars/expr/datetime.py @@ -144,9 +144,9 @@ def truncate( │ 2001-01-01 18:00:00 │ │ 2001-01-01 22:00:00 │ └─────────────────────┘ - >>> df.select(pl.col("datetime").dt.truncate("1h")).frame_equal( - ... df.select(pl.col("datetime").dt.truncate(timedelta(hours=1))) - ... ) + >>> truncate_str = df.select(pl.col("datetime").dt.truncate("1h")) + >>> truncate_td = df.select(pl.col("datetime").dt.truncate(timedelta(hours=1))) + >>> truncate_str.equals(truncate_td) True >>> df = pl.datetime_range( diff --git a/py-polars/polars/series/datetime.py b/py-polars/polars/series/datetime.py index cde33ef27402..17feaa9fa838 100644 --- a/py-polars/polars/series/datetime.py +++ b/py-polars/polars/series/datetime.py @@ -1675,7 +1675,9 @@ def round( 2001-01-01 19:00:00 2001-01-01 22:00:00 ] - >>> s.dt.round("1h").series_equal(s.dt.round(timedelta(hours=1))) + >>> round_str = s.dt.round("1h") + >>> round_td = s.dt.round(timedelta(hours=1)) + >>> round_str.equals(round_td) True >>> start = datetime(2001, 1, 1) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 72680c04b40f..10690921d944 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -3737,11 +3737,11 @@ def explode(self) -> Series: """ - def series_equal( + def equals( self, other: Series, *, null_equal: bool = True, strict: bool = False ) -> bool: """ - Check if series is equal with another Series. + Check whether the Series is equal to another Series. Parameters ---------- @@ -3753,17 +3753,20 @@ def series_equal( Don't allow different numerical dtypes, e.g. comparing `pl.UInt32` with a `pl.Int64` will return `False`. + See Also + -------- + assert_series_equal + Examples -------- - >>> s = pl.Series("a", [1, 2, 3]) + >>> s1 = pl.Series("a", [1, 2, 3]) >>> s2 = pl.Series("b", [4, 5, 6]) - >>> s.series_equal(s) + >>> s1.equals(s1) True - >>> s.series_equal(s2) + >>> s1.equals(s2) False - """ - return self._s.series_equal(other._s, null_equal, strict) + return self._s.equals(other._s, null_equal, strict) def len(self) -> int: """ @@ -7169,6 +7172,28 @@ def map_dict( """ return self.replace(mapping, default=default, return_dtype=return_dtype) + @deprecate_renamed_function("equals", version="0.19.16") + def series_equal( + self, other: Series, *, null_equal: bool = True, strict: bool = False + ) -> bool: + """ + Check whether the Series is equal to another Series. + + .. deprecated:: 0.19.16 + This method has been renamed to :meth:`equals`. + + Parameters + ---------- + other + Series to compare with. + null_equal + Consider null values as equal. + strict + Don't allow different numerical dtypes, e.g. comparing `pl.UInt32` with a + `pl.Int64` will return `False`. + """ + return self.equals(other, null_equal=null_equal, strict=strict) + # Keep the `list` and `str` properties below at the end of the definition of Series, # as to not confuse mypy with the type annotation `str` and `list` diff --git a/py-polars/src/dataframe.rs b/py-polars/src/dataframe.rs index 4c9c847741dd..7cae9f43a7df 100644 --- a/py-polars/src/dataframe.rs +++ b/py-polars/src/dataframe.rs @@ -1125,7 +1125,7 @@ impl PyDataFrame { Ok(mask.into_series().into()) } - pub fn frame_equal(&self, other: &PyDataFrame, null_equal: bool) -> bool { + pub fn equals(&self, other: &PyDataFrame, null_equal: bool) -> bool { if null_equal { self.df.frame_equal_missing(&other.df) } else { diff --git a/py-polars/src/series/mod.rs b/py-polars/src/series/mod.rs index efee7e96840b..151ed4f1e503 100644 --- a/py-polars/src/series/mod.rs +++ b/py-polars/src/series/mod.rs @@ -301,7 +301,7 @@ impl PySeries { self.series.has_validity() } - fn series_equal(&self, other: &PySeries, null_equal: bool, strict: bool) -> bool { + fn equals(&self, other: &PySeries, null_equal: bool, strict: bool) -> bool { if strict && (self.series.dtype() != other.series.dtype()) { return false; } diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 07245b547a2b..db3d03e2063a 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -619,7 +619,7 @@ def test_to_dummies_drop_first() -> None: assert dd.columns == ["foo_1", "foo_2", "bar_4", "bar_5", "baz_y", "baz_z"] assert set(dm.columns) - set(dd.columns) == {"foo_0", "bar_3", "baz_x"} - assert dm.select(dd.columns).frame_equal(dd) + assert_frame_equal(dm.select(dd.columns), dd) assert dd.rows() == [ (0, 0, 0, 0, 0, 0), (1, 0, 1, 0, 1, 0), @@ -3288,52 +3288,6 @@ def test_iter_slices() -> None: assert batches[1].rows() == df[50:].rows() -def test_frame_equal() -> None: - # Values are checked - df1 = pl.DataFrame( - { - "foo": [1, 2, 3], - "bar": [6.0, 7.0, 8.0], - "ham": ["a", "b", "c"], - } - ) - df2 = pl.DataFrame( - { - "foo": [3, 2, 1], - "bar": [8.0, 7.0, 6.0], - "ham": ["c", "b", "a"], - } - ) - - assert df1.frame_equal(df1) - assert not df1.frame_equal(df2) - - # Column names are checked - df3 = pl.DataFrame( - { - "a": [1, 2, 3], - "b": [6.0, 7.0, 8.0], - "c": ["a", "b", "c"], - } - ) - assert not df1.frame_equal(df3) - - # Datatypes are NOT checked - df = pl.DataFrame( - { - "foo": [1, 2, None], - "bar": [6.0, 7.0, None], - "ham": ["a", "b", None], - } - ) - assert df.frame_equal(df.with_columns(pl.col("foo").cast(pl.Int8))) - assert df.frame_equal(df.with_columns(pl.col("ham").cast(pl.Categorical))) - - # The null_equal parameter determines if None values are considered equal - assert df.frame_equal(df) - assert not df.frame_equal(df, null_equal=False) - - def test_format_empty_df() -> None: df = pl.DataFrame( [ diff --git a/py-polars/tests/unit/dataframe/test_equals.py b/py-polars/tests/unit/dataframe/test_equals.py new file mode 100644 index 000000000000..c8c5ec1c2e64 --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_equals.py @@ -0,0 +1,47 @@ +import polars as pl + + +def test_equals() -> None: + # Values are checked + df1 = pl.DataFrame( + { + "foo": [1, 2, 3], + "bar": [6.0, 7.0, 8.0], + "ham": ["a", "b", "c"], + } + ) + df2 = pl.DataFrame( + { + "foo": [3, 2, 1], + "bar": [8.0, 7.0, 6.0], + "ham": ["c", "b", "a"], + } + ) + + assert df1.equals(df1) is True + assert df1.equals(df2) is False + + # Column names are checked + df3 = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [6.0, 7.0, 8.0], + "c": ["a", "b", "c"], + } + ) + assert df1.equals(df3) is False + + # Datatypes are NOT checked + df = pl.DataFrame( + { + "foo": [1, 2, None], + "bar": [6.0, 7.0, None], + "ham": ["a", "b", None], + } + ) + assert df.equals(df.with_columns(pl.col("foo").cast(pl.Int8))) is True + assert df.equals(df.with_columns(pl.col("ham").cast(pl.Categorical))) is True + + # The null_equal parameter determines if None values are considered equal + assert df.equals(df) is True + assert df.equals(df, null_equal=False) is False diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 083b0842aaa6..bb59b86aba35 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -397,19 +397,18 @@ def test_list_any() -> None: def test_list_min_max() -> None: - for dt in pl.NUMERIC_DTYPES: - if dt == pl.Decimal: - continue + for dt in pl.INTEGER_DTYPES | pl.FLOAT_DTYPES: df = pl.DataFrame( {"a": [[1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]}, schema={"a": pl.List(dt)}, ) - assert df.select(pl.col("a").list.min())["a"].series_equal( - df.select(pl.col("a").list.first())["a"] - ) - assert df.select(pl.col("a").list.max())["a"].series_equal( - df.select(pl.col("a").list.last())["a"] - ) + result = df.select(pl.col("a").list.min()) + expected = df.select(pl.col("a").list.first()) + assert_frame_equal(result, expected) + + result = df.select(pl.col("a").list.max()) + expected = df.select(pl.col("a").list.last()) + assert_frame_equal(result, expected) df = pl.DataFrame( {"a": [[1], [1, 5, -1, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5], None]}, diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index f94a689ebc04..4743e647fb29 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -427,8 +427,8 @@ def test_timezone() -> None: # different timezones are not considered equal # we check both `null_equal=True` and `null_equal=False` # https://github.com/pola-rs/polars/issues/5023 - assert not s.series_equal(tz_s, null_equal=False) - assert not s.series_equal(tz_s, null_equal=True) + assert s.equals(tz_s, null_equal=False) is False + assert s.equals(tz_s, null_equal=True) is False assert_series_not_equal(tz_s, s) assert_series_equal(s.cast(int), tz_s.cast(int)) diff --git a/py-polars/tests/unit/io/test_hive.py b/py-polars/tests/unit/io/test_hive.py index 090ce2a86cdc..e476d61aae17 100644 --- a/py-polars/tests/unit/io/test_hive.py +++ b/py-polars/tests/unit/io/test_hive.py @@ -129,10 +129,10 @@ def test_hive_partitioned_projection_pushdown( parallel=parallel, # type: ignore[arg-type] ) - expect = q.collect().select("category") - actual = q.select("category").collect() + expected = q.collect().select("category") + result = q.select("category").collect() - assert expect.frame_equal(actual) + assert_frame_equal(result, expected) @pytest.mark.write_disk() diff --git a/py-polars/tests/unit/namespaces/string/test_string.py b/py-polars/tests/unit/namespaces/string/test_string.py index e4ffa4ff32df..b76f84a0409e 100644 --- a/py-polars/tests/unit/namespaces/string/test_string.py +++ b/py-polars/tests/unit/namespaces/string/test_string.py @@ -244,11 +244,9 @@ def test_str_to_integer_df() -> None: "hex": ["fa1e", "ff00", "cafe", "invalid", None], } ) - out = df.with_columns( - [ - pl.col("bin").str.to_integer(base=2, strict=False), - pl.col("hex").str.to_integer(base=16, strict=False), - ] + result = df.with_columns( + pl.col("bin").str.to_integer(base=2, strict=False), + pl.col("hex").str.to_integer(base=16, strict=False), ) expected = pl.DataFrame( @@ -257,7 +255,7 @@ def test_str_to_integer_df() -> None: "hex": [64030, 65280, 51966, None, None], } ) - assert out.frame_equal(expected) + assert_frame_equal(result, expected) with pytest.raises(pl.ComputeError): df.with_columns( diff --git a/py-polars/tests/unit/namespaces/test_binary.py b/py-polars/tests/unit/namespaces/test_binary.py index 0e08b753df94..7e6eba929c15 100644 --- a/py-polars/tests/unit/namespaces/test_binary.py +++ b/py-polars/tests/unit/namespaces/test_binary.py @@ -1,6 +1,7 @@ import pytest import polars as pl +from polars.testing import assert_frame_equal from polars.type_aliases import TransferEncoding @@ -132,7 +133,7 @@ def test_compare_encode_between_lazy_and_eager_6814(encoding: TransferEncoding) result_eager = df.select(expr) dtype = result_eager["x"].dtype result_lazy = df.lazy().select(expr).select(pl.col(dtype)).collect() - assert result_eager.frame_equal(result_lazy) + assert_frame_equal(result_eager, result_lazy) @pytest.mark.parametrize( @@ -148,4 +149,4 @@ def test_compare_decode_between_lazy_and_eager_6814(encoding: TransferEncoding) result_eager = df.select(expr) dtype = result_eager["x"].dtype result_lazy = df.lazy().select(expr).select(pl.col(dtype)).collect() - assert result_eager.frame_equal(result_lazy) + assert_frame_equal(result_eager, result_lazy) diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index 870d4a4fad96..648d00d2843d 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -492,11 +492,10 @@ def test_list_gather_logical_type() -> None: def test_list_unique() -> None: - assert ( - pl.Series([[1, 1, 2, 2, 3], [3, 3, 3, 2, 1, 2]]) - .list.unique(maintain_order=True) - .series_equal(pl.Series([[1, 2, 3], [3, 2, 1]])) - ) + s = pl.Series([[1, 1, 2, 2, 3], [3, 3, 3, 2, 1, 2]]) + result = s.list.unique(maintain_order=True) + expected = pl.Series([[1, 2, 3], [3, 2, 1]]) + assert_series_equal(result, expected) def test_list_to_struct() -> None: diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index de7a2ada9cd0..e2d887641e1d 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -807,18 +807,25 @@ def test_group_by_list_scalar_11749() -> None: def test_group_by_with_expr_as_key() -> None: gb = pl.select(x=1).group_by(pl.col("x").alias("key")) - assert gb.agg(pl.all().first()).frame_equal(gb.agg(pl.first("x"))) + result = gb.agg(pl.all().first()) + expected = gb.agg(pl.first("x")) + assert_frame_equal(result, expected) # tests: 11766 - assert gb.head(0).frame_equal(gb.agg(pl.col("x").head(0)).explode("x")) - assert gb.tail(0).frame_equal(gb.agg(pl.col("x").tail(0)).explode("x")) + result = gb.head(0) + expected = gb.agg(pl.col("x").head(0)).explode("x") + assert_frame_equal(result, expected) + + result = gb.tail(0) + expected = gb.agg(pl.col("x").tail(0)).explode("x") + assert_frame_equal(result, expected) def test_lazy_group_by_reuse_11767() -> None: lgb = pl.select(x=1).lazy().group_by("x") a = lgb.count() b = lgb.count() - assert a.collect().frame_equal(b.collect()) + assert_frame_equal(a, b) def test_group_by_double_on_empty_12194() -> None: diff --git a/py-polars/tests/unit/series/test_equals.py b/py-polars/tests/unit/series/test_equals.py new file mode 100644 index 000000000000..7eb1cd21a240 --- /dev/null +++ b/py-polars/tests/unit/series/test_equals.py @@ -0,0 +1,27 @@ +from datetime import datetime + +import polars as pl + + +def test_equals() -> None: + s1 = pl.Series("a", [1.0, 2.0, None], pl.Float64) + s2 = pl.Series("a", [1, 2, None], pl.Int64) + + assert s1.equals(s2) is True + assert s1.equals(s2, strict=True) is False + assert s1.equals(s2, null_equal=False) is False + + df = pl.DataFrame( + {"dtm": [datetime(2222, 2, 22, 22, 22, 22)]}, + schema_overrides={"dtm": pl.Datetime(time_zone="UTC")}, + ).with_columns( + s3=pl.col("dtm").dt.convert_time_zone("Europe/London"), + s4=pl.col("dtm").dt.convert_time_zone("Asia/Tokyo"), + ) + s3 = df["s3"].rename("b") + s4 = df["s4"].rename("b") + + assert s3.equals(s4) is False + assert s3.equals(s4, strict=True) is False + assert s3.equals(s4, null_equal=False) is False + assert s3.dt.convert_time_zone("Asia/Tokyo").equals(s4) is True diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index a785905dd8ae..c931183a2056 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -254,30 +254,6 @@ def test_concat() -> None: assert s.len() == 3 -def test_equal() -> None: - s1 = pl.Series("a", [1.0, 2.0, None], Float64) - s2 = pl.Series("a", [1, 2, None], Int64) - - assert s1.series_equal(s2) is True - assert s1.series_equal(s2, strict=True) is False - assert s1.series_equal(s2, null_equal=False) is False - - df = pl.DataFrame( - {"dtm": [datetime(2222, 2, 22, 22, 22, 22)]}, - schema_overrides={"dtm": Datetime(time_zone="UTC")}, - ).with_columns( - s3=pl.col("dtm").dt.convert_time_zone("Europe/London"), - s4=pl.col("dtm").dt.convert_time_zone("Asia/Tokyo"), - ) - s3 = df["s3"].rename("b") - s4 = df["s4"].rename("b") - - assert s3.series_equal(s4) is False - assert s3.series_equal(s4, strict=True) is False - assert s3.series_equal(s4, null_equal=False) is False - assert s3.dt.convert_time_zone("Asia/Tokyo").series_equal(s4) is True - - @pytest.mark.parametrize( "dtype", [pl.Int64, pl.Float64, pl.Utf8, pl.Boolean], diff --git a/py-polars/tests/unit/sql/test_sql.py b/py-polars/tests/unit/sql/test_sql.py index 78c24237203f..90959fa4a4f7 100644 --- a/py-polars/tests/unit/sql/test_sql.py +++ b/py-polars/tests/unit/sql/test_sql.py @@ -1186,10 +1186,11 @@ def test_sql_expr() -> None: "SUBSTR(b,1,2) AS b2", ] ) + result = df.select(*sql_exprs) expected = pl.DataFrame( - {"a": [1, 1, 1], "aa": [1, 4, 27], "b2": ["yz", "bc", None]} + {"a": [1, 1, 1], "aa": [1.0, 4.0, 27.0], "b2": ["yz", "bc", None]} ) - assert df.select(*sql_exprs).frame_equal(expected) + assert_frame_equal(result, expected) # expect expressions that can't reasonably be parsed as expressions to raise # (for example: those that explicitly reference tables and/or use wildcards) @@ -1249,12 +1250,11 @@ def test_sql_date() -> None: ) with pl.SQLContext(df=df, eager_execution=True) as ctx: - expected = pl.DataFrame({"date": [True, False, False]}) - assert ctx.execute("SELECT date < DATE('2021-03-20') from df").frame_equal( - expected - ) + result = ctx.execute("SELECT date < DATE('2021-03-20') from df") + + expected = pl.DataFrame({"date": [True, False, False]}) + assert_frame_equal(result, expected) + result = pl.select(pl.sql_expr("""CAST(DATE('2023-03', '%Y-%m') as STRING)""")) expected = pl.DataFrame({"literal": ["2023-03-01"]}) - assert pl.select( - pl.sql_expr("""CAST(DATE('2023-03', '%Y-%m') as STRING)""") - ).frame_equal(expected) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/test_empty.py b/py-polars/tests/unit/test_empty.py index 1cc55353e059..d757c8bab179 100644 --- a/py-polars/tests/unit/test_empty.py +++ b/py-polars/tests/unit/test_empty.py @@ -28,10 +28,10 @@ def test_empty_cross_join() -> None: def test_empty_string_replace() -> None: s = pl.Series("", [], dtype=pl.Utf8) - assert s.str.replace("a", "b", literal=True).series_equal(s) - assert s.str.replace("a", "b").series_equal(s) - assert s.str.replace("ab", "b", literal=True).series_equal(s) - assert s.str.replace("ab", "b").series_equal(s) + assert_series_equal(s.str.replace("a", "b", literal=True), s) + assert_series_equal(s.str.replace("a", "b"), s) + assert_series_equal(s.str.replace("ab", "b", literal=True), s) + assert_series_equal(s.str.replace("ab", "b"), s) def test_empty_window_function() -> None: diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index e756eb16446e..f3f26433a57d 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -15,8 +15,7 @@ from polars import lit, when from polars.datatypes import FLOAT_DTYPES from polars.exceptions import ComputeError, PolarsInefficientMapWarning -from polars.testing import assert_frame_equal -from polars.testing.asserts import assert_series_equal +from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: from _pytest.capture import CaptureFixture @@ -1536,7 +1535,7 @@ def test_compare_aggregation_between_lazy_and_eager_6904( result_eager = df.select(func.over("y")).select("x") dtype_eager = result_eager["x"].dtype result_lazy = df.lazy().select(func.over("y")).select(pl.col(dtype_eager)).collect() - assert result_eager.frame_equal(result_lazy) + assert_frame_equal(result_eager, result_lazy) @pytest.mark.parametrize( diff --git a/py-polars/tests/unit/test_predicates.py b/py-polars/tests/unit/test_predicates.py index 9a9f94598d17..914d77f9ea34 100644 --- a/py-polars/tests/unit/test_predicates.py +++ b/py-polars/tests/unit/test_predicates.py @@ -6,6 +6,7 @@ import polars as pl from polars.testing import assert_frame_equal +from polars.testing.asserts.series import assert_series_equal def test_predicate_4906() -> None: @@ -108,10 +109,10 @@ def test_predicate_arr_first_6573() -> None: def test_fast_path_comparisons() -> None: s = pl.Series(np.sort(np.random.randint(0, 50, 100))) - assert (s > 25).series_equal(s.set_sorted() > 25) - assert (s >= 25).series_equal(s.set_sorted() >= 25) - assert (s < 25).series_equal(s.set_sorted() < 25) - assert (s <= 25).series_equal(s.set_sorted() <= 25) + assert_series_equal(s > 25, s.set_sorted() > 25) + assert_series_equal(s >= 25, s.set_sorted() >= 25) + assert_series_equal(s < 25, s.set_sorted() < 25) + assert_series_equal(s <= 25, s.set_sorted() <= 25) def test_predicate_pushdown_block_8661() -> None: @@ -252,7 +253,9 @@ def test_predicate_pushdown_boundary_12102() -> None: .filter(pl.col("y") > 2) ) - assert lf.collect().frame_equal(lf.collect(predicate_pushdown=False)) + result = lf.collect() + result_no_ppd = lf.collect(predicate_pushdown=False) + assert_frame_equal(result, result_no_ppd) def test_take_can_block_predicate_pushdown() -> None: