Skip to content

Commit

Permalink
ensure schema update of row_count arguments (#4280)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 6, 2022
1 parent 5487704 commit 20c066e
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 2 deletions.
9 changes: 8 additions & 1 deletion polars/polars-lazy/src/frame/ipc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,19 @@ impl LazyFrame {
n_rows: args.n_rows,
cache: args.cache,
with_columns: None,
row_count: args.row_count,
row_count: None,
rechunk: args.rechunk,
memmap: args.memmap,
};
let row_count = args.row_count;
let mut lf: LazyFrame = LogicalPlanBuilder::scan_ipc(path, options)?.build().into();
lf.opt_state.file_caching = true;

// it is a bit hacky, but this row_count function updates the schema
if let Some(row_count) = row_count {
lf = lf.with_row_count(&row_count.name, Some(row_count.offset))
}

Ok(lf)
}

Expand Down
8 changes: 7 additions & 1 deletion polars/polars-lazy/src/frame/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,16 @@ impl LazyFrame {
low_memory: bool,
) -> Result<Self> {
let mut lf: LazyFrame = LogicalPlanBuilder::scan_parquet(
path, n_rows, cache, parallel, row_count, rechunk, low_memory,
path, n_rows, cache, parallel, None, rechunk, low_memory,
)?
.build()
.into();

// it is a bit hacky, but this row_count function updates the schema
if let Some(row_count) = row_count {
lf = lf.with_row_count(&row_count.name, Some(row_count.offset))
}

lf.opt_state.file_caching = true;
Ok(lf)
}
Expand Down
6 changes: 6 additions & 0 deletions py-polars/tests/io/test_lazy_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,9 @@ def test_is_in_type_coercion(foods_ipc: str) -> None:
.collect()
)
assert out.shape == (7, 1)


def test_row_count_schema(foods_ipc: str) -> None:
assert (
pl.scan_ipc(foods_ipc, row_count_name="id").select(["id", "category"]).collect()
).dtypes == [pl.UInt32, pl.Utf8]
8 changes: 8 additions & 0 deletions py-polars/tests/io/test_lazy_parquet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from os import path
from pathlib import Path

Expand Down Expand Up @@ -101,3 +102,10 @@ def test_parquet_stats(io_test_dir: str) -> None:
assert (
pl.scan_parquet(file).filter(4 < pl.col("a")).select(pl.col("a").sum())
).collect()[0, "a"] == 10.0


def test_row_count_schema(io_test_dir: str) -> None:
f = os.path.join(io_test_dir, "..", "files", "small.parquet")
assert (
pl.scan_parquet(f, row_count_name="id").select(["id", "b"]).collect()
).dtypes == [pl.UInt32, pl.Utf8]

0 comments on commit 20c066e

Please sign in to comment.