# Cubeviz base functionality with leaflet

1. Load a data cube **(DONE)**
2. Scrub through the channels, displaying 2D and 1D **(DONE)**
3. Median collapse cube **(FUTURE WORK)**
4. Stretch goal: Hover stats at cursor (x, y, counts, RA, Dec) **(DONE)**

In [None]:
# STDLIB
import io
import logging
import random
import threading
import warnings

# THIRD PARTY: SCIENTIFIC
import numpy as np
from astropy import visualization
from astropy.nddata import CCDData
from skimage import transform

# THIRD PARTY: VIZ
import flask
import ipyleaflet
import ipywidgets
from IPython.display import display
from matplotlib import pyplot as plt

Create sample data cube or use an existing FITS file. The `ngc6946.fits` (Subaru data) that is used here can be downloaded from https://stsci.box.com/s/tg6m48hccnmn8jjs3447e0yuillz56ld (ask P. L. Lim for Box folder access).

The random cube can be used to stress test the performance, while FITS file can be used to see how this works with proper WCS and other relevant metadata.

In [None]:
demo_type = 'realworld'

if demo_type == 'realworld':
    from astropy.io import fits
    from astropy.wcs import WCS

    with warnings.catch_warnings():
        warnings.simplefilter('ignore')
        with fits.open('ngc6946.fits') as pf:
            cube = pf[0].data
            hdr = pf[0].header
            wcs = WCS(hdr)
            ylbl_1d = hdr['BUNIT']
            xlbl_1d = hdr['CTYPE3']
            flux_unit = 'Jy'
    
else:
    # NOTE: Tried np.arange but doesn't viz well.
    np.random.seed(1234)
    cube = np.random.random(10000000).reshape((1000, 100, 100))
    wcs = None
    ylbl_1d = 'count'
    xlbl_1d = 'index'
    flux_unit = ylbl_1d

c1 = CCDData(cube, unit=flux_unit, wcs=wcs)
    
print('Loaded data for {}'.format(demo_type))
print('--- WCS ---')
print(c1.wcs)

Some functions copied over from `html_view_into_leaflet.ipynb` and `Leaflet_image_viewer.ipynb` originally made by Erik Tollerud.

In [None]:
"""Writing PNGs.

These were copied over from yt. No need to modify.

"""
#-----------------------------------------------------------------------------
# Copyright (c) 2013, yt Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------

import matplotlib._png as _png

from io import BytesIO as StringIO


def call_png_write_png(buffer, width, height, filename, dpi):
    _png.write_png(buffer, filename, dpi)

    
def write_png(buffer, filename, dpi=100):
    width = buffer.shape[1]
    height = buffer.shape[0]
    call_png_write_png(buffer, width, height, filename, dpi)

    
def write_png_to_string(buffer, dpi=100, gray=0):
    width = buffer.shape[1]
    height = buffer.shape[0]
    fileobj = StringIO()
    call_png_write_png(buffer, width, height, fileobj, dpi)
    png_str = fileobj.getvalue()
    fileobj.close()
    return png_str

Set up the tile server. In the real world, the tile server serves pre-generated tiles at given zoom (`z`) and tile position (`x` and `y`). In this demo, the tiles are generated on-the-fly instead by chopping up the given cube slice into appropriate sections.

In [None]:
# Stores the currently active stretched slice for display.
# This is used to generate tiles.
visdat = None

# Used for conversion from tile coordinates to pixels.
px_in_tiles = 0


# NOTE: Unlike html_view_into_leaflet.ipynb, there is extra
# arg named i_slice here for slicing the cube.
def re_stretch(stretch, i_slice):
    """Stretch the given slice and flip it for display.
    Flipping is needed so origin is consistent with
    typical astronomy viz tools.
    
    Parameters
    ----------
    stretch : obj
        Stretching from `astropy.visualization`.
        
    i_slice : int
        Index of the cube slice desired.
        
    Raises
    ------
    IndexError
        Slice is out of bounds.
    
    """
    global visdat
    visdat = np.flip(stretch(c1.data[i_slice]), 0)


def ccd_to_pngstr_app(dat):
    """Format given data into PNG for display.
    This is the final step in the display processing.
    
    Parameters
    ----------
    png
        PNG data for display.
    
    """
    return write_png_to_string((dat * 255).astype('uint8')[:,:,np.newaxis])


logstream = io.StringIO()
logging.basicConfig(stream=logstream)

# We use Flask to serve as tile server.
app = flask.Flask(__name__)


# Adapted from Leaflet_image_viewer.ipynb
# TODO: How to get rid of tile borders?
@app.route('/fits<int:cachebuster>/<string:z>/<int:x>/<int:y>.png')
def get_subfits(z, x, y, cachebuster):
    """Generate a tile.
    See https://en.wikipedia.org/wiki/Tiled_web_map for
    more info.
    
    Parameters
    ----------
    z : int
        Zoom.
        
    x, y : int
        Tile position.
        
    cachebuster : int
        This is used for tile caching.
        
    Returns
    -------
    png
        PNG data for display.
        
    """
    global px_in_tiles
    
    z = int(z)
    tile_size = 256
    max_dim = max(visdat.shape)
    
    if z < 0:  # TODO: Not used?
        factor = 2 ** -z
        wid = tile_size * factor
        xrng = slice(x * wid, (x + 1) * wid, factor)
        yrng = slice(y * wid, (y + 1) * wid, factor)

        subdat = visdat[yrng, xrng]
        px_in_tiles = wid  # Untested
        
        # ???
        if 0 in subdat.shape:
            1 / 0
        
        # Pad out with NaNs
        elif subdat.shape != (wid, wid):
            temp = subdat.copy()
            subdat = np.empty((tile_size, tile_size), dtype=temp.dtype)
            subdat.fill(np.nan)
            subdat[:temp.shape[0], :temp.shape[1]] = temp
        
        return ccd_to_pngstr_app(subdat)
            
    else:  # TODO: Optimize?
        # Determine how many tiles to split up
        # the image to in the longer dimension.
        wid = max_dim // z
        
        # If image cannot fit entirely, squeeze a bit.
        if wid * z < max_dim:
            wid += 1

        px_in_tiles = wid * z
            
        ix1 = x * wid       
        ix2 = (x + 1) * wid
        iy1 = y * wid
        iy2 = (y + 1) * wid
        
        # Out of bounds
        if ix1 >= visdat.shape[1] or iy1 >= visdat.shape[0]:
            return ccd_to_pngstr_app(subdat)

        if iy2 > visdat.shape[0]:
            iy2 = visdat.shape[0]
            
        if ix2 > visdat.shape[1]:
            ix2 = visdat.shape[1]
            
        subdat = np.empty((wid, wid), dtype=visdat.dtype)
        subdat.fill(np.nan)
        subdat[:(iy2 - iy1), :(ix2 - ix1)] = visdat[iy1:iy2, ix1:ix2]
        
    # Prepare tile.
    with warnings.catch_warnings():
        warnings.simplefilter('ignore')
        subdat = transform.resize(
            subdat, (tile_size, tile_size),
            mode='constant', cval=np.nan, anti_aliasing=False)
            
    return ccd_to_pngstr_app(subdat)

Set your favorite display stretch and cuts.

In [None]:
fav_stretch = visualization.LogStretch() + visualization.PercentileInterval(95)

**TODO: Do users need to worry about server warning from `th.start()`?**

In [None]:
th = threading.Thread(target=lambda:app.run(debug=False, use_reloader=False, port=5013))
th.start()

Create an empty `ipyleaflet` viewer. We will display here later.

**TODO: How to make display more user friendly? Will this work with voila?**

**TODO: It looks like ipyleaflet is solely to support Earth map view, so using it like this is too hacky?**

**TODO: Theoretically, a subset of this code can be re-purposed to display median-collapsed cube in the same way.**

In [None]:
url_templ = 'http://127.0.0.1:5013/fitsNUM/{z}/{x}/{y}.png'

m = ipyleaflet.Map(crs='Simple', center=(0, 0), zoom=1, min_zoom=1,
                   layers=[], scroll_wheel_zoom=True)

# NOTE: Need to put this here for 1D plotting to update.
# NOTE: In-place plot update does not work with inline.
# TODO: Does not work in Jupyter Lab
%matplotlib notebook

# For 1D plotting
fig, ax = plt.subplots()

# https://github.com/jupyter-widgets/ipyleaflet/issues/332
lbl = ipywidgets.Label()
display(lbl)

# For debugging
lbl2 = ipywidgets.Label()
display(lbl2)

# Event names taken from
# https://github.com/jupyter-widgets/ipyleaflet/blob/5f27207ac7a3f2a08f45181613e9ed9ab37eb759/ipyleaflet/leaflet.py#L100
def handle_interaction(**kwargs):
    global lbl, lbl2
    event_type = kwargs.get('type')
    coo = kwargs.get('coordinates')
    
    # Display tile size.
    disp_size = m.zoom * (256 // (2 ** m.zoom))
    
    # Debug info.
    lbl2.value = 'LEAF_X: {:.2f}, LEAF_Y: {:.2f}, Zoom: {}, px_in_tiles: {}, disp_size: {}'.format(
        coo[1], coo[0], m.zoom, px_in_tiles, disp_size)
    
    if visdat is None:
        return
    
    # Rough map coord to X, Y translation
    img_y = (coo[0] + disp_size - (px_in_tiles - visdat.shape[0])) * px_in_tiles / disp_size
    img_x = coo[1] * px_in_tiles / disp_size
    
    # TODO: Close enough?
    iy = int(img_y)
    ix = int(img_x)
    iz = slider_obj.value
    
    within_bounds = (0 <= ix < visdat.shape[1] and
        0 <= iy < visdat.shape[0])
    
    # Grab pixel value from translated X, Y
    if within_bounds:
        img_val = cube[iz, iy, ix]
    else:
        img_val = np.nan
    
    # TODO: Can we make this more robust?
    # See https://leafletjs.com/examples/crs-simple/crs-simple.html
    if event_type == 'mousemove':                   
        if c1.wcs is None:
            radec_str = ''
        else:
            skycoord = c1.wcs.pixel_to_world(img_x, img_y, slider_obj.value)[0]
            radec_str = ' (RA: {}, Dec: {})'.format(str(skycoord.ra), str(skycoord.dec))
            
        # Hover info.
        lbl.value = 'Value: {:.4e}, X: {:.2f}, Y: {:.2f} '.format(
            img_val, img_x, img_y) + radec_str
        
    elif event_type == 'click':
        ax.clear()
        
        # Matplotlib for 1D
        if within_bounds:
            data_1d = cube[:, iy, ix]
            ax.plot(data_1d)
            ax.plot(iz, img_val, 'ro')
            ax.set_ylabel(ylbl_1d)
            ax.set_xlabel(xlbl_1d)
            fig.canvas.draw_idle()


# NOTE: Random int here is how they do cache invalidation...
cachebuster_int = random.randint(0, 1000000)
local_fits_layer = ipyleaflet.basemap_to_tiles(
    {'url': url_templ.replace('NUM', str(cachebuster_int)), 
     'attribution': 'fitsfile'})
local_fits_layer.cachebuster_int = cachebuster_int


def refresh():
    """Invalidate tile caching."""
    local_fits_layer.cachebuster_int += 1
    local_fits_layer.url = url_templ.replace('NUM', str(local_fits_layer.cachebuster_int))
    
    
def rere_stretch(stretch, i_slice):
    """Display given slice."""
    re_stretch(stretch, i_slice)
    refresh()

    
def change_slice(x):
    """This is used in slider below."""
    rere_stretch(fav_stretch, x)

    
# Bind map to FITS layer and attach event handling.
m.add_layer(local_fits_layer)
m.on_interaction(handle_interaction)

# Display map.
m

The following hooks up the slider to display chosen slice in the cube on the viewer above.

In [None]:
slider_obj = ipywidgets.IntSlider(min=0, max=cube.shape[0]-1, step=1, value=0)
interact_obj = ipywidgets.interact(change_slice, x=slider_obj)

**Meat of the demo: Zoom in/out. Mouseover the display to see info update. Move the slider. The click on point of interest to see 1D plot above. Repeat as you wish.**

But what about that median-collapsed cube?

In [None]:
collapsed_cube = np.median(cube, axis=0)
print(collapsed_cube.shape)

**TODO: The cell below does not work yet. Need to figure out how to have another tile server for collapsed cube separate from cube above. Can be investigated in a future sprint, if desired.**

In [None]:
visdat_m2 = None

url_templ = 'http://127.0.0.1:5013/fitsNUM/{z}/{x}/{y}.png'

m2 = ipyleaflet.Map(crs='Simple', center=(0, 0), zoom=1, min_zoom=1,
                    layers=[], scroll_wheel_zoom=True)


# https://github.com/jupyter-widgets/ipyleaflet/issues/332
m2lbl = ipywidgets.Label()
display(m2lbl)

# For debugging
m2lbl2 = ipywidgets.Label()
display(m2lbl2)

# Event names taken from
# https://github.com/jupyter-widgets/ipyleaflet/blob/5f27207ac7a3f2a08f45181613e9ed9ab37eb759/ipyleaflet/leaflet.py#L100
def handle_interaction_m2(**kwargs):
    global m2lbl, m2lbl2
    event_type = kwargs.get('type')
    coo = kwargs.get('coordinates')
    
    # Display tile size.
    disp_size = m2.zoom * (256 // (2 ** m2.zoom))
    
    # Debug info.
    m2lbl2.value = 'LEAF_X: {:.2f}, LEAF_Y: {:.2f}, Zoom: {}, px_in_tiles: {}, disp_size: {}'.format(
        coo[1], coo[0], m2.zoom, px_in_tiles, disp_size)
       
    if visdat_m2 is None:
        return
        
    # Rough map coord to X, Y translation
    img_y = (coo[0] + disp_size - (px_in_tiles - collapsed_cube.shape[0])) * px_in_tiles / disp_size
    img_x = coo[1] * px_in_tiles / disp_size
    
    # TODO: Close enough?
    iy = int(img_y)
    ix = int(img_x)
    
    within_bounds = (0 <= ix < collapsed_cube.shape[1] and
        0 <= iy < collapsed_cube.shape[0])
    
    # Grab pixel value from translated X, Y
    if within_bounds:
        img_val = collapsed_cube[iy, ix]
    else:
        img_val = np.nan
        
    # TODO: Can we make this more robust?
    # See https://leafletjs.com/examples/crs-simple/crs-simple.html
    if event_type == 'mousemove':                   
        if c1.wcs is None:
            radec_str = ''
        else:
            # Assume FOV the same across cube.
            skycoord = c1.wcs.pixel_to_world(img_x, img_y, 0)[0]
            radec_str = ' (RA: {}, Dec: {})'.format(str(skycoord.ra), str(skycoord.dec))
            
        # Hover info.
        m2lbl.value = 'Value: {:.4e}, X: {:.2f}, Y: {:.2f} '.format(
            img_val, img_x, img_y) + radec_str


# NOTE: Random int here is how they do cache invalidation...
cachebuster_int = random.randint(0, 1000000)
local_fits_layer_m2 = ipyleaflet.basemap_to_tiles(
    {'url': url_templ.replace('NUM', str(cachebuster_int)), 
     'attribution': 'fitsfile'})
local_fits_layer_m2.cachebuster_int = cachebuster_int


def refresh_m2():
    """Invalidate tile caching."""
    local_fits_layer_m2.cachebuster_int += 1
    local_fits_layer_m2.url = url_templ.replace('NUM', str(local_fits_layer_m2.cachebuster_int))
    
    
def rere_stretch_m2(stretch):
    """Display given slice."""
    global visdat_m2
    visdat_m2 = np.flip(stretch(collapsed_cube, 0)
    refresh_m2()

    
# Bind map to FITS layer and attach event handling.
m2.add_layer(local_fits_layer_m2)
m2.on_interaction(handle_interaction_m2)

# Display map.
m2

End of prototype demo.