Skip to content

Commit

Permalink
fix(rust, python): set string cache if lazy schema contains categoric…
Browse files Browse the repository at this point in the history
…al (#5225)
  • Loading branch information
ritchie46 committed Oct 16, 2022
1 parent 3f1faeb commit c2b0cdb
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
14 changes: 14 additions & 0 deletions polars/polars-io/src/csv/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,20 @@ where
let mut csv_reader = self.core_reader(Some(&schema), to_cast)?;
csv_reader.as_df()?
} else {
#[cfg(feature = "dtype-categorical")]
{
let has_cat = self
.schema
.map(|schema| {
schema
.iter_dtypes()
.any(|dtype| matches!(dtype, DataType::Categorical(_)))
})
.unwrap_or(false);
if has_cat {
_cat_lock = Some(polars_core::IUseStringCache::new())
}
}
let mut csv_reader = self.core_reader(self.schema, vec![])?;
csv_reader.as_df()?
};
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/slow/test_csv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
import os

import polars as pl

Expand All @@ -9,3 +10,13 @@ def test_csv_statistics_offset() -> None:
# the lines at the end have larger rows as the numbers increase
csv = "\n".join(str(x) for x in range(5_000))
assert pl.read_csv(io.StringIO(csv), n_rows=5000).height == 4999


def test_csv_scan_categorical() -> None:
N = 5_000
if os.name != "nt":
pl.DataFrame({"x": ["A"] * N}).write_csv("/tmp/test_csv_scan_categorical.csv")
df = pl.scan_csv(
"/tmp/test_csv_scan_categorical.csv", dtypes={"x": pl.Categorical}
).collect()
assert df["x"].dtype == pl.Categorical

0 comments on commit c2b0cdb

Please sign in to comment.