Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support table keywords #58

Merged
merged 10 commits into from Sep 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.rst
Expand Up @@ -5,6 +5,7 @@ History
0.2.0 (YYYY-MM-DD)
------------------

* Support table and column keywords (:pr:`58`)
* Support concurrent access of multiple independent tables (:pr:`57`)
* Fix WEIGHT_SPECTRUM schema dimensions (:pr:`56`)
* Pin python-casacore to 3.0.0 (:pr:`54`)
Expand Down
42 changes: 40 additions & 2 deletions daskms/dask_ms.py
Expand Up @@ -23,7 +23,8 @@
log = logging.getLogger(__name__)


def xds_to_table(xds, table_name, columns, descriptor=None):
def xds_to_table(xds, table_name, columns, descriptor=None,
table_keywords=None, column_keywords=None):
"""
Generates a dask array representing a series of writes from the
specified arrays in :class:`xarray.Dataset`'s into
Expand All @@ -37,8 +38,10 @@ def xds_to_table(xds, table_name, columns, descriptor=None):
dataset(s) containing the specified columns. If a list of datasets
is provided, the concatenation of the columns in
sequential datasets will be written.

table_name : str
CASA table path

columns : tuple or list or "ALL"
list of column names to write to the table.

Expand All @@ -56,6 +59,15 @@ def xds_to_table(xds, table_name, columns, descriptor=None):

If None, defaults are used.

table_keywords : dict, optional
Dictionary of table keywords to add to existing keywords.
The operation is performed immediately, not lazily.

column_keywords : dict, optional
Dictionary of :code:`{column: keywords}` to add to existing
column keywords.
The operation is performed immediately, not lazily.

Returns
-------
writes : :class:`dask.array.Array`
Expand Down Expand Up @@ -97,7 +109,9 @@ def xds_to_table(xds, table_name, columns, descriptor=None):

# Write the datasets
return write_datasets(table_name, datasets, columns,
descriptor=descriptor)
descriptor=descriptor,
table_keywords=table_keywords,
column_keywords=column_keywords)


def xds_from_table(table_name, columns=None,
Expand Down Expand Up @@ -197,6 +211,14 @@ def xds_from_table(table_name, columns=None,

["MS", {"UVW": {'dims': ('my-uvw',)}}]

table_keywords : {False, True}, optional
If True, returns table keywords.
Changes return type of the function into a tuple

column_keywords : {False, True}, optional
If True return keywords for each column on the table
Changes return type of the function into a tuple

taql_where : str, optional
TAQL where clause. For example, to exclude auto-correlations

Expand All @@ -219,6 +241,11 @@ def xds_from_table(table_name, columns=None,
-------
datasets : list of :class:`xarray.Dataset`
datasets for each group, each ordered by indexing columns
table_keywords : dict, optional
Returned if ``table_keywords==True``
column_keywords : dict, optional
return if ``column_keywords==True``

"""
columns = promote_columns(columns, [])
index_cols = promote_columns(index_cols, [])
Expand All @@ -234,6 +261,14 @@ def xds_from_table(table_name, columns=None,

xarray_datasets = []

# Extract dataset list in case of table_keyword and column_keyword returns
if isinstance(dask_datasets, tuple):
extra = dask_datasets[1:]
dask_datasets = dask_datasets[0]
else:
extra = ()

# Convert each dask dataset into an xarray dataset
for ds in dask_datasets:
data_vars = collections.OrderedDict()
coords = collections.OrderedDict()
Expand All @@ -248,6 +283,9 @@ def xds_from_table(table_name, columns=None,
attrs=dict(ds.attrs),
coords=coords))

if len(extra) > 0:
return (xarray_datasets,) + extra

return xarray_datasets


Expand Down
33 changes: 27 additions & 6 deletions daskms/reads.py
Expand Up @@ -256,6 +256,11 @@ def _dataset_variable_factory(table_proxy, table_schema, select_cols,
return dataset_vars


def _col_keyword_getter(table):
""" Gets column keywords for all columns in table """
return {c: table.getcolkeywords(c) for c in table.colnames()}


class DatasetFactory(object):
def __init__(self, table, select_cols, group_cols, index_cols, **kwargs):
if not table_exists(table):
Expand All @@ -276,6 +281,8 @@ def __init__(self, table, select_cols, group_cols, index_cols, **kwargs):
self.chunks = chunks
self.table_schema = kwargs.pop('table_schema', None)
self.taql_where = kwargs.pop('taql_where', '')
self.table_keywords = kwargs.pop('table_keywords', False)
self.column_keywords = kwargs.pop('column_keywords', False)

if len(kwargs) > 0:
raise ValueError("Unhandled kwargs: %s" % kwargs)
Expand All @@ -292,7 +299,6 @@ def _single_dataset(self, orders, exemplar_row=0):
table_proxy = self._table_proxy()
table_schema = self._table_schema()
select_cols = set(self.select_cols or table_proxy.colnames().result())

variables = _dataset_variable_factory(table_proxy, table_schema,
select_cols, exemplar_row,
orders, self.chunks[0],
Expand Down Expand Up @@ -353,6 +359,7 @@ def _group_datasets(self, groups, exemplar_rows, orders):
# Assign values for the dataset's grouping columns
# as attributes
attrs = dict(zip(self.group_cols, group_id))

datasets.append(Dataset(group_var_dims, attrs=attrs,
coords=coords))

Expand All @@ -366,7 +373,7 @@ def datasets(self):
order_taql = ordering_taql(table_proxy, self.index_cols,
self.taql_where)
orders = row_ordering(order_taql, self.index_cols, self.chunks[0])
return [self._single_dataset(orders)]
datasets = [self._single_dataset(orders)]
# Group by row
elif len(self.group_cols) == 1 and self.group_cols[0] == "__row__":
order_taql = ordering_taql(table_proxy, self.index_cols,
Expand All @@ -386,9 +393,9 @@ def datasets(self):
# dataset as an attribute
np_sorted_row = sorted_rows.compute()

return [self._single_dataset((row_blocks[r], run_blocks[r]),
exemplar_row=er)
for r, er in enumerate(np_sorted_row)]
datasets = [self._single_dataset((row_blocks[r], run_blocks[r]),
exemplar_row=er)
for r, er in enumerate(np_sorted_row)]
# Grouping column case
else:
order_taql = group_ordering_taql(table_proxy, self.group_cols,
Expand All @@ -400,7 +407,21 @@ def datasets(self):
exemplar_rows = order_taql.getcol("__firstrow__").result()
assert len(orders) == len(exemplar_rows)

return self._group_datasets(groups, exemplar_rows, orders)
datasets = self._group_datasets(groups, exemplar_rows, orders)

ret = (datasets,)

if self.table_keywords is True:
ret += (table_proxy.getkeywords().result(),)

if self.column_keywords is True:
keywords = table_proxy.submit(_col_keyword_getter, READLOCK)
ret += (keywords.result(),)

if len(ret) == 1:
return ret[0]

return ret


def read_datasets(ms, columns, group_cols, index_cols, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions daskms/table_proxy.py
Expand Up @@ -46,12 +46,14 @@
("getvarcol", READLOCK),
("getcell", READLOCK),
("getcellslice", READLOCK),
("getkeywords", READLOCK),
("getcolkeywords", READLOCK),
# Writes
("putcol", WRITELOCK),
("putcolnp", WRITELOCK),
("putvarcol", WRITELOCK),
("putcellslice", WRITELOCK),
("putkeywords", WRITELOCK),
("putcolkeywords", WRITELOCK)]


Expand Down
19 changes: 8 additions & 11 deletions daskms/tests/test_dataset.py
Expand Up @@ -122,8 +122,7 @@ def test_dataset_updates(ms, select_cols,

# Create write operations and execute them
for i, ds in enumerate(datasets):
state_attrs = {"keywords": {"state-%d" % i: "foo"}}
state_var = (("row",), ds.STATE_ID.data + 1, state_attrs)
state_var = (("row",), ds.STATE_ID.data + 1)
data_var = (("row", "chan", "corr"), ds.DATA.data + 1, {})
states.append(state_var[1])
datas.append(data_var[1])
Expand All @@ -144,12 +143,9 @@ def test_dataset_updates(ms, select_cols,
datasets = read_datasets(ms, select_cols, group_cols,
index_cols, chunks=chunks)

expected_kws = {"state-%d" % i: "foo" for i in range(len(datasets))}

for i, (ds, state, data) in enumerate(zip(datasets, states, datas)):
assert_array_equal(ds.STATE_ID.data, state)
assert_array_equal(ds.DATA.data, data)
assert ds.STATE_ID.attrs['keywords'] == expected_kws

del ds, datasets
assert_liveness(0, 0)
Expand Down Expand Up @@ -282,12 +278,13 @@ def test_dataset_add_column(ms, dtype):
# Create the dask array
bitflag = da.zeros_like(ds.DATA.data, dtype=dtype)
# Assign keyword attribute
bitflag_attrs = {"keywords": {'FLAGSETS': 'legacy,cubical',
'FLAGSET_legacy': 1,
'FLAGSET_cubical': 2}}
col_kw = {"BITFLAG": {'FLAGSETS': 'legacy,cubical',
'FLAGSET_legacy': 1,
'FLAGSET_cubical': 2}}
# Assign variable onto the dataset
nds = ds.assign(BITFLAG=(("row", "chan", "corr"), bitflag, bitflag_attrs))
writes = write_datasets(ms, nds, ["BITFLAG"], descriptor='ratt_ms')
nds = ds.assign(BITFLAG=(("row", "chan", "corr"), bitflag))
writes = write_datasets(ms, nds, ["BITFLAG"], descriptor='ratt_ms',
column_keywords=col_kw)

dask.compute(writes)

Expand All @@ -296,7 +293,7 @@ def test_dataset_add_column(ms, dtype):

with pt.table(ms, readonly=False, ack=False, lockoptions='auto') as T:
bf = T.getcol("BITFLAG")
assert T.getcoldesc("BITFLAG")['keywords'] == bitflag_attrs['keywords']
assert T.getcoldesc("BITFLAG")['keywords'] == col_kw['BITFLAG']
assert bf.dtype == dtype


Expand Down
81 changes: 81 additions & 0 deletions daskms/tests/test_keywords.py
@@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import dask
import pyrap.tables as pt
import pytest

from daskms.example_data import example_ms
from daskms import xds_to_table, xds_from_ms


@pytest.fixture(scope='module')
def keyword_ms():
try:
yield example_ms()
finally:
pass


@pytest.mark.parametrize("table_kw", [True, False])
@pytest.mark.parametrize("column_kw", [True, False])
def test_keyword_read(keyword_ms, table_kw, column_kw):
# Create an example MS
with pt.table(keyword_ms, ack=False, readonly=True) as T:
desc = T._getdesc(actual=True)

ret = xds_from_ms(keyword_ms,
table_keywords=table_kw,
column_keywords=column_kw)

if isinstance(ret, tuple):
ret_pos = 1

if table_kw is True:
assert desc["_keywords_"] == ret[ret_pos]
ret_pos += 1

if column_kw is True:
colkw = ret[ret_pos]

for column, keywords in colkw.items():
assert desc[column]['keywords'] == keywords

ret_pos += 1
else:
assert table_kw is False
assert column_kw is False
assert isinstance(ret, list)


def test_keyword_write(ms):
datasets = xds_from_ms(ms)

# Add to table keywords
writes = xds_to_table([], ms, [], table_keywords={'bob': 'qux'})
dask.compute(writes)

with pt.table(ms, ack=False, readonly=True) as T:
assert T.getkeywords()['bob'] == 'qux'

# Add to column keywords
writes = xds_to_table(datasets, ms, [],
column_keywords={'STATE_ID': {'bob': 'qux'}})
dask.compute(writes)

with pt.table(ms, ack=False, readonly=True) as T:
assert T.getcolkeywords("STATE_ID")['bob'] == 'qux'

# Remove from column and table keywords
from daskms.writes import DELKW
writes = xds_to_table(datasets, ms, [],
table_keywords={'bob': DELKW},
column_keywords={'STATE_ID': {'bob': DELKW}})
dask.compute(writes)

with pt.table(ms, ack=False, readonly=True) as T:
assert 'bob' not in T.getkeywords()
assert 'bob' not in T.getcolkeywords("STATE_ID")