Skip to content

Commit

Permalink
fix[rust]: Fix wrong schema inference with skip_rows_after_header (#4728
Browse files Browse the repository at this point in the history
) (#4818)

Before, rows which are skipped with skip_rows_after_header were
still used for schema inference although they didn't appear in the
output dataframe.
  • Loading branch information
ghuls committed Sep 11, 2022
1 parent 0ec3d0e commit 0ac48cb
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 13 deletions.
2 changes: 2 additions & 0 deletions polars/polars-io/src/csv/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ impl<'a> CoreReader<'a> {
has_header,
schema_overwrite,
&mut skip_rows,
skip_rows_after_header,
comment_char,
quote_char,
eol_char,
Expand All @@ -230,6 +231,7 @@ impl<'a> CoreReader<'a> {
has_header,
schema_overwrite,
&mut skip_rows,
skip_rows_after_header,
comment_char,
quote_char,
eol_char,
Expand Down
8 changes: 7 additions & 1 deletion polars/polars-io/src/csv/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ pub fn infer_file_schema(
// we take &mut because we maybe need to skip more rows dependent
// on the schema inference
skip_rows: &mut usize,
skip_rows_after_header: usize,
comment_char: Option<u8>,
quote_char: Option<u8>,
eol_char: u8,
Expand Down Expand Up @@ -269,6 +270,7 @@ pub fn infer_file_schema(
has_header,
schema_overwrite,
skip_rows,
skip_rows_after_header,
comment_char,
quote_char,
eol_char,
Expand Down Expand Up @@ -296,7 +298,10 @@ pub fn infer_file_schema(
// needed to prevent ownership going into the iterator loop
let records_ref = &mut lines;

for mut line in records_ref.take(max_read_lines.unwrap_or(usize::MAX)) {
for mut line in records_ref
.take(max_read_lines.unwrap_or(usize::MAX))
.skip(skip_rows_after_header)
{
rows_count += 1;

if let Some(c) = comment_char {
Expand Down Expand Up @@ -423,6 +428,7 @@ pub fn infer_file_schema(
has_header,
schema_overwrite,
skip_rows,
skip_rows_after_header,
comment_char,
quote_char,
eol_char,
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/src/frame/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ impl<'a> LazyCsvReader<'a> {
// we set it to None and modify them after the schema is updated
None,
&mut skip_rows,
self.skip_rows_after_header,
self.comment_char,
self.quote_char,
self.eol_char,
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ impl LogicalPlanBuilder {
has_header,
schema_overwrite,
&mut skip_rows,
skip_rows_after_header,
comment_char,
quote_char,
eol_char,
Expand Down
31 changes: 19 additions & 12 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,26 +494,33 @@ def test_csv_schema_offset(foods_csv: str) -> None:
"""\
metadata
line
foo,bar
1,2
3,4
5,6
col1,col2,col3
alpha,beta,gamma
1,2.0,"A"
3,4.0,"B"
5,6.0,"C"
"""
).encode()
df = pl.read_csv(csv, skip_rows=2)
assert df.columns == ["foo", "bar"]
assert df.shape == (3, 2)
df = pl.read_csv(csv, skip_rows=2, skip_rows_after_header=2)
assert df.columns == ["foo", "bar"]
assert df.shape == (1, 2)

df = pl.read_csv(csv, skip_rows=3)
assert df.columns == ["alpha", "beta", "gamma"]
assert df.shape == (3, 3)
assert df.dtypes == [pl.Int64, pl.Float64, pl.Utf8]

df = pl.read_csv(csv, skip_rows=2, skip_rows_after_header=1)
assert df.columns == ["col1", "col2", "col3"]
assert df.shape == (3, 3)
assert df.dtypes == [pl.Int64, pl.Float64, pl.Utf8]

df = pl.scan_csv(foods_csv, skip_rows=4).collect()
assert df.columns == ["fruit", "60", "0", "11"]
assert df.shape == (23, 4)
assert df.dtypes == [pl.Utf8, pl.Int64, pl.Float64, pl.Int64]

df = pl.scan_csv(foods_csv, skip_rows_after_header=10).collect()
df = pl.scan_csv(foods_csv, skip_rows_after_header=24).collect()
assert df.columns == ["category", "calories", "fats_g", "sugars_g"]
assert df.shape == (17, 4)
assert df.shape == (3, 4)
assert df.dtypes == [pl.Utf8, pl.Int64, pl.Int64, pl.Int64]


def test_empty_string_missing_round_trip() -> None:
Expand Down

0 comments on commit 0ac48cb

Please sign in to comment.