In [1]:
from anywidget import AnyWidget
from traitlets import Unicode, Float, Int, Bool, link
import traitlets
from astrocut import FITSCutout
from astropy.coordinates import SkyCoord 
from astropy.wcs import WCS
from astropy.io import fits
from astropy.io.fits import getheader
from astropy.time import Time
import matplotlib.pyplot as plt 
import ipywidgets as widgets
import numpy as np
from IPython.display import display, HTML
from ipywidgets import Output
from ipywidgets import GridspecLayout
from ipywidgets import VBox, HBox
import warnings


import asdf
from astrocut import ASDFCutout


class SnipPyASDF(AnyWidget):
    class Coordinates(AnyWidget):
        component = Unicode("Coordinates").tag(sync=True)
        _esm = "./sharedwidget_V2.js"
        _css = "./sharedwidget_V2.css"
        ra = traitlets.Float().tag(sync=True)
        dec = traitlets.Float().tag(sync=True)
    class Ratio(AnyWidget):
        component = Unicode("Ratio").tag(sync=True)
        _esm = "./sharedwidget_V2.js"
        _css = "./sharedwidget_V2.css"
        cropwidth = traitlets.Float().tag(sync=True)
        cropheight = traitlets.Float().tag(sync=True)
    class Normalization(AnyWidget):
        component = Unicode("Normalization").tag(sync=True)
        _esm = "./sharedwidget_V2.js"
        _css = "./sharedwidget_V2.css"
        min_percent = traitlets.Float(10.0).tag(sync=True)
        max_percent = traitlets.Float(99.0).tag(sync=True)
        invert = traitlets.Bool(False).tag(sync=True)
        invertbut = traitlets.Bool(False).tag(sync=True)
        stretch = traitlets.Unicode("linear").tag(sync=True)
    class Save(AnyWidget):
        component = Unicode("Save").tag(sync=True)
        _esm = "./sharedwidget_V2.js"
        _css = "./sharedwidget_V2.css"
        save_fits = traitlets.Bool(False).tag(sync=True)
        save_png = traitlets.Bool(False).tag(sync=True)
        save_color = traitlets.Bool(False).tag(sync=True)
        save_memory = traitlets.Bool(False).tag(sync=True)
    class ImageCounter(AnyWidget):
        component = Unicode("ImageCounter").tag(sync=True)
        _esm = "./sharedwidget_V2.js"
        _css = "./sharedwidget_V2.css"
        index = traitlets.Int(0).tag(sync=True)
        total = traitlets.Int(1).tag(sync=True)

    def __init__(self, input_files):
        super().__init__()
        self.input_files = input_files
        
        hdulist = fits.open(input_files[0])
        img_hdu = next(h for h in hdulist if h.data is not None and h.data.ndim == 2)

        header = self.ensure_mjd_obs(img_hdu.header.copy())
        data = img_hdu.data
        naxis2, naxis1 = data.shape
        wcs = self.smart_wcs_from_header(header)
        
        ra_min, dec_min = wcs.wcs_pix2world(0, 0, 0)
        ra_max, dec_max = wcs.wcs_pix2world(naxis1, naxis2, 0)
        mid_ra = ra_max - ((ra_max - ra_min)/2)
        mid_dec = dec_max - ((dec_max - dec_min)/2)


        self.TOP = self.Coordinates()
        self.LEFT = self.Ratio()
        self.RIGHT = self.Normalization()
        self.SAVE = self.Save()
        self.BOTTOM = self.ImageCounter()
        self.ShowImage = Output()
        self.BOTTOM.total = len(input_files)

        
        self.TOP.ra = mid_ra
        self.TOP.dec = mid_dec
        self.LEFT.cropwidth = naxis1
        self.LEFT.cropheight = naxis2
        self.cutout = self.saveMemory()


        for name in ['ra',
                     'dec']:
            self.TOP.observe(self.Run_image, names=name)
    
        for name in ['cropwidth',
                     'cropheight']:
            self.LEFT.observe(self.Run_image, names=name)
        for name in ['min_percent',
                     'max_percent',
                     'invertbut',
                     'stretch']:
            self.RIGHT.observe(self.Run_image, names=name)
        for name in ['index']:
            self.BOTTOM.observe(self.Run_image, names=name)
            
        self.SAVE.observe(self.saveFITS, names="save_fits")
        self.SAVE.observe(self.savePNG, names="save_png")
        self.SAVE.observe(self.saveColor, names="save_color") 
        self.SAVE.observe(self.saveMemory, names="save_memory")

        

        self.Run_image()
        AppLayout = HBox([VBox([self.ShowImage,self.BOTTOM]), VBox([self.TOP,self.LEFT, self.RIGHT, self.SAVE])])
        container = HTML("<style>.Background > .widget-box { background-color: #000f14; padding: 25px; }</style>")
        widget = VBox([AppLayout])
        widget.add_class("Background")
        display(container, widget)




    
    def ensure_mjd_obs(self, header):
        """Add MJD-OBS if it is missing but DATE-OBS is present."""
        if 'DATE-OBS' in header and 'MJD-OBS' not in header:
            try:
                header['MJD-OBS'] = Time(header['DATE-OBS'], format='isot').mjd
                
            except Exception:
                pass
        return header
    
    def smart_wcs_from_header(self, header):
        hdr = header.copy()                       
        has_sip = any(k in hdr for k in ['A_ORDER', 'B_ORDER'])
        sip_ok   = '-SIP' in hdr.get('CTYPE1', '') or '-SIP' in hdr.get('CTYPE2', '')
        if has_sip:
            if not sip_ok:
                for key in ('CTYPE1', 'CTYPE2'):
                    if key in hdr and '-SIP' not in hdr[key]:
                        hdr[key] = hdr[key].strip() + '-SIP'
            wcs = WCS(hdr)                          
        else:
            sip_keys = [k for k in hdr
                        if k.startswith(('A_', 'B_', 'AP_', 'BP_')) or
                           k in ('A_ORDER', 'B_ORDER', 'A_DMAX', 'B_DMAX')]
            for k in sip_keys:
                hdr.remove(k, ignore_missing=True, remove_all=True)
            wcs = WCS(hdr)
        return wcs

    def Run_image(self, change=None):
        
        with self.ShowImage:
            try:
                center_coord = SkyCoord(self.TOP.ra, self.TOP.dec, unit="deg")
                cutout_size = [self.LEFT.cropwidth, self.LEFT.cropheight]
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", category=RuntimeWarning)
                    warnings.filterwarnings("ignore", category=Warning, module="astropy.wcs")
              
                    fits_cutout = FITSCutout(
                        input_files=self.input_files,
                        coordinates=center_coord,
                        cutout_size=cutout_size,
                        single_outfile=True)

                    cutouts = fits_cutout.get_image_cutouts(
                        stretch=self.RIGHT.stretch,
                        invert=self.RIGHT.invertbut,
                        minmax_percent=[self.RIGHT.min_percent, self.RIGHT.max_percent])
        
                    fits_img = cutouts[self.BOTTOM.index]
    
                    self.ShowImage.clear_output(wait=True)
        
                    fig, ax = plt.subplots(figsize=(5, 5), dpi=100, facecolor='#000f14')

                    # Display the image
                    ax.imshow(fits_img, cmap='gray', origin='upper')
                    
                    # Remove ticks and labels
                    ax.set_xticks([])
                    ax.set_yticks([])
                    
                    # Set a uniform 1px white border
                    for spine in ax.spines.values():
                        spine.set_visible(True)
                        spine.set_color('#ffffff')
                        spine.set_linewidth(1)
                    
                    # Remove all padding
                    fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
                    
                    plt.show()



                    
        
            except (IndexError, ValueError) as err:
                    self.ShowImage.clear_output(wait=True)
                    # plt.figure(figsize=(3, 3))
                    plt.figure(figsize=(5, 5),edgecolor='#ffffff', linewidth=1)
                    plt.text(0.5, 0.5, "No FITS Found", fontsize=20, ha='center', va='center')
                    plt.axis('off')
                    plt.show() 


    def saveFITS(self, change):
            try:
                center_coord = SkyCoord(self.TOP.ra, self.TOP.dec, unit="deg")
                cutout_size = [self.LEFT.cropwidth, self.LEFT.cropheight]
    
                fits_cutout = FITSCutout(
                                input_files=self.input_files,
                                coordinates=center_coord,
                                cutout_size=cutout_size,
                                single_outfile=False)
    
                fits_cutout.write_as_fits()
            except:
                pass
            finally:
                self.SAVE.save_fits = False

    
    
    
    
    
        
    
    def savePNG(self, change):
        try:
            center_coord = SkyCoord(self.TOP.ra, self.TOP.dec, unit="deg")
            cutout_size = [self.LEFT.cropwidth, self.LEFT.cropheight]

            fits_cutout = FITSCutout(
                            input_files=self.input_files,
                            coordinates=center_coord,
                            cutout_size=cutout_size,
                            single_outfile=True)
            
            fits_cutout.write_as_img(
                        stretch=self.RIGHT.stretch,
                        invert=self.RIGHT.invertbut,
                        minmax_percent=[self.RIGHT.min_percent, self.RIGHT.max_percent],
                        output_format='png')
        except Exception:
            pass
        finally:
            self.SAVE.save_png = False
    





    def saveColor(self, change):
        try:
            center_coord = SkyCoord(self.TOP.ra, self.TOP.dec, unit="deg")
            cutout_size = [self.LEFT.cropwidth, self.LEFT.cropheight]

            fits_cutout = FITSCutout(
                            input_files=self.input_files,
                            coordinates=center_coord,
                            cutout_size=cutout_size,
                            single_outfile=False)

            fits_cutout.write_as_img(
                        stretch=self.RIGHT.stretch,
                        invert=self.RIGHT.invertbut,
                        minmax_percent=[self.RIGHT.min_percent, self.RIGHT.max_percent],
                        colorize=True,
                        output_format='png')
            
        except Exception:
            pass
        finally:
            self.SAVE.save_color = False





    
    def saveMemory(self, change=None):
        try:
                center_coord = SkyCoord(self.TOP.ra, self.TOP.dec, unit="deg")
                cutout_size = [self.LEFT.cropwidth, self.LEFT.cropheight]
    
                fits_cutout = FITSCutout(
                                input_files=self.input_files,
                                coordinates=center_coord,
                                cutout_size=cutout_size,
                                single_outfile=True)
                self.cutout = fits_cutout.fits_cutouts[0]
        except Exception:
            self.cutout = None
        finally:
            self.SAVE.save_memory = False
    
        return self.cutout


snip = SnipPyFits(["https://archive.stsci.edu/pub/hlsp/candels/cosmos/cos-tot/v1.0/hlsp_candels_hst_acs_cos-tot-sect23_f606w_v1.0_drz.fits"
           ])

VBox(children=(HBox(children=(VBox(children=(Output(), ImageCounter(total=3))), VBox(children=(Coordinates(dec…