In [1]:
##Allows changes in JS to be update without re-running cell (not important for final code, just real time developing)
# %env ANYWIDGET_HMR=1


#For Widget
from anywidget import AnyWidget

#For Initial Conditions and communication between js and python
from traitlets import Unicode, Float, Int, Bool, link
import traitlets


#For Cutting Fits
from astrocut import FITSCutout

#To get use-able coordinates
from astropy.coordinates import SkyCoord 
from astropy.wcs import WCS
#For reading the Fits
from astropy.io import fits
from astropy.io.fits import getheader
from astropy.time import Time

#For Displaying FITS
import matplotlib.pyplot as plt 
%matplotlib inline

import ipywidgets as widgets
import numpy as np
from IPython.display import display
from ipywidgets import Output
from ipywidgets import GridspecLayout
from ipywidgets import VBox, HBox
import warnings

def SnipPyFits(input_files):

    
        # ---- open the file and grab the first 2-D image HDU -------------------
    hdulist = fits.open(input_files[0])
    img_hdu = next(h for h in hdulist if h.data is not None and h.data.ndim == 2)
    
    # ---------- 1.  fix MJD-OBS BEFORE we ever touch WCS -------------------
    def ensure_mjd_obs(header):
        """Add MJD-OBS if it is missing but DATE-OBS is present."""
        if 'DATE-OBS' in header and 'MJD-OBS' not in header:
            try:
                header['MJD-OBS'] = Time(header['DATE-OBS'], format='isot').mjd
            except Exception:
                pass          # ignore malformed DATE-OBS
        return header
    
    # make a *working* copy of the header that now has a guaranteed MJD-OBS
    header = ensure_mjd_obs(img_hdu.header.copy())
    
    data = img_hdu.data
    naxis2, naxis1 = data.shape
    
    
    # ---------- 2.  smart SIP / no-SIP logic (unchanged except that it        
    #               now receives a header that already includes MJD-OBS) -----
    def smart_wcs_from_header(header):
        hdr = header.copy()                         # keep original intact
        has_sip = any(k in hdr for k in ['A_ORDER', 'B_ORDER'])
        sip_ok   = '-SIP' in hdr.get('CTYPE1', '') or '-SIP' in hdr.get('CTYPE2', '')
    
        if has_sip:
            if not sip_ok:                          # patch CTYPE1/2
                for key in ('CTYPE1', 'CTYPE2'):
                    if key in hdr and '-SIP' not in hdr[key]:
                        hdr[key] = hdr[key].strip() + '-SIP'
            wcs = WCS(hdr)                           # distortion kept
        else:
            # strip any stray SIP keywords
            sip_keys = [k for k in hdr
                        if k.startswith(('A_', 'B_', 'AP_', 'BP_')) or
                           k in ('A_ORDER', 'B_ORDER', 'A_DMAX', 'B_DMAX')]
            for k in sip_keys:
                hdr.remove(k, ignore_missing=True, remove_all=True)
            wcs = WCS(hdr)                           # distortion removed
    
        return wcs
    
    
    # ---------- 3. build the WCS with no warnings ---------------------------
    wcs = smart_wcs_from_header(header)
    

    # ---------- Calculate the sky coordinates of the image corners ---------- 
    # Lower-left corner
    ra_min, dec_min = wcs.wcs_pix2world(0, 0, 0)
    
    # Upper-right corner
    ra_max, dec_max = wcs.wcs_pix2world(naxis1, naxis2, 0)
 
    # Mid Point
    mid_ra = ra_max - ((ra_max - ra_min)/2)
    mid_dec = dec_max - ((dec_max - dec_min)/2)



    class TopControlPanel(AnyWidget):
        _esm = "./sharedwidget_V1.js"
        _css = "./sharedwidget_V1.css"
        component = Unicode("TopControlPanel").tag(sync=True)
        ra = traitlets.Float(mid_ra).tag(sync=True)
        dec = traitlets.Float(mid_dec).tag(sync=True)
        cropwidth = traitlets.Float(naxis1).tag(sync=True)
        cropheight = traitlets.Float(naxis2).tag(sync=True)
    
    class LeftControlPanel(AnyWidget):
        _esm = "./sharedwidget_V1.js"
        _css = "./sharedwidget_V1.css"
        component = Unicode("LeftControlPanel").tag(sync=True)
        cropwidth = traitlets.Float(naxis1).tag(sync=True)
        cropheight = traitlets.Float(naxis2).tag(sync=True)
        save_fits = traitlets.Bool(False).tag(sync=True)
        save_png = traitlets.Bool(False).tag(sync=True)

    
    class RightControlPanel(AnyWidget):
        _esm = "./sharedwidget_V1.js"
        _css = "./sharedwidget_V1.css"
        component = Unicode("RightControlPanel").tag(sync=True)
        min_percent = traitlets.Float(10.0).tag(sync=True)
        max_percent = traitlets.Float(99.0).tag(sync=True)
        invert = traitlets.Bool(False).tag(sync=True)
        stretch = traitlets.Unicode("linear").tag(sync=True)

    class BottomControlPanel(AnyWidget):
        component = Unicode("BottomControlPanel").tag(sync=True)
        _esm = "./sharedwidget_V1.js"
        _css = "./sharedwidget_V1.css"
        index = traitlets.Int(0).tag(sync=True)
    
    
    # #Defines the widget & calls it
    TOP = TopControlPanel()
    LEFT = LeftControlPanel()
    RIGHT = RightControlPanel()
    BOTTOM = BottomControlPanel()
    ShowImage = Output()
    link((TOP, 'cropwidth'), (LEFT, 'cropwidth'))
    link((TOP, 'cropheight'), (LEFT, 'cropheight'))

    
    
    def Run_image(change=None):
        with ShowImage:
    
            center_coord = SkyCoord(TOP.ra, TOP.dec, unit="deg")
            cutout_size = [LEFT.cropwidth, LEFT.cropheight]
    
            try:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", category=RuntimeWarning)
              
                    fits_cutout = FITSCutout(
                        input_files=input_files,
                        coordinates=center_coord,
                        cutout_size=cutout_size,
                        single_outfile=True)
        
                    cutouts = fits_cutout.get_image_cutouts(
                        stretch=RIGHT.stretch,
                        invert=RIGHT.invert,
                        minmax_percent=[RIGHT.min_percent, RIGHT.max_percent])
        
                    fits_img = cutouts[BOTTOM.index]
    
                    ShowImage.clear_output(wait=True)
        
                    plt.figure(edgecolor='#01617e', linewidth=15)
                    plt.imshow(fits_img, cmap='gray', origin='lower')
                    plt.axis('off')
                    plt.show()
        
            except (IndexError, ValueError) as err:
                    ShowImage.clear_output(wait=True)
                    plt.figure(figsize=(3, 3),edgecolor='#01617e', linewidth=10)
                    plt.text(0.5, 0.5, "No FITS Found", fontsize=20, ha='center', va='center')
                    plt.axis('off')
                    plt.show()

    
    
    for name in ['ra',
                 'dec',
                 'cropwidth',
                 'cropheight']:
        TOP.observe(Run_image, names=name)
    
    for name in ['cropwidth',
                 'cropheight']:
        LEFT.observe(Run_image, names=name)

    for name in ['min_percent',
                 'max_percent',
                 'invert',
                 'stretch']:
        RIGHT.observe(Run_image, names=name)

        
    for name in ['index']:
        BOTTOM.observe(Run_image, names=name)
    
    def saveFITS(change):
        
        try:
            center_coord = SkyCoord(TOP.ra, TOP.dec, unit="deg")
            cutout_size = [TOP.cropwidth, TOP.cropheight]

            fits_cutout = FITSCutout(
                            input_files=input_files,
                            coordinates=center_coord,
                            cutout_size=cutout_size,
                            single_outfile=False)

            fits_cutout.write_as_fits()
        except:
            pass
        finally:
            LEFT.save_fits = False
    
    
    LEFT.observe(saveFITS, names="save_fits")


    def savePNG(change):
        
        try:
            center_coord = SkyCoord(TOP.ra, TOP.dec, unit="deg")
            cutout_size = [TOP.cropwidth, TOP.cropheight]

            fits_cutout = FITSCutout(
                            input_files=input_files,
                            coordinates=center_coord,
                            cutout_size=cutout_size,
                            single_outfile=True)
            
            fits_cutout.write_as_img(
                        stretch=RIGHT.stretch,
                        invert=RIGHT.invert,
                        minmax_percent=[RIGHT.min_percent, RIGHT.max_percent],
                        output_format='png')
        except Exception:
            pass
        finally:
            LEFT.save_png = False
    
    
    LEFT.observe(savePNG, names="save_png")


    Run_image()
    # display(widgets.AppLayout(header = TOP,
    #                           left_sidebar = LEFT,
    #                           center = ShowImage,
    #                           right_sidebar = RIGHT,
    #                           footer = BOTTOM))

    
    # grid = GridspecLayout(3, 3, height='625px', justify_items='left')
    # display(grid)
    # # grid[0, 0] =
    # grid[0, 1] = TOP
    # # grid[0, 2] = 
    # grid[1, 0] = LEFT
    # grid[1, 1] = ShowImage
    # grid[1, 2] = RIGHT
    # # grid[2, 0] =
    # grid[2, 1] = BOTTOM
    # # grid[2, 2] = 


    AppLayout = HBox([LEFT, VBox([ShowImage,BOTTOM]), VBox([TOP,RIGHT])])
    display(AppLayout)



# SnipPyFits(["https://mast.stsci.edu/search/hst/api/v0.1/retrieve_product?product_name=J9FB06020%2Fj9fb06020_drz.fits",
#             "https://mast.stsci.edu/search/hst/api/v0.1/retrieve_product?product_name=J9FB06030%2Fj9fb06030_drz.fits",
#              "https://mast.stsci.edu/search/hst/api/v0.1/retrieve_product?product_name=J9FB08010%2Fj9fb08010_drz.fits"
#            ])
# SnipPyFits(['./cutout_148.9710279_69.6801016_294.0-x-285.0_astrocut.fits'])

In [2]:
# %env ANYWIDGET_HMR=1
# SnipPyFits(["https://mast.stsci.edu/search/hst/api/v0.1/retrieve_product?product_name=J9FB06020%2Fj9fb06020_drz.fits",
#             "https://mast.stsci.edu/search/hst/api/v0.1/retrieve_product?product_name=J9FB06030%2Fj9fb06030_drz.fits",
#             "https://mast.stsci.edu/search/hst/api/v0.1/retrieve_product?product_name=J9FB08010%2Fj9fb08010_drz.fits"
#            ])

SnipPyFits(["https://archive.stsci.edu/pub/hlsp/candels/cosmos/cos-tot/v1.0/hlsp_candels_hst_acs_cos-tot-sect23_f606w_v1.0_drz.fits",
            "https://archive.stsci.edu/pub/hlsp/candels/cosmos/cos-tot/v1.0/hlsp_candels_hst_acs_cos-tot-sect23_f814w_v1.0_drz.fits"])

HBox(children=(LeftControlPanel(), VBox(children=(Output(), BottomControlPanel())), VBox(children=(TopControlP…