In [None]:
import io
import inspect
import logging
import threading

import numpy as np

from astropy.io import fits
from astropy import units as u
from astropy import wcs
from astropy import visualization
from astropy.visualization import wcsaxes
from astropy.nddata import CCDData

import ipywidgets

from astroquery import mast

from IPython import display

import flask

%matplotlib inline
from matplotlib import pyplot as plt

In [None]:

"""
Writing PNGs
"""

#-----------------------------------------------------------------------------
# Copyright (c) 2013, yt Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------

import matplotlib._png as _png

from io import BytesIO as StringIO

def call_png_write_png(buffer, width, height, filename, dpi):
    _png.write_png(buffer, filename, dpi)

def write_png(buffer, filename, dpi=100):
    width = buffer.shape[1]
    height = buffer.shape[0]
    call_png_write_png(buffer, width, height, filename, dpi)

def write_png_to_string(buffer, dpi=100, gray=0):
    width = buffer.shape[1]
    height = buffer.shape[0]
    fileobj = StringIO()
    call_png_write_png(buffer, width, height, fileobj, dpi)
    png_str = fileobj.getvalue()
    fileobj.close()
    return png_str

In [None]:
obses = mast.Observations.query_region('M32', radius=30*u.arcsec)

acsm32 = obses[(obses['instrument_name']=='ACS/WFC')&(obses['filters']=='F555W')&(obses['obs_collection']=='HST')]
acsm32 = acsm32[np.argmax(acsm32['t_exptime'])]
acsm32

In [None]:
products = mast.Observations.get_product_list(acsm32)
products[products['productSubGroupDescription']=='DRZ']

In [None]:
downloaded = mast.Observations.download_products(products[products['productSubGroupDescription']=='DRZ'])
f = fits.open(downloaded['Local Path'][0])
f.info()

In [None]:
c1 = CCDData(f[1].data, unit=u.electron, wcs=wcs.WCS(f[1].header, f), header=f[1].header)

# Server 

In [None]:
visdat = None
def re_stretch(stretch):
    global visdat
    visdat = np.flip(stretch(c1.data), 0)
re_stretch(visualization.stretch.LinearStretch())

In [None]:
def ccd_to_pngstr_app(dat):
    return write_png_to_string((dat*255).astype('uint8')[:,:,np.newaxis])

In [None]:
logstream = io.StringIO()
logging.basicConfig(stream=logstream)

In [None]:
app = flask.Flask(__name__)

@app.route('/')
def hello_world():
    return 'Hello, World!'


@app.route('/fits<int:cachebuster>/<string:z>/<int:x>/<int:y>.png')
def get_subfits(z, x, y, cachebuster):
    z = int(z)
    if z < 0:
        factor = 2**-z
        wid = 256 * factor
        xrng = slice(x*wid, (x+1)*wid, factor)
        yrng = slice(y*wid, (y+1)*wid, factor)

        subdat = visdat[yrng, xrng]
    else:
        wid = 256//(2**z)
        xrng = slice(x*wid, (x+1)*wid)
        yrng = slice(y*wid, (y+1)*wid)
        subdat = visdat[yrng, xrng]
        if z > 1:
            subdat = subdat.repeat(z, 0).repeat(z, 1)
            wid = subdat.shape[0]
              
    if subdat.shape != (wid, wid):
        if 0 in subdat.shape:
            1/0
        else:
            #pad out with nans
            temp = subdat
            subdat = np.empty((256, 256), dtype=temp.dtype)
            subdat.fill(np.nan)
            subdat[:temp.shape[0], :temp.shape[1]] = temp
    return ccd_to_pngstr_app(subdat)

## A threading approach 

In [None]:
th = threading.Thread(target=lambda:app.run(debug=False, use_reloader=False, port=5001))
th.start()

# Start up view, run server below once server is running 

In [None]:
viewerstr= """
<link rel="stylesheet" href="https://unpkg.com/leaflet@1.4.0/dist/leaflet.css" />
<script src="https://unpkg.com/leaflet@1.4.0/dist/leaflet.js"></script>
  <style>
    #map{ height: <HEIGHTPX>px;}
  </style>
  <div id="map"></div>

  <script>


var map = L.map('map', {crs: L.CRS.Simple}).setView([1, 1], 1);

map.redrawint = Math.floor( Math.random() * 200000 ) + 1
var getRedrawInteger = function() {
    return map.redrawint;
};

var fitslayer = L.tileLayer('http://127.0.0.1:5001/fits{cachebuster}/{z}/{x}/{y}.png', {
    attribution: 'fitsfile', minZoom: -3, maxZoom:5, cachebuster: getRedrawInteger
}).addTo(map);

  </script>
""".replace('<HEIGHTPX>', '600')

slider = ipywidgets.FloatSlider(value=99, min=90, max=100, step=.01, description='Perc:')
o2 = ipywidgets.Output()

stretches = []
for k,v in visualization.stretch.__dict__.items():
    if inspect.isclass(v) and issubclass(v, visualization.BaseStretch) and v is not visualization.BaseStretch:
        stretches.append((k, v))
        
stretch_dropdown = ipywidgets.Dropdown(options=stretches)

def update():
    re_stretch(stretch_dropdown.value() + visualization.PercentileInterval(slider.value))
    with o2:
        display.display(display.Javascript('map.redrawint += 1;fitslayer.redraw();'))
        #display.display(change)
    o2.clear_output()

vals = []
def on_slide_change(change):
    vals.append(change['new'])
    update()
def on_dropdown_change(change):
    vals.append(change['new'])
    update()
    
slider.observe(on_slide_change, names='value')
stretch_dropdown.observe(on_dropdown_change, names='value')

o_js = ipywidgets.Output()
with o_js:
    display.display(display.HTML(viewerstr))
    
update()

## Inline: 

In [None]:
ipywidgets.VBox([o_js, ipywidgets.HBox([slider, stretch_dropdown]), o2])

## OR in a sidecar: 

In [None]:
from sidecar import Sidecar
from ipywidgets import IntSlider

sc = Sidecar(title='Imageviewer')
with sc:
    display.display(ipywidgets.VBox([o_js, o2]))
ipywidgets.VBox([slider, stretch_dropdown])

# CAVEATS/PROBLEMS:

* The threading approach is bad as the thread is un-killable. Might be better to use multiprocessing, or not flask
* A separate server is providing the tiles.  That server connection then has to be managed, an error-prone task if it's not on a local machine. If we could attach to the jupyter server itself that would be better.

In [None]:
stretch_dropdown

# Blocking version 

In [None]:
# BLOCKING!
app.run(debug=True, use_reloader=False, port=5001)  #requires ipykernel>4.9 ...

## A multiprocessing approach 

In [None]:
import multiprocessing

pserv = multiprocessing.Process(target=lambda:app.run(debug=True, use_reloader=False, port=5001))
pserv.start()

# Initial experiments on the data

In [None]:
fig = plt.figure(figsize=(10, 10))
ax1 = plt.subplot(2, 1, 1, projection=c1.wcs)
ax2 = plt.subplot(2, 1, 2, projection=c4.wcs)

v1 = visualization.imshow_norm(c1, ax=ax1, 
                               stretch=visualization.AsinhStretch(), 
                               interval=visualization.PercentileInterval(97))
v4 = visualization.imshow_norm(c4, ax=ax2, 
                               stretch=visualization.AsinhStretch(), 
                               interval=visualization.PercentileInterval(97))

v4 = visualization.imshow_norm(c4, stretch=visualization.AsinhStretch(), 
                          interval=visualization.PercentileInterval(97))

In [None]:
ls = visualization.stretch.LogStretch()
pi = visualization.PercentileInterval(99)
comb = ls + pi
comb2 = pi + ls

%timeit comb(c1)
%timeit comb2(c1)

In [None]:
%timeit pstr = write_png_to_string((comb(c1)*255).astype('uint8')[:,:,np.newaxis])

In [None]:
def ccd_to_pngstr(ccdd, stretch):
    return write_png_to_string((stretch(ccdd.data)*255).astype('uint8')[:,:,np.newaxis])

In [None]:
for l,w in [(2048, 4096),(2048, 2048),(1024, 1024),(512, 512)]:
    print(l, 'x', w)
    %timeit ccd_to_pngstr(c1[:l,:w], comb)

In [None]:
(2048*4096)/(512*512)  * 19.7

In [None]:
display.Image(data=ccd_to_pngstr(c1[:512, :512], visualization.stretch.LogStretch() + visualization.PercentileInterval(99)))