Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dask ms profiling #114

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions daskms/dask_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def xds_to_table(xds, table_name, columns, descriptor=None,
else:
raise TypeError("Invalid Dataset type '%s'" % type(ds))

# print("Table Keywords - dask_ms", type(table_keywords), table_keywords)
# Write the datasets
out_ds = write_datasets(table_name, datasets, columns,
descriptor=descriptor,
Expand Down
115 changes: 103 additions & 12 deletions daskms/table_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import logging
from threading import Lock
import weakref
import os

from time import time

from dask.base import normalize_token
import pyrap.tables as pt
Expand All @@ -15,6 +18,18 @@
_table_cache = weakref.WeakValueDictionary()
_table_lock = Lock()

# Environment variable to check for profiling enabling
# DASK_MS_PROFILE --> [True, False]
if 'DASK_MS_PROFILE' in os.environ:
dask_ms_profile = os.environ.get('DASK_MS_PROFILE')
else:
dask_ms_profile = False

# Dictionary to store runtimes
# {function_name: list(execution_time, call_count)}
# Storing execution_time and call_count is still a problem
_function_runs = {}

# CASA Table Locking Modes
NOLOCK = 0
READLOCK = 1
Expand Down Expand Up @@ -85,39 +100,76 @@ def proxied_method_factory(method, locktype):
if locktype == NOLOCK:
def _impl(table_future, args, kwargs):
try:
return getattr(table_future.result(), method)(*args, **kwargs)
_impl.calls += 1
_impl.akwargs.append((args, kwargs))
start_time = time()
result = getattr(table_future.result(), method)(*args, **kwargs)
end_time = time()
_impl.run_time.append(end_time - start_time)
return result
except Exception:
if logging.DEBUG >= log.getEffectiveLevel():
log.exception("Exception in %s", method)
raise
_function_runs[method] = (_impl.run_time, _impl.akwargs,
_impl.calls)
_impl.calls = 0
_impl.akwargs = []
_impl.run_time = []


elif locktype == READLOCK:
def _impl(table_future, args, kwargs):
table = table_future.result()
table.lock(write=False)

try:
return getattr(table, method)(*args, **kwargs)
_impl.calls += 1
_impl.akwargs.append((args, kwargs))
start_time = time()
result = getattr(table, method)(*args, **kwargs)
end_time = time()
_impl.run_time.append(end_time - start_time)
return result
except Exception:
if logging.DEBUG >= log.getEffectiveLevel():
log.exception("Exception in %s", method)
raise
raise
finally:
table.unlock()
_function_runs[method] = (_impl.run_time, _impl.akwargs,
_impl.calls)

_impl.calls = 0
_impl.akwargs = []
_impl.run_time = []


elif locktype == WRITELOCK:
def _impl(table_future, args, kwargs):
table = table_future.result()
table.lock(write=True)

try:
return getattr(table, method)(*args, **kwargs)
_impl.calls += 1
_impl.akwargs.append((args, kwargs))
start_time = time()
result = getattr(table, method)(*args, **kwargs)
end_time = time()
_impl.run_time.append(end_time - start_time)
return result
except Exception:
if logging.DEBUG >= log.getEffectiveLevel():
log.exception("Exception in %s", method)
raise
finally:
table.unlock()
_function_runs[method] = (_impl.run_time, _impl.akwargs,
_impl.calls)

_impl.calls = 0
_impl.akwargs = []
_impl.run_time = []

else:
raise ValueError("Invalid locktype %s" % locktype)
Expand Down Expand Up @@ -193,37 +245,73 @@ def taql_factory(query, style='Python', tables=(), readonly=True):
for t in tables:
t.unlock()


def _nolock_runner(table_future, fn, args, kwargs):
def _nolock_runner(table_future, fn, *args, **kwargs):
"""
_nolock_runner wrapper with profiling
"""
try:
return fn(table_future.result(), *args, **kwargs)
_nolock_runner.calls += 1
_nolock_runner.akwargs.append((args, kwargs))
start_time = time()
result = fn(table_future.result(), *args, **kwargs)
end_time = time()
_nolock_runner.run_time.append(end_time - start_time)
_function_runs[fn.__name__] = (_nolock_runner.run_time, _nolock_runner.akwargs,
_nolock_runner.calls)
return result
except Exception:
if logging.DEBUG >= log.getEffectiveLevel():
log.exception("Exception in %s", fn.__name__)
raise
_nolock_runner.calls = 0
_nolock_runner.akwargs = []
_nolock_runner.run_time = []


def _readlock_runner(table_future, fn, args, kwargs):
"""
_readlock_runner wrapper with profiling
"""
table = table_future.result()
table.lock(write=False)

try:
return fn(table, *args, **kwargs)
_readlock_runner.calls += 1
_readlock_runner.akwargs.append((args, kwargs))
start_time = time()
result = fn(table_future.result(), *args, **kwargs)
end_time = time()
_readlock_runner.run_time.append(end_time - start_time)
_function_runs[fn.__name__] = (_readlock_runner.run_time, _readlock_runner.akwargs,
_readlock_runner.calls)
return result
except Exception:
if logging.DEBUG >= log.getEffectiveLevel():
log.exception("Exception in %s", fn.__name__)
raise
finally:
table.unlock()

_readlock_runner.calls = 0
_readlock_runner.akwargs = []
_readlock_runner.run_time = []

def _writelock_runner(table_future, fn, args, kwargs):
def _writelock_runner(table_future, fn, *args, **kwargs):
"""
_writelock_runner wrapper with profiling
"""
table = table_future.result()
table.lock(write=True)

try:
result = fn(table, *args, **kwargs)
table.flush()
_writelock_runner.calls += 1
_writelock_runner.akwargs.append((args, kwargs))
start_time = time()
result = fn(table_future.result(), *args, **kwargs)
end_time = time()
_writelock_runner.run_time.append(end_time - start_time)
_function_runs[fn.__name__] = (_writelock_runner.run_time,_writelock_runner.akwargs,
_writelock_runner.calls)
return result
except Exception:
if logging.DEBUG >= log.getEffectiveLevel():
log.exception("Exception in %s", fn.__name__)
Expand All @@ -233,6 +321,9 @@ def _writelock_runner(table_future, fn, args, kwargs):
finally:
table.unlock()

_writelock_runner.calls = 0
_writelock_runner.akwargs = []
_writelock_runner.run_time = []

def _iswriteable(table_future):
return table_future.result().iswritable()
Expand Down
160 changes: 160 additions & 0 deletions daskms/tests/test_profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# -*- coding: utf-8 -*-
from pprint import pprint

import dask
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph
import dask.array as da
import numpy as np
from numpy.testing import assert_array_equal
import pyrap.tables as pt
import pytest

from daskms.table_proxy import TableProxy, taql_factory, _function_runs
from daskms.query import orderby_clause, where_clause
from daskms.utils import (group_cols_str, index_cols_str, select_cols_str,
assert_liveness, table_path_split)
from daskms.dask_ms import (xds_from_ms, xds_from_table, xds_to_table)


@pytest.mark.parametrize('group_cols', [
["DATA_DESC_ID", "SCAN_NUMBER"]], ids=group_cols_str)
@pytest.mark.parametrize('index_cols', [
["TIME", "ANTENNA1", "ANTENNA2"]], ids=index_cols_str)
@pytest.mark.parametrize('select_cols', [
['TIME', 'DATA']], ids=select_cols_str)
def test_readlock_profiling(ms, group_cols, index_cols, select_cols):

xds = xds_from_ms(ms, columns=select_cols, group_cols=group_cols,
index_cols=index_cols)

assert ('getcol' and 'getcoldesc') in _function_runs.keys()
# 2 - __tablerows__ , __firstrow__
assert _function_runs['getcol'][2] == 2 + len(group_cols)
assert _function_runs['getcoldesc'][2] == len(select_cols) * len(xds)

pprint(_function_runs)
getcol_runs = _function_runs['getcol'][2]
getcoldesc_runs = _function_runs['getcoldesc'][2]

order = orderby_clause(index_cols)
np_column_data = []

with TableProxy(pt.table, ms, lockoptions='auto', ack=False) as T:
for ds in xds:
assert "ROWID" in ds.coords
group_col_values = [ds.attrs[a] for a in group_cols]
where = where_clause(group_cols, group_col_values)
query = "SELECT * FROM $1 %s %s" % (where, order)
with TableProxy(taql_factory, query, tables=[T]) as Q:
column_data = {c: Q.getcol(c).result() for c in select_cols}
np_column_data.append(column_data)

assert _function_runs['getcol'][2] == getcol_runs + len(select_cols)*len(xds)
assert _function_runs['getcoldesc'][2] == getcoldesc_runs
del T

for ds, column_data in zip(xds, np_column_data):
for c in select_cols:
dask_data = ds.data_vars[c].data.compute(scheduler='single-threaded')
assert_array_equal(column_data[c], dask_data)

# 1 - __tablerow__
assert _function_runs['getcellslice'][2] == (1 + len(index_cols)) * (len(xds) * len(select_cols))
# we are reading DATA column ???


@pytest.mark.parametrize('group_cols', [
["SCAN_NUMBER"]], ids=group_cols_str)
@pytest.mark.parametrize('index_cols', [
["TIME", "ANTENNA1", "ANTENNA2"]], ids=index_cols_str)
@pytest.mark.parametrize('select_cols', [
['DATA', 'STATE_ID']])
def test_writelock_profiling(ms, group_cols, index_cols, select_cols):
# Zero everything to be sure
with TableProxy(pt.table, ms, readonly=False,
lockoptions='auto', ack=False) as T:
nrows = T.nrows().result()
assert 'nrows' in _function_runs.keys()
assert _function_runs['nrows'][2] == 1
# put a new column with zeros
T.putcol("STATE_ID", np.full(nrows, 0, dtype=np.int32)).result()
assert 'putcol' in _function_runs.keys()
assert _function_runs['putcol'][2] == 1
# get the DATA column and create 'data' variable
data = np.zeros_like(T.getcol("DATA").result())
data_dtype = data.dtype
assert 'getcol' in _function_runs.keys()
assert _function_runs['getcol'][2] == 1
getcol_runs = _function_runs['getcol'][2]
# put new data into DATA column
T.putcol("DATA", data).result()
assert _function_runs['putcol'][2] == 2

xds = xds_from_ms(ms, columns=select_cols,
group_cols=group_cols,
index_cols=index_cols,
chunks={"row": 2})
assert 'getcoldesc' in _function_runs.keys()
# 2 - __tablerows__ , __firstrow__
assert _function_runs['getcol'][2] == getcol_runs + 2 + len(group_cols)
assert _function_runs['getcoldesc'][2] == len(select_cols) * len(xds)

written_states = []
written_data = []
writes = []

# # Write out STATE_ID and DATA
for i, ds in enumerate(xds):
dims = ds.dims
chunks = ds.chunks
state = da.arange(i, i + dims["row"], chunks=chunks["row"])
state = state.astype(np.int32)
written_states.append(state)

data = da.arange(i, i + dims["row"]*dims["chan"]*dims["corr"])
data = data.reshape(dims["row"], dims["chan"], dims["corr"])
data = data.rechunk((chunks["row"], chunks["chan"], chunks["corr"]))
data = data.astype(data_dtype)
written_data.append(data)

nds = ds.assign(STATE_ID=(("row",), state),
DATA=(("row", "chan", "corr"), data))

write = xds_to_table(nds, ms, ["STATE_ID", "DATA"])
writes.append(write)

assert ('colnames' and '_put_keywords') in _function_runs.keys()
assert _function_runs['colnames'][2] == len(xds)
assert _function_runs['_put_keywords'][2] == len(xds)
# Do all writes in parallel
dask.compute(writes)

assert 'getcellslice' in _function_runs.keys()
print(xds[0].dims, xds[1].dims)
print(xds[0].chunks, xds[1].chunks)

# 0 : [ 5 rows, 1 __tablerow__ , 3 index_columns ] : 9
# 1 : [ 5 rows, 1 __tablerow__ , 3 index_columns ] : 9
# or
# 0 : [ 5 rows, 4 corr, 3 index_columns ] : 12
# 1 : [ 5 rows, 4 corr, 3 index_columns ] : 12
# read CASA getcellslice

# assert _function_runs['getcellslice'][2] == len(index_col) len(selecols) len(xds)

xds = xds_from_ms(ms, columns=select_cols,
group_cols=group_cols,
index_cols=index_cols,
chunks={"row": 2})

# Check that state and data have been correctly written
it = enumerate(zip(xds, written_states, written_data))
for i, (ds, state, data) in it:
assert_array_equal(ds.STATE_ID.data, state)
assert_array_equal(ds.DATA.data, data)

# @ToDo or @ToThinkAbout
# Assert statements on all the tests file for _function_runs
# Because no new function was added (just wrappers) on existing functions.
#
5 changes: 4 additions & 1 deletion daskms/writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ def _write_datasets(table, table_proxy, datasets, columns, descriptor,
table_name = '::'.join((table_name, subtable)) if subtable else table_name
row_orders = []

# print("Table Keywords - _write_datasets", type(table_keywords), table_keywords)
# Put table and column keywords
table_proxy.submit(_put_keywords, WRITELOCK,
table_keywords, column_keywords).result()
Expand Down Expand Up @@ -602,7 +603,8 @@ def _write_datasets(table, table_proxy, datasets, columns, descriptor,


def _put_keywords(table, table_keywords, column_keywords):
if table_keywords is not None:
# print("Table Keywords - _put_keywords", type(table_keywords), table_keywords)
if not all(item is None for item in table_keywords):
for k, v in table_keywords.items():
if v == DELKW:
table.removekeyword(k)
Expand Down Expand Up @@ -641,6 +643,7 @@ def write_datasets(table, datasets, columns, descriptor=None,
else:
tp = _updated_table(table, datasets, columns, descriptor)

# print("Table Keywords - write_datasets", type(table_keywords), table_keywords)
write_datasets = _write_datasets(table, tp, datasets, columns,
descriptor=descriptor,
table_keywords=table_keywords,
Expand Down