Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,21 @@ import xarray as xr
import xarray_sql as xql


# Open a year of ARCO-ERA5 — all 273 variables. Selecting a year up front
# keeps Dask's partition setup cheap before any chunks are read from GCS.
ds = (
xr.open_zarr('gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3',
chunks=dict(time=1),
storage_options={'token': 'anon'}) # Anonymous read from the public GCS bucket — no auth required.
.sel(time='2020')
# Open ARCO-ERA5 — a weather dataset with 273 variables since 1940.
# Turning off dask means we don't have to wait to construct a task graph.
ds = xr.open_zarr(
'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3',
chunks=None, # Turn dask off
storage_options={'token': 'anon'} # Anonymous read from the public GCS bucket — no auth required.
)

ctx = xql.XarrayContext()
ctx.from_dataset('era5', ds, table_names={
# Make sure to pass `chunks`!
ctx.from_dataset('era5', ds, chunks=dict(time=6), table_names={
('time', 'latitude', 'longitude'): 'surface',
('time', 'level', 'latitude', 'longitude'): 'atmosphere',
})
# Registration: ~0.5s for a full year of hourly ERA5, all variables.

# Registration takes ~10s on my machine.

# Heads up: ARCO-ERA5 has 262 surface + 11 atmospheric variables. The library
# pushes column projection down to Zarr, so SELECT only fetches what you ask
Expand Down Expand Up @@ -81,13 +80,26 @@ result = ctx.sql('''
# | 775 | -2.3064649711534457 |
# +-------+----------------------+

avg_temp_ds = result.to_dataset(dims=["level"])
# <xarray.Dataset> Size: 592B
# Dimensions: (level: 37)
ctx.sql('''
SELECT latitude, longitude, AVG("2m_temperature") - 273.15 AS avg_c
FROM era5.surface
WHERE time BETWEEN TIMESTAMP '2020-01-01'
AND TIMESTAMP '2020-01-01 05:00:00'
GROUP BY latitude, longitude
ORDER BY latitude DESC, longitude
''').to_dataset(dims=['latitude', 'longitude'], template=ds)
# <xarray.Dataset> Size: 8MB
# Dimensions: (latitude: 721, longitude: 1440)
# Coordinates:
# * level (level) int64 296B 1000 975 950 925 900 875 850 ... 20 10 7 5 3 2 1
# * latitude (latitude) float32 3kB 90.0 89.75 89.5 ... -89.5 -89.75 -90.0
# * longitude (longitude) float32 6kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8
# Data variables:
# avg_c (level) float64 296B 6.621 5.186 4.028 ... -21.51 -13.36 -9.021
# avg_c (latitude, longitude) float64 8MB -26.84 -26.84 ... -27.38 -27.38
# Attributes:
# last_updated: 2026-06-20 02:33:34.265980+00:00
# valid_time_start: 1940-01-01
# valid_time_stop: 2025-12-31
# valid_time_stop_era5t: 2026-06-14
```

_(A runnable version of this example lives at
Expand Down
63 changes: 43 additions & 20 deletions perf_tests/era5_temp_profile.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
#!/usr/bin/env python3
"""Surface and global-atmospheric temperatures on 2020-01-01, in SQL.

Two queries against ARCO-ERA5 on the morning of January 1, 2020:
Three queries against ARCO-ERA5 on the morning of January 1, 2020:

* **Surface (local).** Average 2m-temperature over a small grid covering the
New York City area for the first six hours.
* **Atmosphere (global).** Average temperature per pressure level, computed
over the entire planet for the same six hours — a classic atmospheric
temperature profile (surface around 1000 hPa is warmest, tropopause near
100 hPa is coldest).
* **Surface (global, gridded).** Average 2m-temperature per (lat, lon) cell
for the same six hours, returned as an xarray Dataset.

Both queries express their filters entirely in SQL: ``xr.open_zarr`` is given
a single calendar year and no spatial slicing. The library's table provider
prunes time partitions for ``WHERE time …`` filters, and pushes ``WHERE
All filters live in SQL: the dataset is opened with no time or spatial
slicing on the xarray side. The library's table provider prunes time
partitions for ``WHERE time …`` filters, and pushes ``WHERE
latitude/longitude …`` down to dimension columns.

ARCO-ERA5's atmospheric variables are stored in native Zarr chunks of shape
``(1, 37, 721, 1440)`` — about 150 MB per hour. We align Dask chunks to that
shape with ``chunks={'time': 1}`` so chunks fetch from GCS concurrently. The
global atmospheric query scans ~230M rows after pruning.
``(1, 37, 721, 1440)`` — about 150 MB per hour. ``chunks=dict(time=6)`` groups
six native chunks per DataFusion partition: large enough to keep partition
count (and registration time) low, small enough that a 6-hour WHERE clause
hits a single partition with no wasted I/O.

The Zarr is read anonymously from the public GCS bucket — no auth required.
"""
Expand All @@ -34,27 +37,29 @@


def main() -> None:
full = xr.open_zarr(URL, chunks=None, storage_options={"token": "anon"})

# Open a full calendar year — all 273 variables. No spatial slicing on
# the xarray side; SQL WHERE clauses below express the filters.
#
# Heads up: the library pushes column projection down to Zarr, so SELECT
# only fetches what you ask for — but `SELECT * FROM era5.surface` would
# try to read every variable across the year (terabytes from GCS).
# Always SELECT specific columns.
ds = full.sel(time="2020").chunk({"time": 1})
# Open the full ARCO-ERA5 archive — all 273 variables since 1940. No
# time or spatial slicing on the xarray side; SQL WHERE clauses below
# express the filters. Turning dask off (chunks=None) skips task-graph
# construction at open time.
ds = xr.open_zarr(URL, chunks=None, storage_options={"token": "anon"})
print(
"ARCO-ERA5 opened: year 2020, "
"ARCO-ERA5 opened: "
f"{ds.sizes['time']:,} hourly time steps, "
f"{len(ds.data_vars)} variables (no spatial pre-slicing)."
f"{len(ds.data_vars)} variables (no pre-slicing)."
)

# Heads up: ARCO-ERA5 has 262 surface + 11 atmospheric variables. The
# library pushes column projection down to Zarr, so SELECT only fetches
# what you ask for — but `SELECT * FROM era5.surface` would try to pull
# every variable across the archive (terabytes from GCS).
# ---> Always SELECT specific columns. <---
ctx = xql.XarrayContext()
t0 = time.perf_counter()
# Make sure to pass `chunks`!
ctx.from_dataset(
"era5",
ds,
chunks=dict(time=6),
table_names={
("time", "latitude", "longitude"): "surface",
("time", "level", "latitude", "longitude"): "atmosphere",
Expand Down Expand Up @@ -94,7 +99,25 @@ def main() -> None:
"""
).to_pandas()
print(profile.to_string(index=False))
print(f" ({time.perf_counter() - t0:.2f}s, ~230M rows scanned)")
print(f" ({time.perf_counter() - t0:.2f}s)")

print(
"\nAverage 2m-temperature per (lat, lon) cell, globally, "
"2020-01-01 00:00-05:00 UTC (°C):"
)
t0 = time.perf_counter()
gridded = ctx.sql(
"""
SELECT latitude, longitude, AVG("2m_temperature") - 273.15 AS avg_c
FROM era5.surface
WHERE time BETWEEN TIMESTAMP '2020-01-01'
AND TIMESTAMP '2020-01-01 05:00:00'
GROUP BY latitude, longitude
ORDER BY latitude DESC, longitude
"""
).to_dataset(dims=["latitude", "longitude"], template=ds)
print(gridded)
print(f" ({time.perf_counter() - t0:.2f}s)")


if __name__ == "__main__":
Expand Down
59 changes: 59 additions & 0 deletions tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DEFAULT_BATCH_SIZE,
_parse_schema,
block_slices,
compute_chunks,
dataset_to_record_batch,
explode,
from_map,
Expand Down Expand Up @@ -473,3 +474,61 @@ def test_read_xarray_table_memory_bounds(large_ds):
)
finally:
tracemalloc.stop()


# ---------------------------------------------------------------------------
# compute_chunks: arithmetic replacement for ds.chunk(...).chunks.
# Dask serves as the source of truth.
# ---------------------------------------------------------------------------


def _dask_chunks(ds: xr.Dataset, chunks: dict) -> dict:
rechunked = ds.copy(data=None, deep=False).chunk(chunks)
return {str(k): tuple(v) for k, v in rechunked.chunks.items()}


def _normalise(result: dict) -> dict:
return {str(k): tuple(v) for k, v in result.items()}


def _simple_ds(shape: tuple[int, ...], dims: tuple[str, ...]) -> xr.Dataset:
return xr.Dataset(
{"v": (dims, np.zeros(shape))},
coords={d: np.arange(s) for d, s in zip(dims, shape)},
)


@pytest.mark.parametrize(
"ds,chunks",
[
# Even divide on a single dim.
(_simple_ds((10,), ("x",)), {"x": 5}),
# Uneven divide: trailing remainder chunk.
(_simple_ds((10,), ("x",)), {"x": 3}),
# Requested chunk size larger than the dim → single chunk.
(_simple_ds((5,), ("x",)), {"x": 100}),
# Multi-dim spec with a dim left unspecified (kept as one chunk).
(_simple_ds((4, 6), ("x", "y")), {"x": 2}),
# Multi-dim spec rechunking every dim.
(_simple_ds((7, 11, 13), ("a", "b", "c")), {"a": 3, "b": 4, "c": 5}),
],
)
def test_compute_chunks_matches_dask(ds, chunks):
assert _normalise(compute_chunks(ds, chunks)) == _dask_chunks(ds, chunks)


def test_compute_chunks_preserves_existing_dask_chunking():
# When the dataset is already dask-backed, rechunking one dim must
# leave other dims' existing chunk tuples alone.
ds = _simple_ds((4, 5), ("x", "y")).chunk({"x": 1, "y": 2})
chunks = {"x": 2}
assert _normalise(compute_chunks(ds, chunks)) == _dask_chunks(ds, chunks)


def test_compute_chunks_tuples_sum_to_dim_size():
# Dask-independent invariant: every per-dim chunk tuple must fully
# cover its dimension.
ds = _simple_ds((7, 11, 13), ("a", "b", "c"))
result = compute_chunks(ds, {"a": 3, "b": 4, "c": 5})
for dim, tup in result.items():
assert sum(tup) == ds.sizes[dim]
38 changes: 38 additions & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,44 @@ def test_lat_filter_prunes_partitions(self):
f"Expected exactly 2 partitions for lat < 0, got {tracker.iteration_count}"
)

def test_unchunked_dim_filter_still_prunes(self):
"""Filters on an *unchunked* dim still prune via static bounds.

``read_xarray_table`` precomputes bounds for unchunked dims once
rather than re-scanning their full coord array on every partition.
Regression guard: if the static-range merge ever stops attaching
those bounds to each partition, the Rust pruner falls back to
"never prune" for the unchunked dim and reads every partition.
"""
np.random.seed(42)
time = pd.date_range("2020-01-01", periods=100, freq="D")
lat = np.linspace(-90, 90, 50) # unchunked
data = np.random.rand(100, 50).astype(np.float32)

ds = xr.Dataset(
{"temperature": (["time", "lat"], data)},
coords={"time": time, "lat": lat},
)

tracker = IterationTracker()
# Chunk only on time → lat is "static" (one chunk spanning the axis).
# Every partition still spans lat -90 to +90.
table = read_xarray_table(
ds, chunks={"time": 25}, _iteration_callback=tracker
)

ctx = SessionContext()
ctx.register_table("test", table)

# WHERE lat > 100 matches no rows. With static bounds (lat ∈ [-90, 90])
# attached to every partition, the pruner drops *all* partitions and
# the table is never iterated. Without them, all 4 are read.
ctx.sql("SELECT COUNT(*) FROM test WHERE lat > 100").collect()
assert tracker.iteration_count == 0, (
"Static lat bounds should let the pruner skip every partition; "
f"got {tracker.iteration_count} partitions read."
)

def test_no_pruning_for_data_column_filters(self, time_chunked_ds):
"""Filters on data columns (not dimensions) should not prune."""
tracker = IterationTracker()
Expand Down
Loading
Loading