Skip to content

Commit

Permalink
fix windows test, add inline comment, further streamline
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Mar 3, 2024
1 parent 8cd5331 commit 21e4a27
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
14 changes: 7 additions & 7 deletions py-polars/polars/io/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

class _ArrowDriverProperties_(TypedDict):
# name of the method that fetches all arrow data; tuple form
# calls the fetch_all method with the give chunk size
# calls the fetch_all method with the given chunk size (int)
fetch_all: str | tuple[str, int]
# name of the method that fetches arrow data in batches
fetch_batches: str | None
Expand Down Expand Up @@ -70,6 +70,7 @@ class _ArrowDriverProperties_(TypedDict):
"repeat_batch_calls": False,
},
"kuzu": {
# 'get_as_arrow' currently takes a mandatory chunk size
"fetch_all": ("get_as_arrow", 10_000),
"fetch_batches": None,
"exact_batch_size": None,
Expand Down Expand Up @@ -180,12 +181,11 @@ def _arrow_batches(
"""Yield Arrow data in batches, or as a single 'fetchall' batch."""
fetch_batches = driver_properties["fetch_batches"]
if not iter_batches or fetch_batches is None:
fetch_method = driver_properties["fetch_all"]
if not isinstance(fetch_method, tuple):
yield getattr(self.result, fetch_method)()
else:
fetch_method, sz = fetch_method
yield getattr(self.result, fetch_method)(sz)
fetch_method, sz = driver_properties["fetch_all"], []
if isinstance(fetch_method, tuple):
fetch_method, chunk_size = fetch_method
sz = [chunk_size]
yield getattr(self.result, fetch_method)(*sz)
else:
size = batch_size if driver_properties["exact_batch_size"] else None
repeat_batch_calls = driver_properties["repeat_batch_calls"]
Expand Down
6 changes: 5 additions & 1 deletion py-polars/tests/unit/io/test_database_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,11 @@ def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None:
if (kuzu_test_db := (tmp_path / "kuzu_test.db")).exists():
kuzu_test_db.unlink()

db = kuzu.Database(str(kuzu_test_db))
test_db = str(kuzu_test_db)
if sys.platform == "win32":
test_db = test_db.replace("\\", "/")

db = kuzu.Database(test_db)
conn = kuzu.Connection(db)
conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))")
conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)")
Expand Down

0 comments on commit 21e4a27

Please sign in to comment.