In [None]:
import holoviews as hv

import paulssonlab.deaton.trenchripper.trenchripper as tr

hv.extension("bokeh")

In [None]:
headpath = "/home/de64/scratch/de64/sync_folder/2021-05-27_lDE18_20x_run_1/mVenus"

In [None]:
dask_controller = tr.trcluster.dask_controller(
    walltime="02:00:00",
    local=False,
    n_workers=4,
    memory="32GB",
    working_directory=headpath + "/dask",
)
dask_controller.startdask()

In [None]:
dask_controller.displaydashboard()

In [None]:
dask_controller.daskclient.restart()

In [None]:
import dask.array as da
import dask.dataframe as dd
import dask.delayed as delayed
import h5py
import numpy as np
import xarray as xr
from holoviews.operation.datashader import datashade, dynspread, rasterize, spread
from holoviews.selection import link_selections


def fetch_hdf5(filename, channel):
    with h5py.File(filename, "r") as infile:
        data = infile[channel][:]
    return data


def put_index_first(df, col_name):
    df_cols = list(df)
    df_cols.insert(0, df_cols.pop(df_cols.index(col_name)))
    df = df[df_cols]
    #     df = df[.reindex(columns=df_cols)]
    return df


def linked_scatter(
    df, x_dim_label, y_dim_label, minperc=0, maxperc=99, height=400, **scatterkwargs
):
    # put trenchid first
    df = put_index_first(df, "trenchid")

    dataset = hv.Dataset(df)
    x_data_vals, y_data_vals = df[x_dim_label], df[y_dim_label]

    x_low, x_high = np.percentile(x_data_vals, minperc), np.percentile(
        x_data_vals, maxperc
    )
    y_low, y_high = np.percentile(y_data_vals, minperc), np.percentile(
        y_data_vals, maxperc
    )

    scatter = hv.Scatter(data=dataset, vdims=[y_dim_label], kdims=[x_dim_label])
    # some toy code to try datashading in the future, need to figure out linked brushing for this to work
    #     if px_size == None:
    #         ## Typical datashade mode
    #         shaded_scatter = dynspread(rasterize(scatter, cmap=cmap, cnorm="linear"))
    #         shaded_scatter = shaded_scatter.opts(colorbar=True, colorbar_position="bottom")
    #     else:
    #         shaded_scatter = spread(rasterize(scatter), px=px_size, shape='circle').opts(cmap=cmap, cnorm="linear")
    #         shaded_scatter = shaded_scatter.opts(colorbar=True, colorbar_position="bottom")
    # #         shaded_scatter = spread(rasterize(scatter, cmap="kbc_r", cnorm="linear"), px=px_size, shape='circle')

    scatter = scatter.opts(
        tools=["hover", "doubletap", "lasso_select"],
        xlim=(x_low, x_high),
        ylim=(y_low, y_high),
        height=height,
        responsive=True,
        **scatterkwargs,
    )

    select_scatter = hv.streams.Selection1D(
        source=scatter, index=[0], rename={"index": "scatterselect"}
    )

    def get_scatter_trenchids(scatterselect, dataset=dataset):
        filtered_dataset = dataset.iloc[scatterselect]

        return hv.Table(filtered_dataset)

    trenchid_table = hv.DynamicMap(get_scatter_trenchids, streams=[select_scatter])

    select_trenchid = hv.streams.Selection1D(
        source=trenchid_table, index=[0], rename={"index": "trenchid_index"}
    )

    scatter_display = scatter + trenchid_table

    return scatter_display, select_scatter, select_trenchid


def linked_kymograph_for_scatter(
    xrstack,
    df,
    x_dim_label,
    y_dim_label,
    select_scatter,
    select_trenchid,
    y_scale=3,
    x_window_size=300,
):
    ### stream must return trenchid value
    ### df must have trenchid lookups
    width, height = xrstack.shape[3], int(xrstack.shape[2] * y_scale)
    x_window_scale = x_window_size / xrstack.shape[3]
    x_size = int(xrstack.shape[3] * (y_scale * x_window_scale))

    dataset = hv.Dataset(df)

    def select_pt_load_image(
        channel, scatterselect, trenchid_index, width=width, height=height
    ):
        filtered_dataset = dataset.iloc[scatterselect]
        trenchid = filtered_dataset.iloc[trenchid_index]["trenchid"][0]
        arr = xrstack.loc[channel, trenchid].values
        return hv.Image(arr, bounds=(0, 0, width, height))

    def print_trenchid(scatterselect, trenchid_index):
        filtered_dataset = dataset.iloc[scatterselect]
        trenchid = filtered_dataset.iloc[trenchid_index]["trenchid"][0]
        return hv.Text(3.0, 20.0, str(trenchid), fontsize=30)

    def set_bounds(
        fig, element, y_dim=height, x_dim=width, x_window_size=x_window_size
    ):
        sy = y_dim - 0.5
        sx = x_dim - 0.5

        fig.state.y_range.bounds = (-0.5, sy)
        fig.state.x_range.bounds = (0, sx)
        fig.state.x_range.start = 0
        fig.state.x_range.reset_start = 0
        fig.state.x_range.end = x_window_size
        fig.state.x_range.reset_end = x_window_size

    image_stack = hv.DynamicMap(
        select_pt_load_image,
        kdims=["Channel"],
        streams=[select_scatter, select_trenchid],
    )

    kymograph_display = image_stack.opts(
        plot={
            "Image": dict(
                colorbar=True, tools=["hover"], hooks=[set_bounds], aspect="equal"
            ),
        }
    )
    kymograph_display = kymograph_display.opts(
        cmap="Greys_r", height=height, width=x_size
    )

    kymograph_display = kymograph_display.redim.range(trenchid=(0, xrstack.shape[1]))
    kymograph_display = kymograph_display.redim.values(
        Channel=xrstack.coords["Channel"].values.tolist()
    )

    trenchid_display = hv.DynamicMap(
        print_trenchid, streams=[select_scatter, select_trenchid]
    )
    trenchid_display = trenchid_display.opts(text_align="left", text_color="white")

    output_display = kymograph_display * trenchid_display

    return output_display


def linked_histogram(
    df, label, bins=50, minperc=0, maxperc=99, height=400, **histkwargs
):
    df = put_index_first(df, "trenchid")

    dataset = hv.Dataset(df)
    data_vals = df[label]

    x_low = np.percentile(data_vals, minperc)
    x_high = np.percentile(data_vals, maxperc)

    frequencies, edges = np.histogram(data_vals, bins=50, range=(x_low, x_high))
    hist = hv.Histogram((edges, frequencies))

    hist = hist.opts(
        tools=["hover", "doubletap"], height=height, responsive=True, **histkwargs
    )

    select_histcolumn = hv.streams.Selection1D(
        source=hist, index=[0], rename={"index": "histcolumn"}
    )

    def get_hist_trenchids(histcolumn, df=df, label=label, edges=edges):
        selected_edges = edges[histcolumn[0] : histcolumn[0] + 2]
        filtered_df = df[
            (df[label] < selected_edges[1]) & (df[label] > selected_edges[0])
        ]
        filtered_dataset = hv.Dataset(filtered_df)

        return hv.Table(filtered_dataset)

    trenchid_table = hv.DynamicMap(get_hist_trenchids, streams=[select_histcolumn])

    select_trenchid = hv.streams.Selection1D(
        source=trenchid_table, index=[0], rename={"index": "trenchid_index"}
    )

    hist_display = hist + trenchid_table

    return hist_display, edges, select_histcolumn, select_trenchid


def linked_kymograph_for_hist(
    xrstack,
    df,
    label,
    edges,
    select_histcolumn,
    select_trenchid,
    y_scale=3,
    x_window_size=300,
):
    ### stream must return trenchid value
    ### df must have trenchid lookups
    width, height = xrstack.shape[3], int(xrstack.shape[2] * y_scale)
    x_window_scale = x_window_size / xrstack.shape[3]
    x_size = int(xrstack.shape[3] * (y_scale * x_window_scale))

    def select_pt_load_image(
        channel, histcolumn, trenchid_index, width=width, height=height
    ):
        selected_edges = edges[histcolumn[0] : histcolumn[0] + 2]
        filtered_df = df[
            (df[label] < selected_edges[1]) & (df[label] > selected_edges[0])
        ]
        filtered_dataset = hv.Dataset(filtered_df)

        trenchid = filtered_dataset.iloc[trenchid_index]["trenchid"][0]
        arr = xrstack.loc[channel, trenchid].values
        return hv.Image(arr, bounds=(0, 0, width, height))

    def print_trenchid(histcolumn, trenchid_index):
        selected_edges = edges[histcolumn[0] : histcolumn[0] + 2]
        filtered_df = df[
            (df[label] < selected_edges[1]) & (df[label] > selected_edges[0])
        ]
        filtered_dataset = hv.Dataset(filtered_df)

        trenchid = filtered_dataset.iloc[trenchid_index]["trenchid"][0]
        return hv.Text(3.0, 20.0, str(trenchid), fontsize=30)

    def set_bounds(
        fig, element, y_dim=height, x_dim=width, x_window_size=x_window_size
    ):
        sy = y_dim - 0.5
        sx = x_dim - 0.5

        fig.state.y_range.bounds = (-0.5, sy)
        fig.state.x_range.bounds = (0, sx)
        fig.state.x_range.start = 0
        fig.state.x_range.reset_start = 0
        fig.state.x_range.end = x_window_size
        fig.state.x_range.reset_end = x_window_size

    image_stack = hv.DynamicMap(
        select_pt_load_image,
        kdims=["Channel"],
        streams=[select_histcolumn, select_trenchid],
    )

    kymograph_display = image_stack.opts(
        plot={
            "Image": dict(
                colorbar=True, tools=["hover"], hooks=[set_bounds], aspect="equal"
            ),
        }
    )
    kymograph_display = kymograph_display.opts(
        cmap="Greys_r", height=height, width=x_size
    )

    kymograph_display = kymograph_display.redim.range(trenchid=(0, xrstack.shape[1]))
    kymograph_display = kymograph_display.redim.values(
        Channel=xrstack.coords["Channel"].values.tolist()
    )

    trenchid_display = hv.DynamicMap(
        print_trenchid, streams=[select_histcolumn, select_trenchid]
    )
    trenchid_display = trenchid_display.opts(text_align="left", text_color="white")

    output_display = kymograph_display * trenchid_display

    return output_display


def kymo_xarr(headpath, subset=None, in_memory=False, in_distributed_memory=False):
    data_parquet = dd.read_parquet(headpath + "/kymograph/metadata")
    meta_handle = tr.pandas_hdf5_handler(headpath + "/metadata.hdf5")
    metadata = meta_handle.read_df("global", read_metadata=True).metadata
    channels = metadata["channels"]
    file_indices = data_parquet["File Index"].unique().compute().to_list()

    delayed_fetch_hdf5 = delayed(fetch_hdf5)
    filenames = [
        headpath + "/kymograph/kymograph_" + str(file_idx) + ".hdf5"
        for file_idx in file_indices
    ]

    sample = fetch_hdf5(filenames[0], channels[0])

    channel_arr = []
    for channel in channels:
        filenames = [
            headpath + "/kymograph/kymograph_" + str(file_idx) + ".hdf5"
            for file_idx in file_indices
        ]
        delayed_arrays = [delayed_fetch_hdf5(fn, channel) for fn in filenames]
        da_file_arrays = [
            da.from_delayed(delayed_reader, shape=sample.shape, dtype=sample.dtype)
            for delayed_reader in delayed_arrays
        ]
        da_file_index_arr = da.concatenate(da_file_arrays, axis=0)
        channel_arr.append(da_file_index_arr)
    da_channel_arr = da.stack(channel_arr, axis=0)
    da_channel_arr = (
        da_channel_arr.swapaxes(3, 4)
        .reshape(
            da_channel_arr.shape[0],
            da_channel_arr.shape[1],
            -1,
            da_channel_arr.shape[3],
        )
        .swapaxes(2, 3)
    )
    if subset != None:
        da_channel_arr = da_channel_arr[:, subset]
    if in_memory:
        da_channel_arr = da_channel_arr.compute()
    elif in_distributed_memory:
        da_channel_arr = da_channel_arr.persist()

    # defining xarr
    dims = ["Channel", "trenchid", "y", "xt"]
    coords = {d: np.arange(s) for d, s in zip(dims, da_channel_arr.shape)}
    coords["Channel"] = np.array(channels)
    kymo_xarr = xr.DataArray(
        da_channel_arr, dims=dims, coords=coords, name="Data"
    ).astype("uint16")

    return kymo_xarr

In [None]:
kymo_xarr = kymo_xarr(headpath, subset=slice(0, 10000), in_memory=True)

In [None]:
kymo_df = dd.read_parquet(
    "/home/de64/scratch/de64/sync_folder/2021-05-27_lDE18_20x_run_1/Barcodes/percentiles"
)
test_df = kymo_df.loc[:6500000].compute()

In [None]:
scatter_display, select_scatter, select_scatter_trenchid = linked_scatter(
    test_df, "RFP 98th Percentile", "Cy5 98th Percentile", cmap="gray_r", height=600
)

In [None]:
scatter_kymograph_display = linked_kymograph_for_scatter(
    kymo_xarr,
    test_df,
    "RFP 98th Percentile",
    "Cy5 98th Percentile",
    select_scatter,
    select_scatter_trenchid,
    y_scale=3,
    x_window_size=300,
)

In [None]:
dataset = hv.Dataset(test_df)

In [None]:
hv.Dataset(dataset.iloc[10]["trenchid"])

In [None]:
hist_display, edges, select_histcolumn, select_hist_trenchid = linked_histogram(
    test_df, "RFP 98th Percentile", bins=50, minperc=0, maxperc=99, height=400
)

In [None]:
hist_kymograph_display = linked_kymograph_for_hist(
    kymo_xarr,
    test_df,
    "RFP 98th Percentile",
    edges,
    select_histcolumn,
    select_hist_trenchid,
    y_scale=3,
    x_window_size=300,
)

In [None]:
scatter_display

In [None]:
scatter_kymograph_display

In [None]:
hist_display

In [None]:
hist_kymograph_display

In [None]:
# toying with links to enable selection on datashader output

from holoviews.selection import link_selections

points = hv.Scatter(np.random.randn(10000, 2))


def selected_info(points):
    label = "Mean x, y: %.3f, %.3f" % tuple(points.array().mean(axis=0))
    return points.relabel(label)


selected = points.apply(selected_info)
mpg_ls = link_selections.instance()
test_sel = mpg_ls(datashade(selected))

In [None]:
test_sel

In [None]:
mpg_ls.selection_expr