Skip to content

Commit

Permalink
respect dtype overwrite when schema is overwritten in lazy csv scanner (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Mar 16, 2022
1 parent 43f7e39 commit 98a3d85
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
13 changes: 11 additions & 2 deletions polars/polars-lazy/src/frame/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,22 @@ impl<'a> LazyCsvReader<'a> {
self.delimiter,
self.infer_schema_length,
self.has_header,
self.schema_overwrite,
// we set it to None and modify them after the schema is updated
None,
&mut skip_rows,
self.comment_char,
self.quote_char,
None,
)?;
let schema = f(schema)?;
let mut schema = f(schema)?;

// the dtypes set may be for the new names, so update again
if let Some(overwrite_schema) = self.schema_overwrite {
for (name, dtype) in overwrite_schema.iter() {
schema.with_column(name.clone(), dtype.clone())
}
}

Ok(self.with_schema(Arc::new(schema)))
}

Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/io/test_lazy_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,16 @@ def test_row_count(foods_csv: str) -> None:
)

assert df["foo"].to_list() == [10, 16, 21, 23, 24, 30, 35]


def test_scan_csv_schema_overwrite_and_dtypes_overwrite(foods_csv: str) -> None:
assert (
pl.scan_csv(
foods_csv,
dtypes={"calories_foo": pl.Utf8, "fats_g_foo": pl.Float32},
with_column_names=lambda names: [f"{a}_foo" for a in names],
)
.collect()
.dtypes
== [pl.Utf8, pl.Utf8, pl.Float32, pl.Int64]
)

0 comments on commit 98a3d85

Please sign in to comment.