Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dimension ID coordinates (ROWID++) #292

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions daskms/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dask.array.core import normalize_chunks
from dask.highlevelgraph import HighLevelGraph
import numpy as np
from itertools import product

from daskms.query import select_clause, groupby_clause, orderby_clause
from daskms.optimisation import cached_array
Expand Down Expand Up @@ -218,3 +219,34 @@ def group_row_ordering(group_order_taql, group_cols, index_cols, chunks):
)

return ordering_arrays


def multidim_locators(dim_runs):

blcs = []
trcs = []

# Convert to lists i.e. iterables we can use with itertools.
list_dim_runs = [dim_run.tolist() for dim_run in dim_runs]

for locs in product(*list_dim_runs):
blc, trc = (), ()

for start, step in locs:
blc += (start,)
trc += (start + step - 1,) # Inclusive index.
blcs.append(blc)
trcs.append(trc)

offsets = [
np.cumsum([0, *[s for _, s in ldr]]).tolist() for ldr in list_dim_runs
]
JSKenyon marked this conversation as resolved.
Show resolved Hide resolved
ax_slices = [
[
slice(offset[i], offset[i + 1]) for i in range(len(offset) - 1)
] for offset in offsets
]

slices = [sl for sl in product(*ax_slices)]

return blcs, trcs, slices
42 changes: 25 additions & 17 deletions daskms/reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def _dataset_variable_factory(
"""
JSKenyon marked this conversation as resolved.
Show resolved Hide resolved

sorted_rows, row_runs = orders
dataset_vars = {"ROWID": (("row",), sorted_rows)}
dataset_coords = {"ROWID": (("row",), sorted_rows)}
dataset_vars = {}

for column in select_cols:
try:
Expand All @@ -234,6 +235,18 @@ def _dataset_variable_factory(
continue

full_dims = ("row",) + meta.dims

for dim in full_dims:
if f"{dim.upper()}ID" not in dataset_coords:
dim_idx = meta.dims.index(dim)
dim_size = meta.shape[dim_idx]
dim_chunks = chunks.get(dim, -1)
dim_label = f"{dim.upper()}ID"
dataset_coords[dim_label] = (
(dim,),
da.arange(dim_size, dtype=np.int32, chunks=dim_chunks),
)

args = [row_runs, ("row",)]

# We only need to pass in dimension extent arrays if
Expand Down Expand Up @@ -280,7 +293,7 @@ def _dataset_variable_factory(
# Assign into variable and dimension dataset
dataset_vars[column] = (full_dims, dask_array)

return dataset_vars
return dataset_vars, dataset_coords


def _col_keyword_getter(table):
Expand Down Expand Up @@ -335,7 +348,7 @@ def _single_dataset(self, table_proxy, orders, exemplar_row=0):

table_schema = self._table_schema()
select_cols = set(self.select_cols or table_proxy.colnames().result())
variables = _dataset_variable_factory(
variables, coords = _dataset_variable_factory(
table_proxy,
table_schema,
select_cols,
Expand All @@ -345,12 +358,9 @@ def _single_dataset(self, table_proxy, orders, exemplar_row=0):
short_table_name,
)

try:
rowid = variables.pop("ROWID")
except KeyError:
# No coordinate case - this is an empty dict.
if not coords:
coords = None
else:
coords = {"ROWID": rowid}

attrs = {DASKMS_PARTITION_KEY: ()}

Expand Down Expand Up @@ -385,7 +395,7 @@ def _group_datasets(self, table_proxy, groups, exemplar_rows, orders):
array_suffix = f"[{gid_str}]-{short_table_name}"

# Create dataset variables
group_var_dims = _dataset_variable_factory(
group_var_dims, group_var_coords = _dataset_variable_factory(
table_proxy,
table_schema,
select_cols,
Expand All @@ -395,13 +405,9 @@ def _group_datasets(self, table_proxy, groups, exemplar_rows, orders):
array_suffix,
)

# Extract ROWID
try:
rowid = group_var_dims.pop("ROWID")
except KeyError:
coords = None
else:
coords = {"ROWID": rowid}
# No coordinate case - this is an empty dict.
if not group_var_coords:
group_var_coords = None

# Assign values for the dataset's grouping columns
# as attributes
Expand All @@ -414,7 +420,9 @@ def _group_datasets(self, table_proxy, groups, exemplar_rows, orders):
group_id = [gid.item() for gid in group_id]
attrs.update(zip(self.group_cols, group_id))

datasets.append(Dataset(group_var_dims, attrs=attrs, coords=coords))
datasets.append(
Dataset(group_var_dims, attrs=attrs, coords=group_var_coords)
)

return datasets

Expand Down
57 changes: 36 additions & 21 deletions daskms/writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
import numpy as np
import pyrap.tables as pt

from daskms.columns import dim_extents_array
from daskms.constants import DASKMS_PARTITION_KEY
from daskms.dataset import Dataset
from daskms.dataset_schema import DatasetSchema
from daskms.descriptors.builder import AbstractDescriptorBuilder
from daskms.descriptors.builder_factory import filename_builder_factory
from daskms.descriptors.builder_factory import string_builder_factory
from daskms.ordering import row_run_factory
from daskms.ordering import row_run_factory, multidim_locators
from daskms.table import table_exists
from daskms.table_executor import executor_key
from daskms.table_proxy import TableProxy, WRITELOCK
Expand Down Expand Up @@ -68,17 +67,23 @@ def multidim_str_putcol(row_runs, table_future, column, data):
table.unlock()


def ndarray_putcolslice(row_runs, blc, trc, table_future, column, data):
def ndarray_putcolslice(row_runs, dim_runs, table_future, column, data):
"""Put data into the table"""
table = table_future.result()
putcolslice = table.putcolslice
rr = 0

blcs, trcs, slices = multidim_locators(dim_runs)

table.lock(write=True)

try:
for rs, rl in row_runs:
putcolslice(column, data[rr : rr + rl], blc, trc, startrow=rs, nrow=rl)
row_slice = slice(rr, rr + rl)
for blc, trc, sl in zip(blcs, trcs, slices):
putcolslice(
column, data[(row_slice, *sl)], blc, trc, startrow=rs, nrow=rl
)
rr += rl

table.flush()
Expand All @@ -87,20 +92,24 @@ def ndarray_putcolslice(row_runs, blc, trc, table_future, column, data):
table.unlock()


def multidim_str_putcolslice(row_runs, blc, trc, table_future, column, data):
def multidim_str_putcolslice(row_runs, dim_runs, table_future, column, data):
"""Put multidimensional string data into the table"""
table = table_future.result()
putcol = table.putcol
putcol = table.putcolslice
rr = 0

blcs, trcs, slices = multidim_locators(dim_runs)

table.lock(write=True)

try:
for rs, rl in row_runs:
# Construct a dict with the shape and a flattened list
chunk = data[rr : rr + rl]
chunk = {"shape": chunk.shape, "array": chunk.ravel().tolist()}
putcol(column, chunk, blc, trc, startrow=rs, nrow=rl)
row_slice = slice(rr, rr + rl)
for blc, trc, sl in zip(blcs, trcs, slices):
# Construct a dict with the shape and a flattened list
chunk = data[(row_slice, *sl)]
chunk = {"shape": chunk.shape, "array": chunk.ravel().tolist()}
putcol(column, chunk, blc, trc, startrow=rs, nrow=rl)
rr += rl

table.flush()
Expand Down Expand Up @@ -204,16 +213,17 @@ def putter_wrapper(row_orders, *args):

# There are other dimensions beside row
if nextent_args > 0:
blc, trc = zip(*args[:nextent_args])
dim_runs = [dim_run for dim_run, _ in args[:nextent_args]]
fn = (
multidim_str_putcolslice
if multidim_str
else multidim_dict_putvarcol
if dict_data
else ndarray_putcolslice
)
# TODO: Signature of fn is not consistent across all three calls.
table_proxy._ex.submit(
fn, row_runs, blc, trc, table_proxy._table_future, column, data
fn, row_runs, dim_runs, table_proxy._table_future, column, data
).result()
else:
fn = (
Expand Down Expand Up @@ -520,7 +530,6 @@ def cached_row_order(rowid):
if not layer_name.startswith("row-") and not layer_name.startswith(
"group-rows-"
):

log.warning(
"Unusual ROWID layer %s. "
"This is probably OK but "
Expand All @@ -537,7 +546,6 @@ def cached_row_order(rowid):
layer_names[0].startswith("group-rows-")
and layer_names[1].startswith("rechunk-merge-")
):

log.warning(
"Unusual ROWID layers %s for "
"the group ordering case. "
Expand Down Expand Up @@ -661,13 +669,20 @@ def _write_datasets(
f"ROWID shape and/or chunking does not match that of {column}"
)

if not all(len(c) == 1 for c in array.chunks[1:]):
# Add extent arrays
for d, c in zip(full_dims[1:], array.chunks[1:]):
extent_array = dim_extents_array(d, c)
args.append(extent_array)
inlinable_arrays.append(extent_array)
args.append((d,))
for dim in full_dims[1:]:
dim_idx = full_dims.index(dim)
dim_size = variable.shape[dim_idx]
dim_chunks = variable.chunks[dim_idx]
dim_label = f"{dim.upper()}ID"
dim_array = variable.coords.get(
dim_label, da.arange(dim_size, dtype=np.int32, chunks=dim_chunks)
)
dim_runs = getattr(dim_array, 'data', dim_array).map_blocks(
row_run_factory, sort=False, dtype=object
)
args.append(dim_runs)
args.append((dim,))
# inlinable_arrays.append(dim_runs) # TODO: Seems to break things.

# Add other variables
args.extend([table_proxy, None, column, None, array, full_dims])
Expand Down
Loading