Skip to content

Commit

Permalink
fix(rust, python): take glob into account in scan_csv 'with_schema_mo… (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 30, 2022
1 parent 27240d6 commit 02463f6
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
22 changes: 21 additions & 1 deletion polars/polars-lazy/src/frame/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,27 @@ impl<'a> LazyCsvReader<'a> {
where
F: Fn(Schema) -> PolarsResult<Schema>,
{
let mut file = std::fs::File::open(&self.path)?;
let path;
let path_str = self.path.to_string_lossy();

let mut file = if path_str.contains('*') {
let glob_err = || PolarsError::ComputeError("invalid glob pattern given".into());
let mut paths = glob::glob(&path_str).map_err(|_| glob_err())?;

match paths.next() {
Some(globresult) => {
path = globresult.map_err(|_| glob_err())?;
}
None => {
return Err(PolarsError::ComputeError(
"globbing pattern did not match any files".into(),
));
}
}
std::fs::File::open(&path)
} else {
std::fs::File::open(&self.path)
}?;
let reader_bytes = get_reader_bytes(&mut file).expect("could not mmap file");
let mut skip_rows = self.skip_rows;

Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
"foods1.csv",
)


FOODS_CSV_GLOB = os.path.join(
EXAMPLES_DIR,
"foods*.csv",
)

FOODS_PARQUET = os.path.join(
EXAMPLES_DIR,
"foods1.parquet",
Expand Down Expand Up @@ -55,6 +61,11 @@ def foods_csv() -> str:
return FOODS_CSV


@pytest.fixture
def foods_csv_glob() -> str:
return FOODS_CSV


if not os.path.isfile(FOODS_PARQUET):
pl.read_csv(FOODS_CSV).write_parquet(FOODS_PARQUET)

Expand Down
22 changes: 14 additions & 8 deletions py-polars/tests/unit/io/test_lazy_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,22 @@ def test_row_count(foods_csv: str) -> None:
assert df["foo"].to_list() == [10, 16, 21, 23, 24, 30, 35]


def test_scan_csv_schema_overwrite_and_dtypes_overwrite(foods_csv: str) -> None:
assert (
pl.scan_csv(
foods_csv,
def test_scan_csv_schema_overwrite_and_dtypes_overwrite(
foods_csv: str, foods_csv_glob: str
) -> None:
for fn in [foods_csv, foods_csv_glob]:
df = pl.scan_csv(
fn,
dtypes={"calories_foo": pl.Utf8, "fats_g_foo": pl.Float32},
with_column_names=lambda names: [f"{a}_foo" for a in names],
)
.collect()
.dtypes
) == [pl.Utf8, pl.Utf8, pl.Float32, pl.Int64]
).collect()
assert df.dtypes == [pl.Utf8, pl.Utf8, pl.Float32, pl.Int64]
assert df.columns == [
"category_foo",
"calories_foo",
"fats_g_foo",
"sugars_g_foo",
]


def test_lazy_n_rows(foods_csv: str) -> None:
Expand Down

0 comments on commit 02463f6

Please sign in to comment.