In [None]:
%matplotlib widget

import h5py
import ipywidgets
import matplotlib.pyplot as plt
import numpy as np

from eidein import Eidein

## Redshift

In [None]:
CACHEFILE = "data/dr16q_superset_cache.hdf5"
with h5py.File(CACHEFILE, "r") as cachefile:
    X = cachefile["X"][:]
    y_cache = cachefile["y_cache"][:]
X.shape, y_cache.shape

In [None]:
LOGLAMMIN, LOGLAMMAX = 3.5832, 3.9583
N_FEATURES = 3752
LAM = np.power(10, np.linspace(LOGLAMMIN, LOGLAMMAX, N_FEATURES))
LINES = [
    (1033.82, "O VI"),
    (1215.24, "Lyα"),
    (1549.48, "C IV"),
    (1908.734, "C III"),
    (2326.0, "C II"),
    (2799.117, "Mg II"),
    (4102.89, "HΔ"),
    (4341.68, "Hγ"),
    (4862.68, "Hβ"),
    (6564.61, "Hα")]
ARROWPROPS = {"arrowstyle": "-|>", "facecolor": "black"}

def z2lam_emit(z, lam_obsv):
    return lam_obsv / (1 + z)

def plot_spec(ax, idx, flux, y, label, y_cache=y_cache, lam=LAM):
    label_str = "\n".join([f"{idx}",
                           f"$\hat{{z}} = {y:.2f}$",
                           f"$z = {label:.2f}$ (shown)",
                           f"$z_{{\mathrm{{cache}}}} = {y_cache[idx]:.2f}$"])    
    lam_emit = [z2lam_emit(label, l) for l in lam]
    ax.plot(lam_emit, flux, label=label_str)
    ax.legend()
    ax.set_xlabel("Rest Frame Wavelength [Å]")
    ax.set_ylabel("Flux [10$^{-17}$ erg cm$^{-2}$ s$^{-1}$ Å$^{-1}$]")
    # plot spectral lines
    for line, name in LINES:
        ax.annotate(name, xy=(line, 0), xytext=(line, -2),
                    arrowprops=ARROWPROPS, horizontalalignment='center')

In [None]:
i = 1

with h5py.File("data/human.hdf5", "r") as datafile:
    idx_query = datafile[f"idx_query_{i}"][:]
    entr_query = datafile[f"entr_query_{i}"][:]
    ypred_query = datafile[f"ypred_query_{i}"][:]
    # Python has different indexing then Julia.
    idx_query -= 1

X_query = X[idx_query]
idx_query.shape, X_query.shape, entr_query.shape, ypred_query.shape

In [None]:
fig, ax = plt.subplots()
ax.hist(entr_query)

In [None]:
label_widget = ipywidgets.FloatText(description="Redshift:", step=0.01)
eidein = Eidein(idx_query, X_query, ypred_query, entr_query, plot_spec, label_widget)
display(eidein)

In [None]:
len(eidein.labelled), eidein.labelled

In [None]:
idx_label = np.array(list(eidein.labelled.keys()))
with h5py.File("data/human.hdf5", "r+") as hdf5file:
    hdf5file.create_dataset(f"idx_label_{i}", data=(idx_label + 1))

with h5py.File(CACHEFILE, "r+") as cachefile:
    y_cache = cachefile["y_cache"]
    for idx, z in eidein.labelled.items():
        y_cache[idx] = z

## Digits

In [None]:
def plot_mnist(ax, identifier, x, y, label):
    ax.imshow(x.reshape(32, 32), cmap='gray')
    ax.set_title("{}".format(y, label))

In [None]:
DATAFILE = "data/human.hdf5"

with h5py.File(DATAFILE, "r") as datafile:
    X = datafile["X"][:]
    y = datafile["y"][:]
X.shape, y.shape

In [None]:
i = 1

with h5py.File(DATAFILE, "r") as datafile:
    index_query = datafile["index_query_{}".format(i)][:]
    entr_query = datafile["entr_query_{}".format(i)][:]
    # Python has different indexing then Julia.
    index_query -= 1

X_query = X[index_query]
y_query = y[index_query]
X_query.shape, y_query.shape, index_query.shape, entr_query.shape

In [None]:
label_widget = ipywidgets.RadioButtons(options=np.arange(1, 11))

eidein = Eidein(
    index_query, X_query.reshape(-1, 1024), y_query, entr_query, X_query,
    plot_mnist, label_widget)
eidein.data_fig.set_size_inches(2, 2)
display(eidein)

In [None]:
len(eidein.labelled), eidein.labelled

In [None]:
with h5py.File(DATAFILE, "r+") as datafile:
    index_label = np.array(list(eidein.labelled.keys()))
    index_label += 1
    datafile.create_dataset("index_label_{}".format(i), data=index_label)