In [None]:
import io
import os
import inspect
import logging
import threading
import random

import numpy as np

from astropy.io import fits
from astropy import units as u
from astropy import wcs
from astropy import visualization
from astropy.visualization import wcsaxes
from astropy.nddata import CCDData
from astropy.utils.data import download_file

import ipywidgets

from astroquery import mast

from IPython import display

import flask

%matplotlib inline
from matplotlib import pyplot as plt

In [None]:
default_file_name = 'https://astropy.stsci.edu/data/photometry/spitzer_example_image.fits'

ccdis = []

In [None]:
filename = download_file(default_file_name)
print(filename)

In [None]:
ccdis.append(CCDData.read(filename))

In [None]:

"""
Writing PNGs
"""

#-----------------------------------------------------------------------------
# Copyright (c) 2013, yt Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------

import matplotlib._png as _png

from io import BytesIO as StringIO

def call_png_write_png(buffer, width, height, filename, dpi):
    _png.write_png(buffer, filename, dpi)

def write_png(buffer, filename, dpi=100):
    width = buffer.shape[1]
    height = buffer.shape[0]
    call_png_write_png(buffer, width, height, filename, dpi)

def write_png_to_string(buffer, dpi=100, gray=0):
    width = buffer.shape[1]
    height = buffer.shape[0]
    fileobj = StringIO()
    call_png_write_png(buffer, width, height, fileobj, dpi)
    png_str = fileobj.getvalue()
    fileobj.close()
    return png_str

In [None]:
# SERVER:

visdat = None
_last_stretch = None
def re_stretch(stretch):
    global visdat, _last_stretch
    visdat = np.flip(stretch(ccdis[0].data), 0)
    _last_stretch = stretch
re_stretch(visualization.LogStretch() + visualization.PercentileInterval(99))

def ccd_to_pngstr_app(dat):
    return write_png_to_string((dat*255).astype('uint8')[:,:,np.newaxis])

logstream = io.StringIO()
logging.basicConfig(stream=logstream)

app = flask.Flask(__name__)

@app.route('/')
def hello_world():
    return 'Hello, World!'


@app.route('/fits<int:cachebuster>/<string:z>/<int:x>/<int:y>.png')
def get_subfits(z, x, y, cachebuster):
    z = int(z)
    z0 = np.log2(256/visdat.shape[0])
    z = z  - 1 + z0
    z = int(np.ceil(z))
    
    if z < 0:
        factor = 2**-z
        wid = 256 * factor
        xrng = slice(x*wid, (x+1)*wid, factor)
        yrng = slice(y*wid, (y+1)*wid, factor)

        subdat = visdat[yrng, xrng]
    else:
        wid = 256//(2**z)
        xrng = slice(x*wid, (x+1)*wid)
        yrng = slice(y*wid, (y+1)*wid)
        subdat = visdat[yrng, xrng]
        if z > 1:
            subdat = subdat.repeat(z, 0).repeat(z, 1)
            wid = subdat.shape[0]
              
    if subdat.shape != (wid, wid):
        if 0 in subdat.shape:
            1/0
        else:
            #pad out with nans
            temp = subdat
            subdat = np.empty((256, 256), dtype=temp.dtype)
            subdat.fill(np.nan)
            subdat[:temp.shape[0], :temp.shape[1]] = temp
    return ccd_to_pngstr_app(subdat)

#app.run(debug=True, use_reloader=False, port=5013)

In [None]:
th = threading.Thread(target=lambda:app.run(debug=False, use_reloader=False, port=5013))
th.start()

In [None]:
import ipyleaflet

In [None]:
url_templ = 'http://127.0.0.1:5013/fitsNUM/{z}/{x}/{y}.png'

m = ipyleaflet.Map(center=(60, 0), zoom=2, layers=[], 
                   min_zoom=1, scroll_wheel_zoom=True)


cachebuster_int = random.randint(0, 1000000)
local_fits_layer = ipyleaflet.basemap_to_tiles({'url': url_templ.replace('NUM', str(cachebuster_int)), 
                                                'attribution': 'fitsfile'})
local_fits_layer.cachebuster_int = cachebuster_int

m.add_layer(local_fits_layer)

m

# Interactivity: 

In [None]:
def refresh():
    local_fits_layer.cachebuster_int += 1
    local_fits_layer.url = url_templ.replace('NUM', str(local_fits_layer.cachebuster_int))
    
def rere_stretch(stretch):
    re_stretch(stretch)
    refresh()
    
def load_file(fn, **kwargs):
    ccdis[0] = CCDData.read(fn, **kwargs)
    rere_stretch(_last_stretch)

In [None]:
rere_stretch(visualization.LogStretch() + visualization.PercentileInterval(95))

In [None]:
# note: the cell below assumes the existence of a specific example file.
# uncomment the code below to download that file.

from urllib.request import urlretrieve
urlretrieve('https://mast.stsci.edu/api/v0.1/Download/file/?uri=mast:HST/product/jclj01tvq_flc.fits', 
            'jclj01tvq_flc.fits')

In [None]:
load_file('jclj01tvq_flc.fits', 
          unit=u.count,
          hdu=4)

In [None]:
rere_stretch(visualization.LogStretch() + 
             visualization.AsymmetricPercentileInterval(50, 99.9))