In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
from ipywidgets import interact

import bokeh
from bokeh.palettes import viridis, Category10
from bokeh.plotting import figure, output_file, show, output_notebook, save, ColumnDataSource
from bokeh.models import GlyphRenderer, LinearColorMapper, LinearAxis, Range1d, Span, Label, LabelSet, Legend
from bokeh.io import push_notebook, export_png, reset_output, export_svgs, curdoc
from bokeh.layouts import gridplot, layout, row, column
from bokeh.themes import built_in_themes

#from numba import jit, njit

bokeh.__version__

In [None]:
from mpdaf.obj.bokeh import *

In [None]:
output_notebook()

In [None]:
%mpdaf

In [None]:
#s = Source.from_file('/home/simon/data/UDF/udf_origin_00223.fits')
#s = Source.from_file('/home/simon/data/UDF/udf10_c031_e021/udf_udf10_00002.fits')
s = Source.from_file('/home/simon/data/UDF/udf_udf10_00002.fits')
#s = Source.from_file('/muse/UDF/private/analysis/UDF-10/udf10_sources/udf_udf10_00002.fits')

In [None]:
from astropy.visualization import ZScaleInterval, PercentileInterval, MinMaxInterval
from bokeh.models import LogColorMapper, LogTicker, ColorBar, LinearColorMapper

def plot_image(im, size=(350, 350), title=None, colorbar=True, palette="Viridis256", scale='minmax', 
               axis=True, x_range=None, y_range=None, center=None, catalog=None):
    if scale == 'zscale':
        interval = ZScaleInterval()
    elif scale == 'percentile':
        interval = PercentileInterval(99)
    elif scale == 'minmax':
        interval = MinMaxInterval()

    vmin, vmax = interval.get_limits(im.data)
    color_mapper = LinearColorMapper(palette=palette, low=vmin, high=vmax)
    if x_range is None:
        x_range = (0, im.shape[0])
    if y_range is None:
        y_range = (0, im.shape[1])

    if catalog:
        tooltips = [
            ("ID", "@ID"),
            #("(x,y)", "($x, $y)"),
            ("pos", "(@RA, @DEC)"),
            ("mag F775W", "@MAG_F775W"),
        ]
    else:
        tooltips = None

    p = figure(plot_width=size[0], plot_height=size[1], tooltips=tooltips,
               x_range=x_range, y_range=y_range, title=title)
    p.image(image=[im.data.filled()], x=[0], y=[0], dw=[im.shape[0]], dh=[im.shape[1]], 
            color_mapper=color_mapper)
    p.grid.visible = False
    p.axis.visible = axis

    if center:
        y, x = im.wcs.sky2pix(center)[0]
        span = Span(location=x, dimension='height', line_color='white', line_width=1, line_alpha=0.5)
        p.add_layout(span)
        span = Span(location=y, dimension='width', line_color='white', line_width=1, line_alpha=0.5)
        p.add_layout(span)
        p.circle(x, y, size=10, line_color='red', line_width=2, line_alpha=0.6, fill_color=None)

    if catalog:
        #cat = catalog.copy(copy_data=True)
        #skypos = np.array([cat['DEC'], cat['RA']]).T
        #cat['y'], cat['x'] = im.wcs.sky2pix(skypos).T
        
        p.circle('x', 'y', source=catalog, size=5, line_color='white', line_width=1, line_alpha=0.6, fill_color=None)
        label = LabelSet(x='x', y='y', source=catalog, x_offset=2, y_offset=2, text='ID', text_font_size='10px', text_color='white')
        p.add_layout(label)

        #for row in cat:
        #    p.circle(row['x'], row['y'], size=5, line_color='white', line_width=1, line_alpha=0.6, fill_color=None)
        #    label = Label(x=row['x'], y=row['y'], x_offset=2, y_offset=2, text=str(row['ID']), text_font_size='10px', text_color='white')
        #    p.add_layout(label)

    if colorbar:
        color_bar = ColorBar(color_mapper=color_mapper, #ticker=LogTicker(),
                             label_standoff=12, border_line_color=None, location=(0,0))
        p.add_layout(color_bar, 'right')

    return p

In [None]:
def images(s, size=(350, 350)):
    images = {
        'MUSE_WHITE': {'scale': 'zscale'},
        'MUSE_EXPMAP': {'scale': 'minmax', 'palette': 'Greys256', 'link': 'MUSE_WHITE'},
        'MASK_OBJ': {'scale': 'minmax', 'palette': 'Greys256', 'link': 'MUSE_WHITE'}, 
        'HST_F606W': {'scale': 'percentile'}, 
        'HST_F775W': {'scale': 'percentile', 'link': 'HST_F606W'}
    }
    ranges = {}
    
    cat = s.tables['HST_CAT'].copy(copy_data=True)
    skypos = np.array([cat['DEC'], cat['RA']]).T
    #cat['y_muse'], cat['x_muse'] = s.images['MUSE_WHITE'].wcs.sky2pix(skypos).T
    #cat['y_hst'], cat['x_hst'] = s.images['HST_F606W'].wcs.sky2pix(skypos).T
    data = ColumnDataSource(cat.to_pandas())

    for name, params in images.items():
        im = s.images[name]
        kw = dict(size=size, title=name, colorbar=False, axis=False,
                  scale=params['scale'], palette=params.get('palette', 'Viridis256'))
        if params.get('link') is not None:
            kw['x_range'] = ranges[params['link']].x_range
            kw['y_range'] = ranges[params['link']].y_range

        cat['y'], cat['x'] = im.wcs.sky2pix(skypos).T
        data = ColumnDataSource(cat.to_pandas())

        if name.startswith(('HST_', 'MUSE_WHITE')):
            s2 = plot_image(im, center=(s.DEC, s.RA), catalog=data, **kw)
        else:
            s2 = plot_image(im, center=(s.DEC, s.RA), **kw)

        ranges[name] = s2

    return list(ranges.values())

In [None]:
#show(row(images(s, size=(250, 250))))

In [None]:
def spectrum(s, size=(800, 350), axis_labels=True, lbrange=None, show_legend=True,
             snames=['MUSE_TOT_SKYSUB', 'MUSE_PSF_SKYSUB', 'MUSE_WHITE_SKYSUB']):
    
    s1 = figure(plot_width=size[0], plot_height=size[1])
    palette = Category10[8]

    # plot lines
    z = s.z[s.z['Z_DESC'] == 'MUSE'][0]['Z']
    sp = s.spectra[snames[0]]
    #lines = get_emlines(z=z, lbrange=sp.wave.get_range(), table=True)
    for line in s.lines:
        span = Span(location=line['LBDA_OBS'], dimension='height', line_color='black', line_width=1, line_alpha=0.6, line_dash='dashed')
        s1.add_layout(span)
        label = Label(x=line['LBDA_OBS'], y=size[1]-120, y_units='screen', angle=90, angle_units='deg', text=line['LINE'], text_font_size='10px')
        s1.add_layout(label)

    legend_items = []
    smin, smax = np.inf, -np.inf
    for i, sname in enumerate(snames):
        sp = s.spectra[sname]
        smin = min(smin, sp.data.min())
        smax = max(smax, sp.data.max())
        if lbrange:
            sp = sp.subspec(lbrange[0], lbrange[1])
        line = s1.line(sp.wave.coord(), sp.data, #legend=sname.lstrip('MUSE_'), 
                       color=palette[i])
        legend_items.append((sname.lstrip('MUSE_'), [line]))
        if i > 0:
            line.visible = False

    s1.extra_y_ranges = {"var": Range1d(start=0, end=sp.var.max())}
    s1.y_range = Range1d(smin - 20, smax + 20)
    s1.line(sp.wave.coord(), sp.var, line_color='gray', line_alpha=0.6, y_range_name="var")
    #s1.add_layout(LinearAxis(y_range_name="var"), 'left')

    legend = Legend(items=legend_items, location=(0, 0))

    s1.add_layout(legend, 'above')

    s1.legend.location = "top_left"
    s1.legend.visible = show_legend
    s1.legend.label_text_font_size = '12px'
    s1.legend.padding = 0
    s1.legend.background_fill_alpha = 0.5
    s1.legend.orientation = "horizontal"
    
    s1.yaxis.major_label_orientation = "vertical"
    if axis_labels:
        s1.xaxis.axis_label = f'Wavelength ({sp.wave.unit})'
        s1.yaxis.axis_label = f'Flux ({sp.unit})'
    
    s1.legend.click_policy = "hide"

    return s1

In [None]:
#show(spectrum(s, size=(750, 400), axis_labels=False))

In [None]:
from bokeh.models.widgets import Div

def legend(s, size=(800, 200)):
    z = s.z[s.z['Z_DESC'] == 'MUSE'][0]['Z']
    div = Div(text=f"""<h4>Source #{s.ID}, ({s.RA:.6f}, {s.DEC:.6f}), z={z:.3f}</h4>""", width=size[0], height=size[1])
    return [div]

In [None]:
from mpdaf.sdetect import get_emlines

def plot_lines(s, size=(250,250)):
    #z = s.z[s.z['Z_DESC'] == 'MUSE'][0]['Z']
    #lines = get_emlines(z=z, lbrange=sp.wave.get_range(), sel=1, table=True, ltype='em')
    #lines.add_index('LINE')
    palette = Category10[8]
    
    sp = s.spectra['MUSE_TOT_SKYSUB']
    s.lines.sort('FLUX_REF')
    lines = s.lines[-2:]

    figures = []
    for line in lines:
        #lbda = lines.loc[line]['LBDA_OBS']
        s1 = figure(plot_width=size[0], plot_height=size[1], title=line['LINE'])
        lbda = line['LBDA_OBS']
        subsp = sp.subspec(lbda-25, lbda+25)
        line = s1.line(subsp.wave.coord(), subsp.data, color=palette[0])
        figures.append(s1)
        
    return figures

In [None]:
#output_file('output.html')

#curdoc().theme = 'dark_minimal'

images = {
    'MUSE_WHITE': {'scale': 'zscale'},
    'MUSE_EXPMAP': {'scale': 'minmax', 'palette': 'Greys256', 'link': 'MUSE_WHITE'},
    'MASK_OBJ': {'scale': 'minmax', 'palette': 'Greys256', 'link': 'MUSE_WHITE'}, 
    'HST_F606W': {'scale': 'percentile'}, 
    'HST_F775W': {'scale': 'percentile', 'link': 'HST_F606W'}
}

if True:
    l = gridplot([
        legend(s, size=(1250, 20)),
        plot_src_images(s, params=images, size=(250,250)), 
        [plot_spectrum(s, size=(750, 300), axis_labels=False)] + plot_spectrum_lines(s, size=(250,300), nlines=2)
    ], sizing_mode='fixed')
else:
    l = layout([
        legend(s, size=(1250, 20)),
        images(s, size=(250,250)), 
        [spectrum(s, size=(750, 300), axis_labels=False)] + plot_lines(s, size=(250,300))
    ], sizing_mode='stretch_both')

show(l)
#plot = show(p, notebook_handle=True)

In [None]:
p.plot_height

In [None]:
p = plot_image(s.images['MUSE_WHITE'], colorbar=False, size=(200, 200), axis=False)
p2 = plot_image(s.images['HST_F606W'], colorbar=False, size=(200, 200), axis=False)
show(gridplot([[p, p2]]))

In [None]:
show(gridplot([        
    images(s, size=(200,200)), 
]))

In [None]:
show(column(        
    row(images(s, size=(200,200))), 
    row(spectrum(s, size=(750, 250)), *plot_lines(s))
))

In [None]:
# create an array of RGBA data
N = 20
img = np.empty((N, N), dtype=np.uint32)
view = img.view(dtype=np.uint8).reshape((N, N, 4))
for i in range(N):
    for j in range(N):
        view[i, j, 0] = int(255 * i / N)
        view[i, j, 1] = 158
        view[i, j, 2] = int(255 * j / N)
        view[i, j, 3] = 255

In [None]:
#img = s.images['MUSE_WHITE'].data.filled()

In [None]:
#output_file("grid.html", )

p = figure(plot_width=200, plot_height=200, x_range=(0, 10), y_range=(0, 10))
p.image_rgba(image=[img], x=[0], y=[0], dw=[10], dh=[10])

p2 = figure(plot_width=200, plot_height=200, x_range=(0, 10), y_range=(0, 10))
p2.image_rgba(image=[img], x=[0], y=[0], dw=[10], dh=[10])

p3 = figure(plot_width=400, plot_height=200)
r = p3.line(np.array([1, 2, 3, 4, 5]), np.array([6, 7, 2, 4, 5])*10000, line_width=2)
p3.xaxis.axis_label = f'Wavelength'
p3.yaxis.axis_label = f'Flux'

l = gridplot([
    [p, p2],
    [p3]
], sizing_mode='fixed')

p.xaxis.axis_label = 'foo'

#p3.axis.major_tick_in = 10
#p3.axis.major_tick_out = 0
#p3.axis.minor_tick_in = 5
#p3.axis.minor_tick_out = 0
#p3.axis.axis_label_standoff = 50
#p3.yaxis.major_label_orientation = "vertical"

#p3.xaxis.major_label_standoff = -25
#p3.yaxis.major_label_standoff = -45

#show(l)
t = show(l, notebook_handle=True)

In [None]:
#push_notebook(handle=t)

In [None]:
from ipywidgets import interact

In [None]:
def update(f, w=1, A=1, phi=0):
    if   f == "sin": func = np.sin
    elif f == "cos": func = np.cos
    r.data_source.data['y'] = A * func(w * r.data_source.data['x'] + phi)
    push_notebook(handle=t)

In [None]:
interact(update, f=["sin", "cos"], w=(0,50), A=(1,10), phi=(0, 20, 0.1))

In [None]:
r.data_source.data