Skip to content

Commit

Permalink
Add chunk specification to the katdal export (#318)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Mar 28, 2024
1 parent 32e866b commit ed44bd6
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 39 deletions.
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

X.Y.Z (YYYY-MM-DD)
------------------
* Add chunk specification to ``dask-ms katdal import`` (:pr:`318`)
* Add a ``dask-ms katdal import`` application for exporting SARAO archive data directly to zarr (:pr:`315`)
* Define dask-ms command line applications with click (:pr:`317`)
* Make poetry dev and docs groups optional (:pr:`316`)
Expand Down
31 changes: 2 additions & 29 deletions daskms/apps/convert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ast
from argparse import ArgumentTypeError
from collections import defaultdict
import logging
Expand All @@ -8,37 +7,11 @@

from daskms.apps.formats import TableFormat, CasaFormat
from daskms.fsspec_store import DaskMSStore
from daskms.utils import parse_chunks_dict

log = logging.getLogger(__name__)


class ChunkTransformer(ast.NodeTransformer):
def visit_Module(self, node):
if len(node.body) != 1 or not isinstance(node.body[0], ast.Expr):
raise ValueError("Module must contain a single expression")

expr = node.body[0]

if not isinstance(expr.value, ast.Dict):
raise ValueError("Expression must contain a dictionary")

return self.visit(expr).value

def visit_Dict(self, node):
keys = [self.visit(k) for k in node.keys]
values = [self.visit(v) for v in node.values]
return {k: v for k, v in zip(keys, values)}

def visit_Name(self, node):
return node.id

def visit_Tuple(self, node):
return tuple(self.visit(v) for v in node.elts)

def visit_Constant(self, node):
return node.n


NONUNIFORM_SUBTABLES = ["SPECTRAL_WINDOW", "POLARIZATION", "FEED", "SOURCE"]


Expand Down Expand Up @@ -88,7 +61,7 @@ def _check_exclude_columns(ctx, param, value):


def parse_chunks(ctx, param, value):
return ChunkTransformer().visit(ast.parse(value))
return parse_chunks_dict(value)


def col_converter(ctx, param, value):
Expand Down
12 changes: 10 additions & 2 deletions daskms/apps/katdal_import.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import click

from daskms.utils import parse_chunks_dict


@click.group()
@click.pass_context
Expand Down Expand Up @@ -58,10 +60,16 @@ def convert(self, value, param, ctx):
"'K,B,G'. Use 'default' for L1 + L2 and 'all' for "
"all available products.",
)
def _import(ctx, rdb_url, no_auto, pols_to_use, applycal, output_store):
@click.option(
"--chunks",
callback=lambda c, p, v: parse_chunks_dict(v),
default="{time: 10}",
help="Chunking values to apply to each dimension",
)
def _import(ctx, rdb_url, output_store, no_auto, pols_to_use, applycal, chunks):
"""Export an observation in the SARAO archive to zarr formation
RDB_URL is the SARAO archive link"""
from daskms.experimental.katdal import katdal_import

katdal_import(rdb_url, output_store, no_auto, applycal)
katdal_import(rdb_url, output_store, no_auto, applycal, chunks)
13 changes: 11 additions & 2 deletions daskms/experimental/katdal/katdal_import.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import logging
import os
import urllib

import dask

from daskms.fsspec_store import DaskMSStore
from daskms.utils import requires

log = logging.getLogger(__file__)

try:
import katdal
from katdal.dataset import DataSet
Expand All @@ -31,20 +35,25 @@ def default_output_name(url):


@requires("pip install dask-ms[katdal]", import_error)
def katdal_import(url: str, out_store: str, no_auto: bool, applycal: str):
def katdal_import(url: str, out_store: str, no_auto: bool, applycal: str, chunks: dict):
if isinstance(url, str):
dataset = katdal.open(url, appycal=applycal)
elif isinstance(url, DataSet):
dataset = url
else:
raise TypeError(f"{url} must be a string or a katdal DataSet")

facade = XarrayMSV2Facade(dataset, no_auto=no_auto)
facade = XarrayMSV2Facade(dataset, no_auto=no_auto, chunks=chunks)
main_xds, subtable_xds = facade.xarray_datasets()

if not out_store:
out_store = default_output_name(url)

out_store = DaskMSStore(out_store)
if out_store.exists():
log.warn("Removing previously existing %s", out_store)
out_store.rm("", recursive=True)

writes = [
xds_to_zarr(main_xds, out_store),
*(xds_to_zarr(ds, f"{out_store}::{k}") for k, ds in subtable_xds.items()),
Expand Down
27 changes: 25 additions & 2 deletions daskms/experimental/katdal/msv2_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,22 @@ def to_mjds(timestamp: Timestamp):
return timestamp.to_mjd() * 24 * 60 * 60


DEFAULT_TIME_CHUNKS = 100
DEFAULT_CHAN_CHUNKS = 4096
DEFAULT_CHUNKS = {"time": DEFAULT_TIME_CHUNKS, "chan": DEFAULT_CHAN_CHUNKS}


class XarrayMSV2Facade:
"""Provides a simplified xarray Dataset view over a katdal dataset"""

def __init__(self, dataset: DataSet, no_auto: bool = True, row_view: bool = True):
def __init__(
self,
dataset: DataSet,
no_auto: bool = True,
row_view: bool = True,
chunks: dict = None,
):
self._chunks = chunks or DEFAULT_CHUNKS
self._dataset = dataset
self._no_auto = no_auto
self._row_view = row_view
Expand All @@ -81,6 +93,10 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe
time_utc = dataset.timestamps
t_chunks, chan_chunks, cp_chunks = dataset.vis.dataset.chunks

# Override time and channel chunking
t_chunks = self._chunks.get("time", t_chunks)
chan_chunks = self._chunks.get("chan", chan_chunks)

# Modified Julian Date in Seconds
time_mjds = np.asarray([to_mjds(t) for t in map(Timestamp, time_utc)])

Expand Down Expand Up @@ -110,7 +126,14 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe

flags = DaskLazyIndexer(dataset.flags, (), (rechunk, flag_transpose))
weights = DaskLazyIndexer(dataset.weights, (), (rechunk, weight_transpose))
vis = DaskLazyIndexer(dataset.vis, (), transforms=(vis_transpose,))
vis = DaskLazyIndexer(
dataset.vis,
(),
transforms=(
rechunk,
vis_transpose,
),
)

time = da.from_array(time_mjds[:, None], chunks=(t_chunks, 1))
ant1 = da.from_array(cp_info.ant1_index[None, :], chunks=(1, cpi.shape[0]))
Expand Down
13 changes: 13 additions & 0 deletions daskms/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from daskms.utils import (
parse_chunks_dict,
promote_columns,
natural_order,
table_path_split,
Expand All @@ -15,6 +16,18 @@
)


def test_parse_chunks_dict():
assert parse_chunks_dict("{row: 1000}") == {"row": 1000}
assert parse_chunks_dict("{row: 1000, chan: 64}") == {"row": 1000, "chan": 64}
assert parse_chunks_dict("{row: (10, 10), chan: (4, 4)}") == {
"row": (10, 10),
"chan": (4, 4),
}

with pytest.raises(SyntaxError):
parse_chunks_dict("row:1000}")


def test_natural_order():
data = [f"{i}.parquet" for i in reversed(range(20))]
expected = [f"{i}.parquet" for i in range(20)]
Expand Down
36 changes: 32 additions & 4 deletions daskms/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
# -*- coding: utf-8 -*-

import ast
from collections import OrderedDict
import importlib.util
import logging
from pathlib import PurePath, Path
import re
import sys
import time
import inspect
import warnings

from dask.utils import funcname

# The numpy module may disappear during interpreter shutdown
# so explicitly import ndarray
from numpy import ndarray
Expand All @@ -21,6 +18,37 @@
log = logging.getLogger(__name__)


class ChunkTransformer(ast.NodeTransformer):
def visit_Module(self, node):
if len(node.body) != 1 or not isinstance(node.body[0], ast.Expr):
raise ValueError("Module must contain a single expression")

expr = node.body[0]

if not isinstance(expr.value, ast.Dict):
raise ValueError("Expression must contain a dictionary")

return self.visit(expr).value

def visit_Dict(self, node):
keys = [self.visit(k) for k in node.keys]
values = [self.visit(v) for v in node.values]
return {k: v for k, v in zip(keys, values)}

def visit_Name(self, node):
return node.id

def visit_Tuple(self, node):
return tuple(self.visit(v) for v in node.elts)

def visit_Constant(self, node):
return node.n


def parse_chunks_dict(chunks_str):
return ChunkTransformer().visit(ast.parse(chunks_str))


def natural_order(key):
return tuple(
int(c) if c.isdigit() else c.lower() for c in re.split(r"(\d+)", str(key))
Expand Down

0 comments on commit ed44bd6

Please sign in to comment.