Skip to content

Commit

Permalink
Add more tests (#2144)
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj committed Dec 23, 2021
1 parent f3013a7 commit 8112482
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 5 deletions.
8 changes: 4 additions & 4 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
from typing_extensions import Literal # pragma: no cover

import numpy as np

Expand Down Expand Up @@ -221,7 +221,7 @@ def __init__(

elif _PANDAS_AVAILABLE and isinstance(data, pd.DataFrame):
if not _PYARROW_AVAILABLE:
raise ImportError(
raise ImportError( # pragma: no cover
"'pyarrow' is required for converting a pandas DataFrame to a polars DataFrame."
)
self._df = pandas_to_pydf(data, columns=columns)
Expand Down Expand Up @@ -616,7 +616,7 @@ def to_arrow(self) -> "pa.Table":
- CategoricalType
"""
if not _PYARROW_AVAILABLE:
raise ImportError(
raise ImportError( # pragma: no cover
"'pyarrow' is required for converting a polars DataFrame to an Arrow Table."
)
record_batches = self._df.to_arrow()
Expand Down Expand Up @@ -1037,7 +1037,7 @@ def to_parquet(

if use_pyarrow:
if not _PYARROW_AVAILABLE:
raise ImportError(
raise ImportError( # pragma: no cover
"'pyarrow' is required when using 'to_parquet(..., use_pyarrow=True)'."
)

Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard
from typing_extensions import TypeGuard # pragma: no cover


def _process_null_values(
Expand Down
54 changes: 54 additions & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,12 @@ def test_concat() -> None:
_ = pl.concat([a, a, a])
assert a.shape == (2, 2)

with pytest.raises(ValueError):
_ = pl.concat([]) # type: ignore

with pytest.raises(ValueError):
pl.concat([df, df], how="rubbish")


def test_arg_where() -> None:
s = pl.Series([True, False, True, False])
Expand Down Expand Up @@ -1035,6 +1041,35 @@ def test_to_json(df: pl.DataFrame) -> None:
out = pl.read_json(s)
assert df.frame_equal(out, null_equal=True)

file = BytesIO()
df.to_json(file)
file.seek(0)
s = file.read().decode("utf8")
out = pl.read_json(s)
assert df.frame_equal(out, null_equal=True)


def test_to_csv() -> None:
df = pl.DataFrame(
{
"foo": [1, 2, 3, 4, 5],
"bar": [6, 7, 8, 9, 10],
"ham": ["a", "b", "c", "d", "e"],
}
)
expected = "foo,bar,ham\n1,6,a\n2,7,b\n3,8,c\n4,9,d\n5,10,e\n"

# if no file argument is supplied, to_csv() will return the string
s = df.to_csv()
assert s == expected

# otherwise it will write to the file/iobuffer
file = BytesIO()
df.to_csv(file)
file.seek(0)
s = file.read().decode("utf8")
assert s == expected


def test_from_rows() -> None:
df = pl.from_records([[1, 2, "foo"], [2, 3, "bar"]], orient="row")
Expand Down Expand Up @@ -1626,3 +1661,22 @@ def test_pivot_list() -> None:

out = df.groupby("a").pivot("a", "b").first()["a", "1", "2", "3"].sort("a")
assert out.frame_equal(expected, null_equal=True)


@pytest.mark.parametrize("as_series,inner_dtype", [(True, pl.Series), (False, list)])
def test_to_dict(as_series: bool, inner_dtype: tp.Any) -> None:
df = pl.DataFrame(
{
"A": [1, 2, 3, 4, 5],
"fruits": ["banana", "banana", "apple", "apple", "banana"],
"B": [5, 4, 3, 2, 1],
"cars": ["beetle", "audi", "beetle", "beetle", "beetle"],
"optional": [28, 300, None, 2, -30],
}
)

s = df.to_dict(as_series=as_series)
assert isinstance(s, dict)
for v in s.values():
assert isinstance(v, inner_dtype)
assert len(v) == len(df)
4 changes: 4 additions & 0 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,10 @@ def test_compare_series_type_mismatch() -> None:
with pytest.raises(AssertionError, match="Series are different\n\nType mismatch"):
testing.assert_series_equal(srs1, srs2) # type: ignore

srs3 = pl.Series([1.0, 2.0, 3.0])
with pytest.raises(AssertionError, match="Series are different\n\nDtype mismatch"):
testing.assert_series_equal(srs1, srs3)


def test_compare_series_name_mismatch() -> None:
srs1 = pl.Series(values=[1, 2, 3], name="srs1")
Expand Down

0 comments on commit 8112482

Please sign in to comment.