# Prototype application for NIRSpec MOS preimaging planner

In [None]:
import io
import os
import warnings

from astropy.io import fits
from IPython.display import display, HTML
from ipyvuetify.extra import FileInput
import ipywidgets as ipw

from jdaviz.app import Application
from jdaviz.core.config import get_configuration
from jdaviz.configs.imviz.helper import Imviz

from novt import footprints as fp
from novt import display as nd

DEBUG = False

In [None]:
# set up viewer sizing for voila app
display(HTML("<style>.jdaviz__content--not-in-notebook {min-height: 80vh; max-height:80vh}</style>"))

# define a button class with 'value 'metadata
class ButtonWithValue(ipw.Button):
    def __init__(self, *args, **kwargs):
        self.value = kwargs.pop('value', '')
        super().__init__(*args, **kwargs)
        
# make a handler to wrap angles for nudging up and down
def wrap_angle(event):
    if event.owner.value < 0 or event.owner.value >= 360:
        event.owner.value = event.owner.value % 360

In [None]:
def _config():
    # create a config dict that does not allow file import or viewer creation,
    # based on MAST Jdaviz configuration
    cc = get_configuration('imviz')
    cc['settings']['viewer_spec'] = cc['settings'].get('configuration', 'default')
    cc['settings']['configuration'] = 'novt'
    cc['settings']['visible'] = {'menu_bar': False, 'toolbar': False, 'tray': False,
                                 'tab_headers': False}
    for tool in ['g-data-tools', 'g-viewer-creator', 'g-image-viewer-creator']:
        if tool in cc['toolbar']:
            cc['toolbar'].remove(tool)
    return cc

# start imviz viewer
app = Application(_config())
viz = Imviz(app)
viewer = viz.default_viewer

# callback for data loading
def on_load(event):
    if viewer.state.reference_data is not None:
        coords = viewer.state.reference_data.coords
        if coords is not None:
            ra, dec = coords.wcs.crval
            nirspec_center_ra.value = ra
            nirspec_center_dec.value = dec
            nircam_center_ra.value = ra
            nircam_center_dec.value = dec

viewer.state.add_callback('reference_data', on_load)

In [None]:
# track instrument footprints
instruments = ['NIRSpec', 'NIRCam Long', 'NIRCam Short']
patches = {}


# track the source catalog
catalog_markers = {}


In [None]:
# NIRSpec configuration buttons    
nirspec_label = ipw.Label('NIRSpec Configuration:', style={'font_weight': 'bold'})
nirspec_center_ra = ipw.BoundedFloatText(
    description='RA (deg)', min=0, max=360, 
    step=5/3600, continuous_update=False, 
    style={'description_width': 'initial'})
nirspec_center_dec = ipw.BoundedFloatText(
    description='Dec (deg)', min=0, max=90, 
    step=5/3600, continuous_update=False, 
    style={'description_width': 'initial'})
nirspec_pa = ipw.FloatText(
    description='PA (deg)', #min=0, max=360, 
    step=5, continuous_update=False, 
    style={'description_width': 'initial'})

def on_nrs_center_changed(event):
    instrument = 'NIRSpec'
    if viewer.state.reference_data is not None:
        wcs = viewer.state.reference_data.coords
        if instrument in patches and wcs is not None:
            patches[instrument] = nd.bqplot_footprint(
                viewer.figure, instrument, 
                nirspec_center_ra.value, nirspec_center_dec.value, 
                nirspec_pa.value, wcs, 
                fill='inside', alpha=0.6, update_patches=patches[instrument])

nirspec_center_ra.observe(on_nrs_center_changed)
nirspec_center_dec.observe(on_nrs_center_changed)
nirspec_pa.observe(wrap_angle)
nirspec_pa.observe(on_nrs_center_changed)

In [None]:
# NIRCam configuration buttons    
nircam_label = ipw.Label('NIRCam Configuration:', style={'font_weight': 'bold'})
nircam_center_ra = ipw.BoundedFloatText(
    description='RA (deg)', min=0, max=360, 
    step=5/3600, continuous_update=False, 
    style={'description_width': 'initial'})
nircam_center_dec = ipw.BoundedFloatText(
    description='Dec (deg)', min=0, max=90, 
    step=5/3600, continuous_update=False, 
    style={'description_width': 'initial'})
nircam_pa = ipw.FloatText(
    description='PA (deg)', #min=0, max=360, 
    step=5, continuous_update=False, 
    style={'description_width': 'initial'})

def on_nrc_center_changed(event):
    instruments = ['NIRCam Short', 'NIRCam Long']
    if viewer.state.reference_data is not None:
        wcs = viewer.state.reference_data.coords
        for instrument in instruments:
            if instrument in patches and wcs is not None:
                patches[instrument] = nd.bqplot_footprint(
                    viewer.figure, instrument, 
                    nircam_center_ra.value, nircam_center_dec.value, 
                    nircam_pa.value, wcs, 
                    fill='inside', alpha=0.6, update_patches=patches[instrument])
                    

nircam_center_ra.observe(on_nrc_center_changed)
nircam_center_dec.observe(on_nrc_center_changed)
nircam_pa.observe(wrap_angle)
nircam_pa.observe(on_nrc_center_changed)

In [None]:
# todo: add/remove patches changes view canvas for some reason

def on_footprint_clicked(b):
    if b.description.startswith('Show'):
        if viewer.state.reference_data is not None:
            wcs = viewer.state.reference_data.coords
            if wcs is not None:
                b.description = f'Hide {b.value} FOV'

                if 'NIRS' in b.value:
                    ra = nirspec_center_ra.value
                    dec = nirspec_center_dec.value
                    pa = nirspec_pa.value
                else:
                    ra = nircam_center_ra.value
                    dec = nircam_center_dec.value
                    pa = nircam_pa.value
                patches[b.value] = nd.bqplot_footprint(
                    viewer.figure, b.value, ra, dec, pa, wcs, 
                    fill='inside', alpha=0.6)
    else:            
        b.description = f'Show {b.value} FOV'

        nd.remove_bqplot_patches(viewer.figure, patches[b.value])
        del patches[b.value]
                
fov_buttons = []
for name in instruments:
    button = ButtonWithValue(description=f'Show {name} FOV', value=name, layout=ipw.Layout(width='auto'))
    button.on_click(on_footprint_clicked)
    fov_buttons.append(button)

In [None]:
image_label = ipw.Label('Image file (.fits):', style={'font_weight': 'bold'})
image_file_upload = FileInput(accept='.fits', multiple=False)

catalog_label = ipw.Label('Catalog file (.radec):', style={'font_weight': 'bold'})
catalog_file_upload = FileInput(accept='.radec', multiple=False)

debug_view = ipw.Output()


# watch for uploaded files
# todo: clear viewer if file removed or replaced
# todo: consider allowing multiple, eg for mosaicked field
@debug_view.capture(clear_output=True)
def on_upload_image(event):
    uploaded_files = event.owner.get_files()
    if len(uploaded_files) > 0:
        uploaded_file = uploaded_files[0]
        print(uploaded_file)
        hdul = fits.open(uploaded_file['file_obj'])
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            viz.load_data(hdul, data_label=uploaded_file['name'])
            

image_file_upload.observe(on_upload_image, names='file_info')


# make catalog show/hide button
def on_cat_clicked(b):
    if b.description.startswith('Show'):
        if viewer.state.reference_data is not None:
            uploaded_files = catalog_file_upload.get_files()
            if len(uploaded_files) > 0:
                catalog = uploaded_files[0]['file_obj']
                wcs = viewer.state.reference_data.coords
                try:
                    primary, filler = nd.bqplot_catalog(
                        viewer.figure, catalog, wcs, alpha=0.6)
                except Exception:
                    # todo: need error/status handling
                    pass
                else:
                    catalog_markers['primary'] = primary
                    catalog_markers['filler'] = filler
                    b.description = 'Hide Catalog'
    else:
        b.description = 'Show Catalog'
        if 'primary' in catalog_markers:
            nd.remove_bqplot_patches(
                viewer.figure, [catalog_markers['primary']])
            del catalog_markers['primary']
        if 'filler' in catalog_markers:
            nd.remove_bqplot_patches(
                viewer.figure, [catalog_markers['filler']])
            del catalog_markers['filler']
                
                
cat_button = ipw.Button(description='Show Catalog', layout=ipw.Layout(width='auto'))
cat_button.on_click(on_cat_clicked)

In [None]:
# set layout
button_layout = ipw.Layout(display='flex', flex_flow='row', justify_content='flex-start', padding='5px')
box_layout = ipw.Layout(display='flex', flex_flow='column', align_items='stretch', width='95vw')

b1 = ipw.Box(children=[image_label, image_file_upload], layout=button_layout)
b2 = ipw.Box(children=[catalog_label, catalog_file_upload], layout=button_layout)
b3 = ipw.Box(children=[nirspec_label, nirspec_center_ra, nirspec_center_dec, nirspec_pa], layout=button_layout)
b4 = ipw.Box(children=[nircam_label, nircam_center_ra, nircam_center_dec, nircam_pa], layout=button_layout)
b5 = ipw.Box(children=fov_buttons + [cat_button], layout=button_layout)
box = ipw.Box(children=[b1, b2, b3, b4, b5, viz.app], layout=box_layout)

In [None]:
if DEBUG:
    display(debug_view)

In [None]:
# display widgets
display(box)