Skip to content

Commit

Permalink
Make embedded CSV test strings easier to read. (#3907)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghuls committed Jul 6, 2022
1 parent 93da58c commit 004958f
Showing 1 changed file with 144 additions and 97 deletions.
241 changes: 144 additions & 97 deletions py-polars/tests/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gzip
import io
import os
import textwrap
import zlib
from datetime import date
from pathlib import Path
Expand Down Expand Up @@ -49,77 +50,94 @@ def test_read_web_file() -> None:


def test_csv_null_values() -> None:
csv = """
a,b,c
na,b,c
a,na,c"""
csv = textwrap.dedent(
"""\
a,b,c
na,b,c
a,na,c
"""
)
f = io.StringIO(csv)

df = pl.read_csv(f, null_values="na")
assert df[0, "a"] is None
assert df[1, "b"] is None

csv = """
a,b,c
na,b,c
a,n/a,c"""
csv = textwrap.dedent(
"""\
a,b,c
na,b,c
a,n/a,c
"""
)
f = io.StringIO(csv)
df = pl.read_csv(f, null_values=["na", "n/a"])
assert df[0, "a"] is None
assert df[1, "b"] is None

csv = """
a,b,c
na,b,c
a,n/a,c"""
csv = textwrap.dedent(
"""\
a,b,c
na,b,c
a,n/a,c
"""
)
f = io.StringIO(csv)
df = pl.read_csv(f, null_values={"a": "na", "b": "n/a"})
assert df[0, "a"] is None
assert df[1, "b"] is None


def test_datetime_parsing() -> None:
csv = """
timestamp,open,high
2021-01-01 00:00:00,0.00305500,0.00306000
2021-01-01 00:15:00,0.00298800,0.00300400
2021-01-01 00:30:00,0.00298300,0.00300100
2021-01-01 00:45:00,0.00299400,0.00304000
"""
csv = textwrap.dedent(
"""\
timestamp,open,high
2021-01-01 00:00:00,0.00305500,0.00306000
2021-01-01 00:15:00,0.00298800,0.00300400
2021-01-01 00:30:00,0.00298300,0.00300100
2021-01-01 00:45:00,0.00299400,0.00304000
"""
)

f = io.StringIO(csv)
df = pl.read_csv(f, parse_dates=True)
assert df.dtypes == [pl.Datetime, pl.Float64, pl.Float64]


def test_partial_dtype_overwrite() -> None:
csv = """
a,b,c
1,2,3
1,2,3
"""
csv = textwrap.dedent(
"""\
a,b,c
1,2,3
1,2,3
"""
)
f = io.StringIO(csv)
df = pl.read_csv(f, dtypes=[pl.Utf8])
assert df.dtypes == [pl.Utf8, pl.Int64, pl.Int64]


def test_dtype_overwrite_with_column_name_selection() -> None:
csv = """
a,b,c,d
1,2,3,4
1,2,3,4
"""
csv = textwrap.dedent(
"""\
a,b,c,d
1,2,3,4
1,2,3,4
"""
)
f = io.StringIO(csv)
df = pl.read_csv(f, columns=["c", "b", "d"], dtypes=[pl.Int32, pl.Utf8])
assert df.dtypes == [pl.Utf8, pl.Int32, pl.Int64]


def test_dtype_overwrite_with_column_idx_selection() -> None:
csv = """
a,b,c,d
1,2,3,4
1,2,3,4
"""
csv = textwrap.dedent(
"""\
a,b,c,d
1,2,3,4
1,2,3,4
"""
)
f = io.StringIO(csv)
df = pl.read_csv(f, columns=[2, 1, 3], dtypes=[pl.Int32, pl.Utf8])
# Columns without an explicit dtype set will get pl.Utf8 if dtypes is a list
Expand All @@ -130,11 +148,13 @@ def test_dtype_overwrite_with_column_idx_selection() -> None:


def test_partial_column_rename() -> None:
csv = """
a,b,c
1,2,3
1,2,3
"""
csv = textwrap.dedent(
"""\
a,b,c
1,2,3
1,2,3
"""
)
f = io.StringIO(csv)
for use in [True, False]:
f.seek(0)
Expand All @@ -148,10 +168,13 @@ def test_partial_column_rename() -> None:
def test_read_csv_columns_argument(
col_input: Union[List[int], List[str]], col_out: List[str]
) -> None:
csv = """a,b,c
1,2,3
1,2,3
"""
csv = textwrap.dedent(
"""\
a,b,c
1,2,3
1,2,3
"""
)
f = io.StringIO(csv)
df = pl.read_csv(f, columns=col_input)
assert df.shape[0] == 2
Expand All @@ -172,11 +195,13 @@ def test_read_csv_buffer_ownership() -> None:


def test_column_rename_and_dtype_overwrite() -> None:
csv = """
a,b,c
1,2,3
1,2,3
"""
csv = textwrap.dedent(
"""\
a,b,c
1,2,3
1,2,3
"""
)
f = io.StringIO(csv)
df = pl.read_csv(
f,
Expand All @@ -194,10 +219,12 @@ def test_column_rename_and_dtype_overwrite() -> None:
)
assert df.dtypes == [pl.Utf8, pl.Float32]

csv = """
1,2,3
1,2,3
"""
csv = textwrap.dedent(
"""\
1,2,3
1,2,3
"""
)
f = io.StringIO(csv)
df = pl.read_csv(
f,
Expand All @@ -210,12 +237,14 @@ def test_column_rename_and_dtype_overwrite() -> None:

def test_compressed_csv() -> None:
# gzip compression
csv = """
a,b,c
1,a,1.0
2,b,2.0,
3,c,3.0
"""
csv = textwrap.dedent(
"""\
a,b,c
1,a,1.0
2,b,2.0,
3,c,3.0
"""
)
fout = io.BytesIO()
with gzip.GzipFile(fileobj=fout, mode="w") as f:
f.write(csv.encode())
Expand Down Expand Up @@ -271,18 +300,20 @@ def test_empty_bytes() -> None:


def test_csq_quote_char() -> None:
rolling_stones = """
linenum,last_name,first_name
1,Jagger,Mick
2,O"Brian,Mary
3,Richards,Keith
4,L"Etoile,Bennet
5,Watts,Charlie
6,Smith,D"Shawn
7,Wyman,Bill
8,Woods,Ron
9,Jones,Brian
"""
rolling_stones = textwrap.dedent(
"""\
linenum,last_name,first_name
1,Jagger,Mick
2,O"Brian,Mary
3,Richards,Keith
4,L"Etoile,Bennet
5,Watts,Charlie
6,Smith,D"Shawn
7,Wyman,Bill
8,Woods,Ron
9,Jones,Brian
"""
)

assert pl.read_csv(rolling_stones.encode(), quote_char=None).shape == (9, 3)

Expand All @@ -293,11 +324,15 @@ def test_csv_empty_quotes_char() -> None:


def test_ignore_parse_dates() -> None:
csv = """a,b,c
1,i,16200126
2,j,16250130
3,k,17220012
4,l,17290009""".encode()
csv = textwrap.dedent(
"""\
a,b,c
1,i,16200126
2,j,16250130
3,k,17220012
4,l,17290009
"""
).encode()

headers = ["a", "b", "c"]
dtypes: Dict[str, Type[DataType]] = {
Expand All @@ -308,14 +343,17 @@ def test_ignore_parse_dates() -> None:


def test_csv_date_handling() -> None:
csv = """date
1745-04-02
1742-03-21
1743-06-16
1730-07-22
""
1739-03-16
"""
csv = textwrap.dedent(
"""\
date
1745-04-02
1742-03-21
1743-06-16
1730-07-22
""
1739-03-16
"""
)
expected = pl.DataFrame(
{
"date": [
Expand Down Expand Up @@ -368,13 +406,16 @@ def test_csv_globbing(examples_dir: str) -> None:


def test_csv_schema_offset(foods_csv: str) -> None:
csv = """metadata
line
foo,bar
1,2
3,4
5,6
""".encode()
csv = textwrap.dedent(
"""\
metadata
line
foo,bar
1,2
3,4
5,6
"""
).encode()
df = pl.read_csv(csv, skip_rows=2)
assert df.columns == ["foo", "bar"]
assert df.shape == (3, 2)
Expand Down Expand Up @@ -409,11 +450,13 @@ def test_write_csv_delimiter() -> None:


def test_escaped_null_values() -> None:
csv = """
"a","b","c"
"a","n/a","NA"
"None","2","3.0"
"""
csv = textwrap.dedent(
"""\
"a","b","c"
"a","n/a","NA"
"None","2","3.0"
"""
)
f = io.StringIO(csv)
df = pl.read_csv(
f,
Expand Down Expand Up @@ -444,10 +487,14 @@ def quoting_round_trip() -> None:


def fallback_chrono_parser() -> None:
data = """date_1,date_2
data = textwrap.dedent(
"""\
date_1,date_2
2021-01-01,2021-1-1
2021-02-02,2021-2-2
2021-10-10,2021-10-10"""
2021-10-10,2021-10-10
"""
)
assert pl.read_csv(data.encode(), parse_dates=True).null_count().row(0) == (0, 0)


Expand Down

0 comments on commit 004958f

Please sign in to comment.