Skip to content

Commit

Permalink
fix(python): Check if BatchedCsvReader.next_batches() is None befor… (
Browse files Browse the repository at this point in the history
  • Loading branch information
ghuls committed Oct 19, 2022
1 parent f8d60ee commit 121d16e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
15 changes: 9 additions & 6 deletions py-polars/polars/internals/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,12 @@ def next_batches(self, n: int) -> list[pli.DataFrame] | None:
"""
batches = self._reader.next_batches(n)
if self.new_columns is not None and batches is not None:
return [
pli._update_columns(pli.wrap_df(df), self.new_columns) for df in batches
]
else:
return [pli.wrap_df(df) for df in batches]
if batches is not None:
if self.new_columns:
return [
pli._update_columns(pli.wrap_df(df), self.new_columns)
for df in batches
]
else:
return [pli.wrap_df(df) for df in batches]
return None
21 changes: 20 additions & 1 deletion py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
import polars as pl
from polars import DataType
from polars.internals.type_aliases import TimeUnit
from polars.testing import assert_frame_equal_local_categoricals, assert_series_equal
from polars.testing import (
assert_frame_equal,
assert_frame_equal_local_categoricals,
assert_series_equal,
)


def test_quoted_date() -> None:
Expand Down Expand Up @@ -874,6 +878,21 @@ def test_batched_csv_reader(foods_csv: str) -> None:
}


def test_batched_csv_reader_all_batches(foods_csv: str) -> None:
for new_columns in [None, ["Category", "Calories", "Fats_g", "Augars_g"]]:
out = pl.read_csv(foods_csv, new_columns=new_columns)
reader = pl.read_csv_batched(foods_csv, new_columns=new_columns, batch_size=4)
batches = reader.next_batches(5)
batched_dfs = []

while batches:
batched_dfs.extend(batches)
batches = reader.next_batches(5)

batched_concat_df = pl.concat(batched_dfs, rechunk=True)
assert_frame_equal(out, batched_concat_df)


def test_csv_single_categorical_null() -> None:
f = io.BytesIO()
pl.DataFrame(
Expand Down

0 comments on commit 121d16e

Please sign in to comment.