Skip to content

Commit

Permalink
Special case dask.array.broadcast_arrays which produces unncessary co…
Browse files Browse the repository at this point in the history
…mms in the general case
  • Loading branch information
sjperkins committed Apr 6, 2024
1 parent 350415c commit cec7723
Showing 1 changed file with 28 additions and 9 deletions.
37 changes: 28 additions & 9 deletions daskms/experimental/katdal/msv2_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
################################################################################

from functools import partial
from operator import getitem

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -126,14 +127,7 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe

flags = DaskLazyIndexer(dataset.flags, (), (rechunk, flag_transpose))
weights = DaskLazyIndexer(dataset.weights, (), (rechunk, weight_transpose))
vis = DaskLazyIndexer(
dataset.vis,
(),
transforms=(
rechunk,
vis_transpose,
),
)
vis = DaskLazyIndexer(dataset.vis, (), (rechunk, vis_transpose))

time = da.from_array(time_mjds[:, None], chunks=(t_chunks, 1))
ant1 = da.from_array(cp_info.ant1_index[None, :], chunks=(1, cpi.shape[0]))
Expand All @@ -147,7 +141,32 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe
row=self._row_view,
)

time, ant1, ant2 = da.broadcast_arrays(time, ant1, ant2)
# Better graph than da.broadcast_arrays
bcast = da.blockwise(
np.broadcast_arrays,
("time", "bl"),
time,
("time", "bl"),
ant1,
("time", "bl"),
ant2,
("time", "bl"),
align_arrays=False,
adjust_chunks={"time": time.chunks[0], "bl": ant1.chunks[1]},
meta=np.empty((0,) * 2, dtype=np.int32),
)

time = da.blockwise(
getitem, ("time", "bl"), bcast, ("time", "bl"), 0, None, dtype=time.dtype
)

ant1 = da.blockwise(
getitem, ("time", "bl"), bcast, ("time", "bl"), 1, None, dtype=ant1.dtype
)

ant2 = da.blockwise(
getitem, ("time", "bl"), bcast, ("time", "bl"), 2, None, dtype=ant2.dtype
)

if self._row_view:
primary_dims = ("row",)
Expand Down

0 comments on commit cec7723

Please sign in to comment.