# Advanced `jdaviz` workflow: 
## Collapse spectral cubes in `Cubeviz` to produce pseudo-color images in `Imviz`

In [None]:
import logging
import tempfile
from functools import wraps

import numpy as np
from glue.core.roi import XRangeROI

import astropy.units as u
from astropy.coordinates import SkyCoord
from astropy.table import QTable
from astropy.time import Time
from astropy.nddata import NDDataArray
from astropy.utils.masked import Masked
from astropy.wcs import WCS

from astroquery.jplhorizons import Horizons
from astroquery.mast import Observations

from regions import PixCoord, CirclePixelRegion
from specutils import Spectrum1D, SpectralRegion

import matplotlib.pyplot as plt
from matplotlib.colors import to_hex

from jdaviz import Cubeviz, Imviz

In [None]:
#data_dir = tempfile.gettempdir()
data_dir = '/home/shared/preloaded-fits/jdaviz_data/moons_planets/'
#data_dir = './moons_planets/'

fn = "jw01373-o031_t007_miri_ch1-shortmediumlong_s3d.fits"  # io
# fn = "jw01248-c1001_t001_miri_ch4-shortmediumlong_s3d.fits"  # uranus
# fn = "jw01373-o002_t023_miri_ch2-shortmediumlong_s3d.fits"  # jupiter
# fn = "jw01247-o766_t634_miri_ch4-shortmediumlong_s3d.fits"  # saturn

# If you want to get newer reduction from MAST you can use this command
#result = Observations.download_file(f"mast:JWST/product/{fn}", local_path=f'{data_dir}/{fn}')

Load the spectral cube into Cubeviz:

In [None]:
cubeviz = Cubeviz()
cubeviz.load_data(f'{data_dir}/{fn}')
cubeviz.show()

Get the spectral cube with masks for each spectral region:

In [None]:
# number of spectral subsets to assign to colors:
n_subsets = 5

# colormap to adopt:
cmap = plt.cm.rainbow

# get hex colors for each subset
hex_colors = [
    to_hex(c) for c in 
    cmap(np.linspace(0, 1, n_subsets))
]

In [None]:
data_label = cubeviz.app.data_collection[0].label
data = cubeviz.app.data_collection[data_label]
wavelength = data.get_object().wavelength

# Divide the spectrum into a number of subsets:
subset_edges = np.linspace(wavelength.min(), wavelength.max(), n_subsets + 1)
subset_labels = [f"Subset {i}" for i in range(1, n_subsets + 1)]
subset_bounds = [subset_edges[i:i+2].to(u.um).value for i in range(n_subsets)]

spectrum_viewer = cubeviz.app.get_viewer('spectrum-viewer')

bandpasses = []
for subset_label, limits in zip(subset_labels, subset_bounds):
    cubeviz.app.session.edit_subset_mode.edit_subset = None
    spectrum_viewer.apply_roi(XRangeROI(*limits))
    bandpasses.append(
        data.get_subset_object(subset_label, cls=NDDataArray)
    )

Get the "celestial" (a.k.a. "spatial" or "non-spectral") component of the WCS:

In [None]:
wcs_celestial = data.meta['_orig_spec'].wcs.celestial

Collapse each masked spectral cube along the spectral axis to produce a 2D image as an `NDDataArray` with the celestial coordinates:

In [None]:
def collapse(band, force_wcs=wcs_celestial):
    # make a masked quantity array to collapse
    masked_quantity = Masked(band.data << band.unit, mask=band.mask)
    
    # collapse in the spectral dimension
    dispersion_axis = data.meta['DISPAXIS']
    collapsed_image = np.ma.sum(masked_quantity, axis=dispersion_axis)
    
    # force the celestial coordinates onto the collapsed NDDataArray:
    nddata = NDDataArray(
        collapsed_image, wcs=force_wcs
    )
    return nddata

collapsed_images = [collapse(band) for band in bandpasses]    

Choose Imviz settings to produce a neat RGB image:

In [None]:
# Use colors: B, G, R (chosen in order of increasing wavelength)
# primary_colors = ['#0000FF', '#00FF00', '#FF0000']

defaults = dict(
    stretch_vmin=0, 
    stretch_vmax=float(np.nanmax(collapsed_images[-1])) / 1.5, 
    image_opacity=2/n_subsets, 
    stretch_function='arcsinh'
)

img_settings = {
    subset_label: dict(image_color=color, **defaults)
    for subset_label, color in zip(subset_labels, hex_colors)
}


Initialize `Imviz`, load one monochromatic image per color channel, choose settings:

In [None]:
imviz = Imviz()
for image, label in zip(collapsed_images, subset_labels):
    imviz.load_data(image, data_label=label)
    
# Link images by WCS (without affine approximation)
imviz.plugins['Links Control'].link_type = 'WCS'
imviz.plugins['Links Control'].wcs_use_affine = False

p = imviz.plugins['Plot Options']
p.image_color_mode = 'Monochromatic'

for label, settings in img_settings.items():
    p.layer = f"{label}[DATA]"
    for k,v in settings.items():
        setattr(p, k, v)

    # The Imviz NDDataArray parser will load masks as separate
    # entries in the data collection. Remove those data items:
    mask_label = f"{label}[MASK]"
    imviz.app.remove_data_from_viewer('imviz-0', mask_label)

imviz.show()

This is just for Io, but you can adapt it to other targets!  
Look up the apparent position of Io viewed from JWST throughout the time of observations, with JPL Horizons. Add markers spaced by one minute intervals:

In [None]:
if data.meta['_primary_header']['TARGNAME'].lower() == 'io':
    # observing beginning/end times are in the FITS header:
    obs_beg = Time(data.meta["MJD-BEG"], format='mjd', scale='utc')
    obs_end = Time(data.meta["MJD-END"], format='mjd', scale='utc')

    # set up a Horizons query
    io_jwst = Horizons(
        # Jupiter's moon Io:
        id="501",
        # JWST's coordinates (in flight):
        location="500@-170",
        # return ephemeris at 1 min intervals during obs:
        epochs=dict(
            start=obs_beg.utc.iso,
            stop=obs_end.utc.iso,
            step='1m'
        )
    )
    ephemeris = io_jwst.ephemerides(extra_precision=True)
    ra, dec = QTable(ephemeris[['RA', 'DEC']]).itercols()
    io_coord = SkyCoord(ra, dec)
    
    image_viewer = imviz.app.get_viewer('imviz-0')
    coord_table = QTable(dict(coord=io_coord))
    image_viewer.marker = {'color': 'red', 'alpha': 1, 'markersize': 500, 'fill': True}
    image_viewer.add_markers(table=coord_table, use_skycoord=True, marker_name='Io')