diff --git a/daskms/parallel_table.py b/daskms/parallel_table.py new file mode 100644 index 00000000..2f5ecde0 --- /dev/null +++ b/daskms/parallel_table.py @@ -0,0 +1,204 @@ +import logging +import threading +from daskms.table_proxy import (TableProxy, + TableProxyMetaClass, + _proxied_methods, + NOLOCK, + READLOCK, + WRITELOCK, + _LOCKTYPE_STRINGS, + _PROXY_DOCSTRING, + STANDARD_EXECUTOR) +from weakref import finalize, WeakValueDictionary +from daskms.utils import arg_hasher + +_table_cache = WeakValueDictionary() +_table_lock = threading.Lock() + +log = logging.getLogger(__name__) + + +_parallel_methods = [ + "getcol", + "getcolnp", + "getcolslice", + "getvarcol", + "getcell", + "getcellslice", + "getkeywords", + "getcolkeywords" +] + + +def _parallel_table_finalizer(table_cache): + + for table in table_cache.cache.values(): + table.close() + + +def proxied_method_factory(method, locktype): + """ + Proxy pyrap.tables.table.method calls. + + Creates a private implementation function which performs + the call locked according to to ``locktype``. + + The private implementation is accessed by a public ``method`` + which submits a call to the implementation + on a concurrent.futures.ThreadPoolExecutor. + """ + + if locktype == NOLOCK: + def _impl(table_future, args, kwargs): + if isinstance(table_future, TIDCache): + table = table_future.get() + else: + table = table_future.result() + + try: + return getattr(table, method)(*args, **kwargs) + except Exception: + if logging.DEBUG >= log.getEffectiveLevel(): + log.exception("Exception in %s", method) + raise + + elif locktype == READLOCK: + def _impl(table_future, args, kwargs): + if isinstance(table_future, TIDCache): + table = table_future.get() + else: + table = table_future.result() + table.lock(write=False) + + try: + return getattr(table, method)(*args, **kwargs) + except Exception: + if logging.DEBUG >= log.getEffectiveLevel(): + log.exception("Exception in %s", method) + raise + finally: + table.unlock() + + elif locktype == WRITELOCK: + def _impl(table_future, args, kwargs): + if isinstance(table_future, TIDCache): + table = table_future.get() + else: + table = table_future.result() + table.lock(write=True) + + try: + return getattr(table, method)(*args, **kwargs) + except Exception: + if logging.DEBUG >= log.getEffectiveLevel(): + log.exception("Exception in %s", method) + raise + finally: + table.unlock() + + else: + raise ValueError(f"Invalid locktype {locktype}") + + _impl.__name__ = method + "_impl" + _impl.__doc__ = ("Calls table.%s, wrapped in a %s." % + (method, _LOCKTYPE_STRINGS[locktype])) + + if method in _parallel_methods: + def public_method(self, *args, **kwargs): + """ + Submits _impl(args, kwargs) to the executor + and returns a Future + """ + return self._ex.submit(_impl, self._cached_tables, args, kwargs) + else: + def public_method(self, *args, **kwargs): + """ + Submits _impl(args, kwargs) to the executor + and returns a Future + """ + return self._ex.submit(_impl, self._table_future, args, kwargs) + + public_method.__name__ = method + public_method.__doc__ = _PROXY_DOCSTRING % method + + return public_method + + +class ParallelTableProxyMetaClass(TableProxyMetaClass): + """ + https://en.wikipedia.org/wiki/Multiton_pattern + + """ + def __new__(cls, name, bases, dct): + for method, locktype in _proxied_methods: + proxy_method = proxied_method_factory(method, locktype) + dct[method] = proxy_method + + return type.__new__(cls, name, bases, dct) + + def __call__(cls, *args, **kwargs): + key = arg_hasher((cls,) + args + (kwargs,)) + + with _table_lock: + try: + return _table_cache[key] + except KeyError: + instance = type.__call__(cls, *args, **kwargs) + _table_cache[key] = instance + return instance + + +class ParallelTableProxy(TableProxy, metaclass=ParallelTableProxyMetaClass): + + @classmethod + def _from_args_kwargs(cls, factory, args, kwargs): + """Support pickling of kwargs in ParallelTableProxy.__reduce__.""" + return cls(factory, *args, **kwargs) + + def __init__(self, factory, *args, **kwargs): + + super().__init__(factory, *args, **kwargs) + + # This is a bit hacky, as noted by Simon in TableProxy. Maybe storing + # a sanitised version would be better? + kwargs = self._kwargs.copy() + kwargs.pop("__executor_key__", STANDARD_EXECUTOR) + + self._cached_tables = TIDCache(factory, *args, **kwargs) + + finalize(self, _parallel_table_finalizer, self._cached_tables) + + def __reduce__(self): + """ Defer to _from_args_kwargs to support kwarg pickling """ + return ( + self._from_args_kwargs, + (self._factory, self._args, self._kwargs) + ) + + +class TIDCache(object): + + def __init__(self, fn, *args, **kwargs): + """A cache keyed on thread ID. + + When a key is not found, it is added with a value given by + fn(*args, **kwargs). + """ + + self.cache = {} + self.fn = fn + self.args = args + self.kwargs = kwargs + + def get(self): + """Return or create fn(*args, **args) for the calling thread.""" + + thread_id = threading.get_ident() + + try: + item = self.cache[thread_id] + except KeyError: + print(f"Opening item in {thread_id}.") + self.cache[thread_id] = item = self.fn(*self.args, **self.kwargs) + + return item diff --git a/daskms/reads.py b/daskms/reads.py index c9643634..e7aa3899 100644 --- a/daskms/reads.py +++ b/daskms/reads.py @@ -17,6 +17,7 @@ from daskms.dataset import Dataset from daskms.table_executor import executor_key from daskms.table import table_exists +from daskms.parallel_table import ParallelTableProxy from daskms.table_proxy import TableProxy, READLOCK from daskms.table_schemas import lookup_table_schema from daskms.utils import table_path_split @@ -26,106 +27,74 @@ log = logging.getLogger(__name__) -def ndarray_getcol(row_runs, table_future, column, result, dtype): +def ndarray_getcol(row_runs, table_proxy, column, result, dtype): """ Get numpy array data """ - table = table_future.result() - getcolnp = table.getcolnp + getcolnp = table_proxy.getcolnp rr = 0 - table.lock(write=False) + for rs, rl in row_runs: + getcolnp(column, result[rr:rr + rl], startrow=rs, nrow=rl).result() + rr += rl - try: - for rs, rl in row_runs: - getcolnp(column, result[rr:rr + rl], startrow=rs, nrow=rl) - rr += rl - finally: - table.unlock() - return result - - -def ndarray_getcolslice(row_runs, table_future, column, result, +def ndarray_getcolslice(row_runs, table_proxy, column, result, blc, trc, dtype): """ Get numpy array data """ - table = table_future.result() - getcolslicenp = table.getcolslicenp + getcolslicenp = table_proxy.getcolslicenp rr = 0 - table.lock(write=False) + for rs, rl in row_runs: + getcolslicenp(column, result[rr:rr + rl], + blc=blc, trc=trc, + startrow=rs, nrow=rl).result() + rr += rl - try: - for rs, rl in row_runs: - getcolslicenp(column, result[rr:rr + rl], - blc=blc, trc=trc, - startrow=rs, nrow=rl) - rr += rl - finally: - table.unlock() - return result - - -def object_getcol(row_runs, table_future, column, result, dtype): +def object_getcol(row_runs, table_proxy, column, result, dtype): """ Get object list data """ - table = table_future.result() - getcol = table.getcol + getcol = table_proxy.getcol rr = 0 - table.lock(write=False) - - try: - for rs, rl in row_runs: - data = getcol(column, rs, rl) + for rs, rl in row_runs: + data = getcol(column, rs, rl).result() - # Multi-dimensional string arrays are returned as a - # dict with 'array' and 'shape' keys. Massage the data. - if isinstance(data, dict): - data = (np.asarray(data['array'], dtype=dtype) - .reshape(data['shape'])) + # Multi-dimensional string arrays are returned as a + # dict with 'array' and 'shape' keys. Massage the data. + if isinstance(data, dict): + data = (np.asarray(data['array'], dtype=dtype) + .reshape(data['shape'])) - # NOTE(sjperkins) - # Dask wants ndarrays internally, so we asarray objects - # the returning list of objects. - # See https://github.com/ska-sa/dask-ms/issues/42 - result[rr:rr + rl] = np.asarray(data, dtype=dtype) + # NOTE(sjperkins) + # Dask wants ndarrays internally, so we asarray objects + # the returning list of objects. + # See https://github.com/ska-sa/dask-ms/issues/42 + result[rr:rr + rl] = np.asarray(data, dtype=dtype) - rr += rl - finally: - table.unlock() - - return result + rr += rl -def object_getcolslice(row_runs, table_future, column, result, +def object_getcolslice(row_runs, table_proxy, column, result, blc, trc, dtype): """ Get object list data """ - table = table_future.result() - getcolslice = table.getcolslice + getcolslice = table_proxy.getcolslice rr = 0 - table.lock(write=False) + for rs, rl in row_runs: + data = getcolslice(column, blc, trc, startrow=rs, nrow=rl).result() - try: - for rs, rl in row_runs: - data = getcolslice(column, blc, trc, startrow=rs, nrow=rl) + # Multi-dimensional string arrays are returned as a + # dict with 'array' and 'shape' keys. Massage the data. + if isinstance(data, dict): + data = (np.asarray(data['array'], dtype=dtype) + .reshape(data['shape'])) - # Multi-dimensional string arrays are returned as a - # dict with 'array' and 'shape' keys. Massage the data. - if isinstance(data, dict): - data = (np.asarray(data['array'], dtype=dtype) - .reshape(data['shape'])) + # NOTE(sjperkins) + # Dask wants ndarrays internally, so we asarray objects + # the returning list of objects. + # See https://github.com/ska-sa/dask-ms/issues/42 + result[rr:rr + rl] = np.asarray(data, dtype=dtype) - # NOTE(sjperkins) - # Dask wants ndarrays internally, so we asarray objects - # the returning list of objects. - # See https://github.com/ska-sa/dask-ms/issues/42 - result[rr:rr + rl] = np.asarray(data, dtype=dtype) - - rr += rl - finally: - table.unlock() - - return result + rr += rl def getter_wrapper(row_orders, *args): @@ -153,11 +122,15 @@ def getter_wrapper(row_orders, *args): io_fn = (object_getcolslice if np.dtype == object else ndarray_getcolslice) - # Submit table I/O on executor - future = table_proxy._ex.submit(io_fn, row_runs, - table_proxy._table_future, - column, result, - blc, trc, dtype) + io_fn( + row_runs, + table_proxy, + column, + result, + blc, + trc, + dtype + ) # In this case, the full resolution data # for each row is requested, so we defer to getcol else: @@ -165,16 +138,19 @@ def getter_wrapper(row_orders, *args): io_fn = (object_getcol if dtype == object else ndarray_getcol) - # Submit table I/O on executor - future = table_proxy._ex.submit(io_fn, row_runs, - table_proxy._table_future, - column, result, dtype) + io_fn( + row_runs, + table_proxy, + column, + result, + dtype + ) # Resort result if necessary if resort is not None: - return future.result()[resort] + return result[resort] - return future.result() + return result def _dataset_variable_factory(table_proxy, table_schema, select_cols, @@ -209,7 +185,6 @@ def _dataset_variable_factory(table_proxy, table_schema, select_cols, dict A dictionary looking like :code:`{column: (arrays, dims)}`. """ - sorted_rows, row_runs = orders dataset_vars = {"ROWID": (("row",), sorted_rows)} @@ -305,9 +280,14 @@ def __init__(self, table, select_cols, group_cols, index_cols, **kwargs): raise ValueError(f"Unhandled kwargs: {kwargs}") def _table_proxy_factory(self): - return TableProxy(pt.table, self.table_path, ack=False, - readonly=True, lockoptions='user', - __executor_key__=executor_key(self.canonical_name)) + return ParallelTableProxy( + pt.table, + self.table_path, + ack=False, + readonly=True, + lockoptions='user', + __executor_key__=executor_key(self.canonical_name) + ) def _table_schema(self): return lookup_table_schema(self.canonical_name, self.table_schema) diff --git a/daskms/table_executor.py b/daskms/table_executor.py index 653fa0cb..a046665d 100644 --- a/daskms/table_executor.py +++ b/daskms/table_executor.py @@ -38,7 +38,7 @@ def __call__(cls, key=STANDARD_EXECUTOR): class Executor(object, metaclass=ExecutorMetaClass): def __init__(self, key=STANDARD_EXECUTOR): # Initialise a single thread - self.impl = impl = cf.ThreadPoolExecutor(1) + self.impl = impl = DummyThreadPoolExecutor(1) self.key = key # Register a finaliser shutting down the @@ -54,6 +54,30 @@ def __repr__(self): __str__ = __repr__ +class DummyThreadPoolExecutor(object): + + def __init__(self, nthread): + pass + + def submit(self, fn, *args, **kwargs): + + return DummyFuture(fn(*args, **kwargs)) + + def shutdown(self, wait=True): + pass + + +class DummyFuture(object): + + def __init__(self, value): + + self.value = value + + def result(self): + + return self.value + + def executor_key(table_name): """ Product an executor key from table_name diff --git a/daskms/tests/test_table_proxy.py b/daskms/tests/test_table_proxy.py index 92c4d295..8e32e925 100644 --- a/daskms/tests/test_table_proxy.py +++ b/daskms/tests/test_table_proxy.py @@ -222,3 +222,36 @@ def _getcol(tp, column): del futures, data, u, tab_fut assert_liveness(0, 0) + + +@pytest.mark.parametrize("scheduler", + ["sync", "threads", "processes", "distributed"]) +def test_softlinks(ms, scheduler): + + import dask.array as da + from dask.distributed import Client, LocalCluster + from daskms import xds_from_ms + + if scheduler == "distributed": + + cluster = LocalCluster( + processes=True, + n_workers=2, + threads_per_worker=2, + memory_limit=0 + ) + + client = Client(cluster) # noqa + + # ms = "/home/jonathan/reductions/3C147/msdir/C147_unflagged.MS" + ms = "/home/jonathan/recipes/proxy_experiments/vla_empty.ms" + + xdsl = xds_from_ms( + ms, + group_cols=["DATA_DESC_ID", "SCAN_NUMBER", "FIELD_ID"], + chunks={"row": 30000, "chan": -1, "corr": -1} + ) + + result = [xds.DATA.data.map_blocks(lambda x: x[:1, :1, :1]) for xds in xdsl] + + da.compute(result, scheduler=scheduler, num_workers=4)