Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dimension ID coordinates (ROWID++) #292

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions daskms/dask_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def xds_to_table(
table_keywords=None,
column_keywords=None,
table_proxy=False,
force_shapes=None,
):
"""
Generates a list of Datasets representing a write operations from the
Expand Down Expand Up @@ -107,6 +108,7 @@ def xds_to_table(
table_keywords=table_keywords,
column_keywords=column_keywords,
table_proxy=table_proxy,
force_shapes=force_shapes,
)

# Unpack table proxy if it was requested
Expand Down
28 changes: 28 additions & 0 deletions daskms/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dask.array.core import normalize_chunks
from dask.highlevelgraph import HighLevelGraph
import numpy as np
from itertools import product, accumulate

from daskms.query import select_clause, groupby_clause, orderby_clause
from daskms.optimisation import cached_array
Expand Down Expand Up @@ -218,3 +219,30 @@ def group_row_ordering(group_order_taql, group_cols, index_cols, chunks):
)

return ordering_arrays


def multidim_locators(dim_runs):
blcs = []
trcs = []

# Convert to lists i.e. iterables we can use with itertools.
list_dim_runs = [dim_run.tolist() for dim_run in dim_runs]

for locs in product(*list_dim_runs):
blc, trc = (), ()

for start, step in locs:
blc += (start,)
trc += (start + step - 1,) # Inclusive index.
blcs.append(blc)
trcs.append(trc)

offsets = [[0] + list(accumulate([s for _, s in ldr])) for ldr in list_dim_runs]
ax_slices = [
[slice(offset[i], offset[i + 1]) for i in range(len(offset) - 1)]
for offset in offsets
]

slices = [sl for sl in product(*ax_slices)]

return blcs, trcs, slices
48 changes: 29 additions & 19 deletions daskms/reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,15 @@ def _dataset_variable_factory(

Returns
-------
dict
A dictionary looking like :code:`{column: (arrays, dims)}`.
dataset_vars: dict
A dictionary of data variables looking like :code:`{column: (arrays, dims)}`.
dataset_coords: dict
A dictionary of coordinates looking like :code:`{column: (arrays, dims)}`.
"""
JSKenyon marked this conversation as resolved.
Show resolved Hide resolved

sorted_rows, row_runs = orders
dataset_vars = {"ROWID": (("row",), sorted_rows)}
dataset_coords = {"ROWID": (("row",), sorted_rows)}
dataset_vars = {}

for column in select_cols:
try:
Expand All @@ -234,6 +237,18 @@ def _dataset_variable_factory(
continue

full_dims = ("row",) + meta.dims

for dim in full_dims:
if f"{dim.upper()}ID" not in dataset_coords:
dim_idx = meta.dims.index(dim)
dim_size = meta.shape[dim_idx]
dim_chunks = chunks.get(dim, -1)
dim_label = f"{dim.upper()}ID"
dataset_coords[dim_label] = (
(dim,),
da.arange(dim_size, dtype=np.int32, chunks=dim_chunks),
)

args = [row_runs, ("row",)]

# We only need to pass in dimension extent arrays if
Expand Down Expand Up @@ -280,7 +295,7 @@ def _dataset_variable_factory(
# Assign into variable and dimension dataset
dataset_vars[column] = (full_dims, dask_array)

return dataset_vars
return dataset_vars, dataset_coords


def _col_keyword_getter(table):
Expand Down Expand Up @@ -335,7 +350,7 @@ def _single_dataset(self, table_proxy, orders, exemplar_row=0):

table_schema = self._table_schema()
select_cols = set(self.select_cols or table_proxy.colnames().result())
variables = _dataset_variable_factory(
variables, coords = _dataset_variable_factory(
table_proxy,
table_schema,
select_cols,
Expand All @@ -345,12 +360,9 @@ def _single_dataset(self, table_proxy, orders, exemplar_row=0):
short_table_name,
)

try:
rowid = variables.pop("ROWID")
except KeyError:
# No coordinate case - this is an empty dict.
if not coords:
coords = None
else:
coords = {"ROWID": rowid}

attrs = {DASKMS_PARTITION_KEY: ()}

Expand Down Expand Up @@ -385,7 +397,7 @@ def _group_datasets(self, table_proxy, groups, exemplar_rows, orders):
array_suffix = f"[{gid_str}]-{short_table_name}"

# Create dataset variables
group_var_dims = _dataset_variable_factory(
group_var_dims, group_var_coords = _dataset_variable_factory(
table_proxy,
table_schema,
select_cols,
Expand All @@ -395,13 +407,9 @@ def _group_datasets(self, table_proxy, groups, exemplar_rows, orders):
array_suffix,
)

# Extract ROWID
try:
rowid = group_var_dims.pop("ROWID")
except KeyError:
coords = None
else:
coords = {"ROWID": rowid}
# No coordinate case - this is an empty dict.
if not group_var_coords:
group_var_coords = None

# Assign values for the dataset's grouping columns
# as attributes
Expand All @@ -414,7 +422,9 @@ def _group_datasets(self, table_proxy, groups, exemplar_rows, orders):
group_id = [gid.item() for gid in group_id]
attrs.update(zip(self.group_cols, group_id))

datasets.append(Dataset(group_var_dims, attrs=attrs, coords=coords))
datasets.append(
Dataset(group_var_dims, attrs=attrs, coords=group_var_coords)
)

return datasets

Expand Down
63 changes: 63 additions & 0 deletions daskms/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,69 @@ def _vis_factory(chan, corr):
assert_array_almost_equal(layer[(name, 2, 0, 0)]["r2"], data["r10"])


@pytest.mark.parametrize(
"chunks,selection",
[({"row": (-1,), "chan": (4,), "corr": (4,)}, {"chan": [3, 4, 5], "corr": [0, 3]})],
)
def test_write_selection_update(ms, chunks, selection):
datasets = read_datasets(
ms, ["DATA"], ["DATA_DESC_ID", "FIELD_ID", "SCAN_NUMBER"], [], chunks=chunks
)

datasets = [ds.isel(selection) for ds in datasets]

datasets = [
ds.assign({"DATA": (ds.DATA.dims, 100 * da.ones_like(ds.DATA.data))})
for ds in datasets
]

writes = write_datasets(ms, datasets, ["DATA"])

da.compute(writes)

datasets = read_datasets(
ms, ["DATA"], ["DATA_DESC_ID", "FIELD_ID", "SCAN_NUMBER"], [], chunks=chunks
)

for ds in datasets:
assert_array_equal(ds.isel(selection).DATA.data, 100)


@pytest.mark.parametrize(
"chunks,selection",
[({"row": (-1,), "chan": (4,), "corr": (4,)}, {"chan": [3, 4, 5], "corr": [0, 3]})],
)
def test_write_selection_create(ms, chunks, selection):
datasets = read_datasets(
ms, ["DATA"], ["DATA_DESC_ID", "FIELD_ID", "SCAN_NUMBER"], [], chunks=chunks
)

datasets = [ds.isel(selection) for ds in datasets]

datasets = [
ds.assign({"NEW_DATA": (ds.DATA.dims, 100 * da.ones_like(ds.DATA.data))})
for ds in datasets
]

writes = write_datasets(
ms, datasets, ["NEW_DATA"], force_shapes={"NEW_DATA": (16, 4)}
)

da.compute(writes)

datasets = read_datasets(
ms,
["NEW_DATA"],
["DATA_DESC_ID", "FIELD_ID", "SCAN_NUMBER"],
[],
chunks=chunks,
table_schema=["MS", {"NEW_DATA": {"dims": ("chan", "corr")}}],
)

for ds in datasets:
assert_array_equal(ds.isel(selection).NEW_DATA.data, 100)


def test_dataset_computes_and_values(ms):
datasets = read_datasets(ms, [], [], [])
assert len(datasets) == 1
Expand Down
Loading
Loading