# Dask bag-based imaging demonstration

This notebook explores the use of dask bags for parallelisation. For the most part we work with the bags directly. Much of this can be hidden in standard functions.

See imaging-dask notebook for processing with dask delayed

We create the visibility and fill in values with the transform of a number of point sources. 

In [1]:
% matplotlib inline

import os
import sys

from dask import bag
from distributed import Client

results_dir = './results'
os.makedirs(results_dir, exist_ok=True)

from matplotlib import pylab

pylab.rcParams['figure.figsize'] = (12.0, 12.0)
pylab.rcParams['image.cmap'] = 'rainbow'

import numpy

from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.wcs.utils import pixel_to_skycoord

from matplotlib import pyplot as plt

from arl.calibration.operations import apply_gaintable, create_gaintable_from_blockvisibility
from arl.calibration.solvers import solve_gaintable
from arl.data.polarisation import PolarisationFrame
from arl.visibility.base import create_visibility, create_blockvisibility
from arl.visibility.operations import concatenate_visibility, subtract_visibility
from arl.visibility.coalesce import decoalesce_visibility
from arl.skycomponent.operations import create_skycomponent
from arl.image.operations import show_image, qa_image, pad_image
from arl.image.deconvolution import deconvolve_cube, restore_cube
from arl.util.testing_support import create_named_configuration, create_test_image, simulate_gaintable
from arl.imaging import create_image_from_visibility, advise_wide_field
from arl.imaging.wstack import predict_wstack_single, invert_wstack_single
from arl.visibility.gather_scatter import visibility_gather_w, visibility_scatter_w
from arl.imaging.weighting import weight_visibility
from arl.graphs.dask_init import get_dask_Client
from arl.graphs.bags import safe_invert_list, safe_predict_list, sum_invert_bag_results
from arl.visibility.coalesce import decoalesce_visibility
from arl.visibility.coalesce import coalesce_visibility

import logging

log = logging.getLogger()
log.setLevel(logging.INFO)
log.addHandler(logging.StreamHandler(sys.stdout))


Define a function to create the visibilities

In [2]:
def ingest_visibility(freq=1e8, chan_width=1e6, reffrequency=[1e8], npixel=512,
                      init=False, block=True, add_errors=False):
    lowcore = create_named_configuration('LOWBD2-CORE')
    times = numpy.linspace(-numpy.pi / 4, numpy.pi / 4, 7)
    frequency = numpy.array([freq])
    channel_bandwidth = numpy.array([chan_width])

    phasecentre = SkyCoord(
        ra=+15.0 * u.deg, dec=-26.7 * u.deg, frame='icrs', equinox='J2000')
    if block:
        vt = create_blockvisibility(
            lowcore,
            times,
            frequency,
            channel_bandwidth=channel_bandwidth,
            weight=1.0,
            phasecentre=phasecentre,
            polarisation_frame=PolarisationFrame("stokesI"))
    else:
        vt = create_visibility(
            lowcore,
            times,
            frequency,
            channel_bandwidth=channel_bandwidth,
            weight=1.0,
            phasecentre=phasecentre,
            polarisation_frame=PolarisationFrame("stokesI"))
    return vt

In [51]:
import pprint
pp=pprint.PrettyPrinter()

In [124]:
nfreqwin=5
vis_dict = [{'freqwin':f, 
            'vis':[ingest_visibility(freq, init=True, block=True, add_errors=True)]}
            for f, freq in  enumerate(numpy.linspace(0.8e8,1.2e8,nfreqwin))]

vis_list = [{'freqwin':f,
             'vis':ingest_visibility(freq, init=True, block=True, add_errors=True)}
            for f, freq in  enumerate(numpy.linspace(0.8e8,1.2e8,nfreqwin))]

vis_bag = bag.from_sequence(vis_list)
pp.pprint(list(vis_bag))

create_blockvisibility: 7 rows, 0.009 GB
create_blockvisibility: 7 rows, 0.009 GB
create_blockvisibility: 7 rows, 0.009 GB
create_blockvisibility: 7 rows, 0.009 GB
create_blockvisibility: 7 rows, 0.009 GB
create_blockvisibility: 7 rows, 0.009 GB
create_blockvisibility: 7 rows, 0.009 GB
create_blockvisibility: 7 rows, 0.009 GB
create_blockvisibility: 7 rows, 0.009 GB
create_blockvisibility: 7 rows, 0.009 GB
[{'freqwin': 0,
  'vis': <arl.data.data_models.BlockVisibility object at 0x11aa78860>},
 {'freqwin': 1,
  'vis': <arl.data.data_models.BlockVisibility object at 0x112ccf748>},
 {'freqwin': 2,
  'vis': <arl.data.data_models.BlockVisibility object at 0x11a981be0>},
 {'freqwin': 3,
  'vis': <arl.data.data_models.BlockVisibility object at 0x11aa73fd0>},
 {'freqwin': 4,
  'vis': <arl.data.data_models.BlockVisibility object at 0x11a981ef0>}]


In [119]:
pp.pprint(list(vis_bag.map(lambda x: visibility_scatter_time(x[0]))))

tornado.application - ERROR - Future <tornado.concurrent.Future object at 0x11963c668> exception was never retrieved: Traceback (most recent call last):
  File "/Users/timcornwell/anaconda/envs/arlenv/lib/python3.6/site-packages/tornado/gen.py", line 1069, in run
    yielded = self.gen.send(value)
  File "/Users/timcornwell/anaconda/envs/arlenv/lib/python3.6/site-packages/distributed/client.py", line 1269, in wait
    raise AllExit()
distributed.client.AllExit
tornado.application - ERROR - Future <tornado.concurrent.Future object at 0x11961e1d0> exception was never retrieved: Traceback (most recent call last):
  File "/Users/timcornwell/anaconda/envs/arlenv/lib/python3.6/site-packages/tornado/gen.py", line 1069, in run
    yielded = self.gen.send(value)
  File "/Users/timcornwell/anaconda/envs/arlenv/lib/python3.6/site-packages/distributed/client.py", line 1269, in wait
    raise AllExit()
distributed.client.AllExit
tornado.application - ERROR - Future <tornado.concurrent.Future object

tornado.application - ERROR - Future <tornado.concurrent.Future object at 0x11606bba8> exception was never retrieved: Traceback (most recent call last):
  File "/Users/timcornwell/anaconda/envs/arlenv/lib/python3.6/site-packages/tornado/gen.py", line 1069, in run
    yielded = self.gen.send(value)
  File "/Users/timcornwell/anaconda/envs/arlenv/lib/python3.6/site-packages/distributed/client.py", line 1269, in wait
    raise AllExit()
distributed.client.AllExit
tornado.application - ERROR - Future <tornado.concurrent.Future object at 0x11b153cf8> exception was never retrieved: Traceback (most recent call last):
  File "/Users/timcornwell/anaconda/envs/arlenv/lib/python3.6/site-packages/tornado/gen.py", line 1069, in run
    yielded = self.gen.send(value)
  File "/Users/timcornwell/anaconda/envs/arlenv/lib/python3.6/site-packages/distributed/client.py", line 1269, in wait
    raise AllExit()
distributed.client.AllExit


KeyError: 0

In [110]:
pp.pprint(vis_bag.to_delayed())

[Delayed(('from_sequence-01a5b345bb68c976943e3f89ab38ac9c', 0)),
 Delayed(('from_sequence-01a5b345bb68c976943e3f89ab38ac9c', 1)),
 Delayed(('from_sequence-01a5b345bb68c976943e3f89ab38ac9c', 2)),
 Delayed(('from_sequence-01a5b345bb68c976943e3f89ab38ac9c', 3)),
 Delayed(('from_sequence-01a5b345bb68c976943e3f89ab38ac9c', 4))]


In [67]:
vis_bag.map()

('freqwin', 1, <arl.data.data_models.BlockVisibility object at 0x116112780>)


Now make seven of these spanning 800MHz to 1200MHz and put them into a Dask bag.

In [86]:
vis_bag=bag.from_sequence([ingest_visibility(freq, init=True, block=True, add_errors=True) 
                           for f, freq in enumerate(numpy.linspace(0.8e8,1.2e8,nfreqwin))])
print(vis_bag)

create_blockvisibility: 7 rows, 0.009 GB
create_blockvisibility: 7 rows, 0.009 GB
create_blockvisibility: 7 rows, 0.009 GB
create_blockvisibility: 7 rows, 0.009 GB
create_blockvisibility: 7 rows, 0.009 GB
dask.bag<from_se..., npartitions=5>


In [94]:
from arl.visibility.gather_scatter import visibility_scatter_time, visibility_scatter_w
list(vis_bag.map(visibility_scatter_time, timeslice='auto'))

[[<arl.data.data_models.BlockVisibility at 0x115f281d0>,
  <arl.data.data_models.BlockVisibility at 0x115f2c198>,
  <arl.data.data_models.BlockVisibility at 0x115f2c1d0>,
  <arl.data.data_models.BlockVisibility at 0x115f2c208>,
  <arl.data.data_models.BlockVisibility at 0x115f2c240>,
  <arl.data.data_models.BlockVisibility at 0x115f2c278>,
  <arl.data.data_models.BlockVisibility at 0x115f2c2b0>],
 [<arl.data.data_models.BlockVisibility at 0x115f28208>,
  <arl.data.data_models.BlockVisibility at 0x115f362e8>,
  <arl.data.data_models.BlockVisibility at 0x115f36320>,
  <arl.data.data_models.BlockVisibility at 0x115f36358>,
  <arl.data.data_models.BlockVisibility at 0x115f36390>,
  <arl.data.data_models.BlockVisibility at 0x115f363c8>,
  <arl.data.data_models.BlockVisibility at 0x115f36400>],
 [<arl.data.data_models.BlockVisibility at 0x115f285f8>,
  <arl.data.data_models.BlockVisibility at 0x115f3d438>,
  <arl.data.data_models.BlockVisibility at 0x115f3d470>,
  <arl.data.data_models.Block

We need to compute the bag in order to use it. First we just need a representative data set to calculate imaging parameters.

In [4]:
npixel=256
facets=4
def get_LSM(vt, cellsize=0.001, reffrequency=[1e8], npixel=512):
    model = pad_image(create_test_image(vt, cellsize=cellsize, frequency=reffrequency, 
                                        phasecentre=vt.phasecentre,
                                        polarisation_frame=PolarisationFrame("stokesI")),
                                        shape=[1, 1, 512, 512])
    return model

vis_bag = list(vis_bag)
model = get_LSM(vis_bag[0])
advice=advise_wide_field(vis_bag[0], guard_band_image=4.0)
vis_slices=101

replicate_image: replicating shape (256, 256) to (1, 1, 256, 256)
advise_wide_field: Maximum wavelength 3.747 (meters)
advise_wide_field: Minimum wavelength 3.747 (meters)
advise_wide_field: Maximum baseline 210.1 (wavelengths)
advise_wide_field: Station/antenna diameter 35.0 (meters)
advise_wide_field: Primary beam 0.107069 (rad) 6.135 (deg)
advise_wide_field: Image field of view 0.428275 (rad) 24.538 (deg)
advise_wide_field: Synthesized beam 0.004759 (rad) 0.273 (deg)
advise_wide_field: Cellsize 0.001586 (rad) 0.091 (deg)
advice_wide_field: Npixels per side = 270
advice_wide_field: Npixels (power of 2, 3) per side = 384
advice_wide_field: W sampling for full image = 0.3 (wavelengths)
advice_wide_field: W sampling for primary beam = 5.6 (wavelengths)
advice_wide_field: Time sampling for full image = 45.4 (s)
advice_wide_field: Time sampling for primary beam = 726.9 (s)
advice_wide_field: Frequency sampling for full image = 42066.2 (Hz)
advice_wide_field: Frequency sampling for primary

In [5]:
client=get_dask_Client()

Creating LocalCluster and Dask Client
<Client: scheduler='tcp://127.0.0.1:52977' processes=8 cores=8>
Diagnostic pages available on port http://127.0.0.1:8787


Now we can set up the prediction of the visibility from the model. We scatter over w and then apply the wstack for a single w plane. Then we concatenate the visibilities back together and convert back to a block visibility.

To save recomputing this, we compute it now and place it into another bag of the same name.

In [10]:
empty_vis_bag=bag.from_sequence([ingest_visibility(freq, init=True, block=True, add_errors=True) 
                           for freq in numpy.linspace(0.8e8,1.2e8,nfreqwin)])

model_bag = empty_vis_bag.map(get_LSM)

gt_bag = empty_vis_bag.map(create_gaintable_from_blockvisibility)\
    .map(simulate_gaintable, phase_error=1.0)

create_blockvisibility: 7 rows, 0.009 GB
create_blockvisibility: 7 rows, 0.009 GB
create_blockvisibility: 7 rows, 0.009 GB


In [13]:
vis_bag = empty_vis_bag\
    .map(visibility_scatter_w, vis_slices=vis_slices)\
    .map(safe_predict_list, model, predict=predict_wstack_single)\
    .map(concatenate_visibility)\
    .map(decoalesce_visibility, empty_vis_bag)\
    .map(apply_gaintable, gt_bag)

future=client.compute(vis_bag)
print(future)
vis_bag=bag.from_sequence(future.result())

<Future: status: pending, key: finalize-57cf508b614f9383ab17b0118f24aa44>


Check out the visibility function. To get the result out of the bag, we do need to compute it but this time it's just a lookup.

In [None]:
coalesced_vis_bag=vis_bag.map(coalesce_visibility)
vt = coalesced_vis_bag.compute()[0]

# To check that we got the prediction right, plot the amplitude of the visibility.
uvdist=numpy.sqrt(vt.data['uvw'][:,0]**2+vt.data['uvw'][:,1]**2)
plt.clf()
plt.plot(uvdist, numpy.abs(vt.data['vis']), '.')
plt.xlabel('uvdist')
plt.ylabel('Amp Visibility')
plt.show()

In [None]:
from arl.graphs.bags import add_invert_results 
dirty_bag=coalesced_vis_bag\
    .map(visibility_scatter_w, vis_slices=vis_slices)\
    .map(safe_invert_list, model, invert_wstack_single, dopsf=False, normalize=True)\
    .map(sum_invert_bag_results)
    
dirty_bag.visualize('imaging-bags-dirty.svg')

psf_bag=coalesced_vis_bag\
    .map(visibility_scatter_w, vis_slices=vis_slices)\
    .map(safe_invert_list, model, invert_wstack_single, dopsf=True, normalize=True)\
    .map(sum_invert_bag_results)

    
dirty_bag=bag.from_sequence(dirty_bag.compute())
psf_bag=bag.from_sequence(psf_bag.compute())

Now we can make the dirty images. As before we will scatter each of the 7 frequency windows (patitions) over w, giving a 2 level nested structure. We make a separate image for each frequency window. The image resolution noticeably improves for the high frequencies.

We deconvolve the one image and then replicate it for the subsequent selfcalibration

In [None]:
comp_bag=dirty_bag.take(1, compute=False)\
    .map(deconvolve_cube, psf_bag.take(1, compute=False), niter=1000, threshold=0.3, fractional_threshold=0.1, 
         window_shape='quarter', gain=0.7, scales=[0, 3, 10, 30])

comp = comp_bag.compute()[0][0]
print(comp)
fig=show_image(comp)

comp_bag=bag.from_sequence(nfreqwin*[comp])

Next we can setup the self-calibration step. Note that we have to decoalesce the model visibility.

In [None]:
model_vis_bag=coalesced_vis_bag\
    .map(visibility_scatter_w, vis_slices=vis_slices)\
    .map(safe_predict_list, comp_bag, predict=predict_wstack_single)\
    .map(concatenate_visibility)
    
# Decoalesce the model visibility
block_model_vis_bag = model_vis_bag.map(decoalesce_visibility, empty_vis_bag)

gt_bag = vis_bag\
    .map(solve_gaintable, block_model_vis_bag, phase_only=True)

Correct the visibilities for the gains. Freeze them in.

In [None]:
coalesced_corrected_vis_bag = bag.from_sequence(vis_bag.map(apply_gaintable, gt_bag, inverse=True)\
                                                .map(coalesce_visibility).compute())

Remake the dirty image and deconvolve again

In [None]:
dirty_bag=coalesced_corrected_vis_bag\
    .map(visibility_scatter_w, vis_slices=vis_slices)\
    .map(safe_invert_list, model, invert_wstack_single, dopsf=False, normalize=True)\
    .map(sum_invert_bag_results)\
    .fold(add_invert_results)
dirty_bag=bag.from_sequence(dirty_bag.compute())
    
comp_bag=dirty_bag.take(1, compute=False)\
    .map(deconvolve_cube, psf_bag.take(1, compute=False), niter=1000, threshold=0.3, fractional_threshold=0.1, 
         window_shape='quarter', gain=0.7, scales=[0, 3, 10, 30])

comp = comp_bag.compute()[0][0]

comp_bag=bag.from_sequence(nfreqwin*[comp])


Now we can calculate the corrected and residual visibility. 

In [None]:
model_vis_bag=coalesced_corrected_vis_bag\
    .map(visibility_scatter_w, vis_slices=101)\
    .map(safe_predict_list, comp_bag, predict=predict_wstack_single)\
    .map(concatenate_visibility)
    
ovt = coalesced_corrected_vis_bag.compute()[0]

res_vis_bag = coalesced_corrected_vis_bag.map(subtract_visibility, model_vis_bag)    
vt = res_vis_bag.compute()[0]

uvdist=numpy.sqrt(vt.data['uvw'][:,0]**2+vt.data['uvw'][:,1]**2)
plt.clf()
plt.plot(uvdist, numpy.abs(ovt.data['vis']), '.', color='b')
plt.plot(uvdist, numpy.abs(vt.data['vis']), '.', color='r')
plt.xlabel('uv distance (wavelengths)')
plt.ylabel('Amp Visibility')
plt.show()

In [None]:
res_image_bag=residual_vis_bag\
    .map(visibility_scatter_w, vis_slices=11)\
    .map(safe_invert_list, model, invert_wstack_single, dopsf=False, normalize=True)\
    .map(sum_invert_bag_results)
    
res_image_bag=bag.from_sequence(res_image_bag.compute()[0])

Now we can restore the images

In [None]:
print(res_image_bag.take(1, compute=False).compute())
restore_bag = comp_bag.take(1, compute=False).map(restore_cube, 
                                   psf_bag.take(1, compute=False), 
                                   res_image_bag.take(1, compute=False))

restored = restore_bag.compute()[0]
fig = show_image(restored, title='Restored image')
plt.show()