In [None]:
import paulssonlab.deaton.trenchripper.trenchripper as tr

import holoviews as hv

hv.extension("bokeh")

In [None]:
headpath = "/home/de64/scratch/de64/sync_folder/2021-05-27_lDE18_20x_run_1/mVenus"

In [None]:
dask_controller = tr.trcluster.dask_controller(
    walltime="02:00:00",
    local=False,
    n_workers=4,
    memory="32GB",
    working_directory=headpath + "/dask",
)
dask_controller.startdask()

In [None]:
dask_controller.displaydashboard()

In [None]:
dask_controller.daskclient.restart()

In [None]:
import dask.dataframe as dd
import h5py
import xarray as xr
import dask.array as da
import numpy as np
import dask.delayed as delayed

In [None]:
# ## This will be necessary for complex lookups later possibly
# data_parquet = dd.read_parquet("/home/de64/scratch/de64/sync_folder/2021-05-27_lDE18_20x_run_1/mVenus/kymograph/metadata")
# data_parquet = tr.set_new_aligned_index(data_parquet,"Trenchid Timepoint Index")
# meta_handle = tr.pandas_hdf5_handler(headpath+"/metadata.hdf5")
# metadata = meta_handle.read_df("global",read_metadata=True).metadata
# channels = metadata["channels"]
# file_indices = data_parquet["File Index"].unique().compute().to_list()

In [None]:
data_parquet = dd.read_parquet(
    "/home/de64/scratch/de64/sync_folder/2021-05-27_lDE18_20x_run_1/mVenus/kymograph/metadata"
)
meta_handle = tr.pandas_hdf5_handler(headpath + "/metadata.hdf5")
metadata = meta_handle.read_df("global", read_metadata=True).metadata
channels = metadata["channels"]
file_indices = data_parquet["File Index"].unique().compute().to_list()

In [None]:
file_indices

In [None]:
def fetch_hdf5(filename, channel):
    with h5py.File(filename, "r") as infile:
        data = infile[channel][:]
    return data

In [None]:
delayed_fetch_hdf5 = delayed(fetch_hdf5)
filenames = [
    headpath + "/kymograph/kymograph_" + str(file_idx) + ".hdf5"
    for file_idx in file_indices
]

sample = fetch_hdf5(filenames[0], channels[0])

channel_arr = []
for channel in channels:
    filenames = [
        headpath + "/kymograph/kymograph_" + str(file_idx) + ".hdf5"
        for file_idx in file_indices
    ]
    delayed_arrays = [delayed_fetch_hdf5(fn, channel) for fn in filenames]
    da_file_arrays = [
        da.from_delayed(delayed_reader, shape=sample.shape, dtype=sample.dtype)
        for delayed_reader in delayed_arrays
    ]
    da_file_index_arr = da.concatenate(da_file_arrays, axis=0)
    channel_arr.append(da_file_index_arr)
da_channel_arr = da.stack(channel_arr, axis=0)
da_channel_arr = (
    da_channel_arr.swapaxes(3, 4)
    .reshape(
        da_channel_arr.shape[0], da_channel_arr.shape[1], -1, da_channel_arr.shape[3]
    )
    .swapaxes(2, 3)
)

In [None]:
da_channel_arr

In [None]:
da_channel_arr_sub = da_channel_arr[:, :10000]

In [None]:
da_channel_arr_sub

In [None]:
da_channel_arr_mem = da_channel_arr_sub.compute()

In [None]:
#     trenchid = Stream.define('Trenchid', trenchid=0)
dims = [
    "Channel",
    "trenchid",
    "y",
    "xt",
]
coords = {d: np.arange(s) for d, s in zip(dims, da_channel_arr_mem.shape)}
coords["Channel"] = np.array(channels)
xrstack = xr.DataArray(
    da_channel_arr_mem, dims=dims, coords=coords, name="Data"
).astype("uint16")

# Wrap in HoloViews Dataset
ds = hv.Dataset(xrstack)

In [None]:
coords["Channel"]

In [None]:
from holoviews.streams import Stream, param

trenchid = Stream.define("trenchid", trenchid=10)

trenchid_stream = trenchid()

y_size = 3
x_window_scale = 0.6
x_size = y_size * x_window_scale


def load_image(
    channel,
    trenchid,
    width=da_channel_arr_mem.shape[3],
    height=int(da_channel_arr_mem.shape[2] * y_size),
):
    arr = xrstack.loc[channel, trenchid].values
    return hv.Image(arr, bounds=(0, 0, width, height))


# Define DynamicMap with z-dimension to slide through
image_stack = hv.DynamicMap(load_image, kdims=["Channel"], streams=[trenchid_stream])


def set_bounds(
    fig,
    element,
    y_dim=da_channel_arr_mem.shape[2] * y_size,
    x_dim=da_channel_arr_mem.shape[3],
    x_window_scale=x_window_scale,
):
    sy = y_dim - 0.5
    sx = x_dim - 0.5

    fig.state.y_range.bounds = (-0.5, sy)
    fig.state.x_range.bounds = (0, sx)
    fig.state.x_range.start = 0
    fig.state.x_range.reset_start = 0
    fig.state.x_range.end = int(da_channel_arr_mem.shape[3] * x_window_scale)
    fig.state.x_range.reset_end = int(da_channel_arr_mem.shape[3] * x_window_scale)


display_obj = image_stack.opts(
    plot={
        "Image": dict(
            colorbar=True, tools=["hover"], hooks=[set_bounds], aspect="equal"
        ),
    }
)
display_obj = display_obj.opts(
    cmap="Greys_r",
    height=int(da_channel_arr_mem.shape[2] * y_size),
    width=int(da_channel_arr_mem.shape[3] * x_size),
)

display_obj = display_obj.redim.range(trenchid=(0, coords["trenchid"][-1]))
display_obj = display_obj.redim.values(Channel=coords["Channel"])

In [None]:
display_obj.event(trenchid=1)

In [None]:
display_obj

In [None]:
kymo_df = dd.read_parquet(
    "/home/de64/scratch/de64/sync_folder/2021-05-27_lDE18_20x_run_1/Barcodes/percentiles"
)

In [None]:
kymo_df

In [None]:
test_df = kymo_df.loc[:6500000].compute()

In [None]:
dataset = hv.Dataset(test_df, vdims=("trenchid"))

In [None]:
test_df.iloc[1092]["trenchid"]

In [None]:
from holoviews.streams import Stream, param

y_size = 3
x_window_scale = 0.6
x_size = y_size * x_window_scale

# Declare HeatMap
potat = hv.Scatter(
    data=dataset, vdims=["RFP 98th Percentile"], kdims=["Cy7 98th Percentile"]
)
potat = potat.opts(
    color="k",
    marker="s",
    size=5,
    tools=["hover", "doubletap"],
    fontscale=2,
    width=800,
    height=600,
    xlim=(0, None),
    ylim=(0, None),
)

select1d = hv.streams.Selection1D(source=potat, index=[0])


def select_pt_load_image(
    channel,
    index,
    width=da_channel_arr_mem.shape[3],
    height=int(da_channel_arr_mem.shape[2] * y_size),
):
    trenchid = test_df.iloc[index[0]]["trenchid"]
    arr = xrstack.loc[channel, trenchid].values
    return hv.Image(arr, bounds=(0, 0, width, height))


# Define DynamicMap with z-dimension to slide through
image_stack = hv.DynamicMap(select_pt_load_image, kdims=["Channel"], streams=[select1d])


def set_bounds(
    fig,
    element,
    y_dim=da_channel_arr_mem.shape[2] * y_size,
    x_dim=da_channel_arr_mem.shape[3],
    x_window_scale=x_window_scale,
):
    sy = y_dim - 0.5
    sx = x_dim - 0.5

    fig.state.y_range.bounds = (-0.5, sy)
    fig.state.x_range.bounds = (0, sx)
    fig.state.x_range.start = 0
    fig.state.x_range.reset_start = 0
    fig.state.x_range.end = int(da_channel_arr_mem.shape[3] * x_window_scale)
    fig.state.x_range.reset_end = int(da_channel_arr_mem.shape[3] * x_window_scale)


display_obj = image_stack.opts(
    plot={
        "Image": dict(
            colorbar=True, tools=["hover"], hooks=[set_bounds], aspect="equal"
        ),
    }
)
display_obj = display_obj.opts(
    cmap="Greys_r",
    height=int(da_channel_arr_mem.shape[2] * y_size),
    width=int(da_channel_arr_mem.shape[3] * x_size),
)

display_obj = display_obj.redim.range(trenchid=(0, coords["trenchid"][-1]))
display_obj = display_obj.redim.values(Channel=coords["Channel"])


# # Declare HeatMap
# potat = hv.Scatter(data=dataset,vdims=['RFP 98th Percentile'],kdims=['Cy7 98th Percentile'])
# potat = potat.opts(color='k', marker='s', size=5,tools=["hover","doubletap"])

# # Declare Tap stream with heatmap as source and initial values
# # posxy = hv.streams.DoubleTap(source=potat, x=0., y=0.)

# select1d = hv.streams.Selection1D(source=potat)

# empty = hv.Scatter(dataset,vdims=['RFP 98th Percentile'],kdims=['Cy7 98th Percentile'])
# empty = empty.opts(color='k', marker='s', size=5, tools=["hover"])
# def select_pt(index):
#     if not index:
#         return empty
#     df_selection = test_df.iloc[index[0]]
#     trenchid = test_df.iloc[1092]["trenchid"]
#     outscatter = hv.Scatter(df_selection,vdims=['RFP 98th Percentile'],kdims=['Cy7 98th Percentile'])
#     return outscatter.opts(color='k', marker='s', size=10, tools=["hover"])

# # Define function to compute histogram based on tap location
# # def tap_potat(index):
# #     if not index:
# #         return empty

# #     return hv.Scatter(data=selection,vdims=['RFP 98th Percentile'],kdims=['Cy7 98th Percentile'])

# sel_dmap = hv.DynamicMap(select_pt, streams=[select1d])

In [None]:
display_obj

In [None]:
potat

In [None]:
potat + tap_dmap.opts(framewise=True)

In [None]:
posxy

In [None]:
import pandas as pd

# Declare dataset
df = pd.read_csv("http://assets.holoviews.org/data/diseases.csv.gz", compression="gzip")
dataset = hv.Dataset(df, vdims=("measles", "Measles Incidence"))

# Declare HeatMap
heatmap = hv.HeatMap(
    dataset.aggregate(["Year", "State"], np.mean),
    label="Average Weekly Measles Incidence",
).select(Year=(1928, 2002))

# Declare Tap stream with heatmap as source and initial values
posxy = hv.streams.Tap(source=heatmap, x=1951, y="New York")

# Define function to compute histogram based on tap location
def tap_histogram(x, y):
    return hv.Curve(
        dataset.select(State=y, Year=int(x)),
        kdims="Week",
        label=f"Year: {x}, State: {y}",
    )


# Connect the Tap stream to the tap_histogram callback
tap_dmap = hv.DynamicMap(tap_histogram, streams=[posxy])

In [None]:
from holoviews import opts

# Display the Heatmap and Curve side by side
heatmap + tap_dmap

In [None]:
moo = hv.DynamicMap(test, kdims=["Channel", "trenchid"])

In [None]:
hv.help(test)

In [None]:
hv.help(test)

In [None]:
moo = moo.redim.values(Channel=selection_coords["Channel"].tolist())
moo = moo.redim.range(trenchid=(0, 10))

In [None]:
moo.opts)

In [None]:
def get_display_from_dask_arr(da_channel_arr):
    # Wrap in xarray DataArray and label coordinates
    #     trenchid = Stream.define('Trenchid', trenchid=0)
    dims = [
        "Channel",
        "trenchid",
        "y",
        "xt",
    ]
    coords = {d: np.arange(s) for d, s in zip(dims, da_channel_arr.shape)}
    coords["Channel"] = np.array(channels)
    xrstack = xr.DataArray(
        da_channel_arr, dims=dims, coords=coords, name="Data"
    ).astype("uint16")

    # Wrap in HoloViews Dataset
    ds = hv.Dataset(xrstack)

    # # Convert to stack of images with x/y-coordinates along axes
    # image_stack = ds.to(hv.Image, ['xt', 'y'], dynamic=True)
    # bounds = (0,0,400,140)
    image_stack = ds.to(hv.Image, ["xt", "y"], dynamic=True)

    y_size = 2.9
    x_window_scale = 0.7
    x_size = y_size * x_window_scale

    def set_bounds(
        fig,
        element,
        y_dim=da_channel_arr.shape[2],
        x_dim=da_channel_arr.shape[3],
        x_window_scale=x_window_scale,
    ):
        sy = y_dim - 0.5
        sx = x_dim - 0.5

        fig.state.y_range.bounds = (-0.5, sy)
        fig.state.x_range.bounds = (-0.5, sx)
        fig.state.x_range.start = 0
        fig.state.x_range.reset_start = 0
        fig.state.x_range.end = int(da_channel_arr.shape[3] * x_window_scale)
        fig.state.x_range.reset_end = int(da_channel_arr.shape[3] * x_window_scale)

    display_obj = image_stack.opts(
        plot={
            "Image": dict(
                colorbar=True,
                width=int(da_channel_arr.shape[3] * x_size),
                height=int(da_channel_arr.shape[2] * y_size),
                tools=["hover"],
                hooks=[set_bounds],
                aspect="equal",
            ),
        }
    )
    display_obj = display_obj.opts(cmap="Greys_r")

    return display_obj

In [None]:
display_obj = get_display_from_dask_arr(da_channel_arr_mem)

In [None]:
display_obj

In [None]:
from holoviews.streams import Stream, param

trenchid = Stream.define("Trenchid", trenchid=0)

In [None]:
display_obj.event(trenchid=10)

In [None]:
hv.help(trenchid)