# Imports

In [None]:
import numpy as np
import random
from tqdm.auto import tqdm
import nd2reader
import h5py
import skimage
from skimage.transform import SimilarityTransform, warp
import holoviews as hv
import matplotlib.pyplot as plt
from matplotlib.colors import hex2color
from cytoolz import partial
from itertools import cycle, repeat, chain
from numbers import Number
from pathlib import Path
import av
from tqdm.auto import trange, tqdm
import dask
import distributed
from distributed import Client
from dask_jobqueue import SLURMCluster
from IPython.display import Video
from paulssonlab.image_analysis import workflow

In [None]:
hv.extension("bokeh")

# Config

In [None]:
# nd2_filename = "/home/jqs1/scratch/jqs1/microscopy/220704/220704rbs_library_fish.nd2"
nd2_filename = "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2"

In [None]:
nd2 = nd2reader.ND2Reader(nd2_filename)

In [None]:
default_channels = ("YFP-DUAL", "RFP-Penta")
channel_to_color = {
    "BF": "#ffffff",
    "RFP-PENTA": "#e22400",
    "RFP-Penta": "#e22400",
    "YFP-DUAL": "#f5eb00",
    # "GFP": "#76ba40",
    "Cy5": "#e292fe",
    # "Cy7": "#FF0000"
    # "BFP": "#3a87fd",
}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="02:00:00",
    memory="16GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster.scale(2)

In [None]:
cluster.adapt(maximum=300)

In [None]:
cluster

# Mosaic

In [None]:
def composite_channels(imgs, hexcolors, scale=True):
    colors = [hex2color(hexcolor) for hexcolor in hexcolors]
    return _composite_channels(imgs, colors, scale=scale)


def _composite_channels(channel_imgs, colors, scale=True):
    if len(channel_imgs) != len(colors):
        raise ValueError("expecting equal numbers of channels and colors")
    num_channels = len(channel_imgs)
    if scale:
        scaled_imgs = [
            channel_imgs[i] / np.percentile(channel_imgs[i], 99.9)
            for i in range(num_channels)
        ]
        for scaled_img in scaled_imgs:
            np.clip(scaled_img, 0, 1, scaled_img)  # clip in place
    else:
        scaled_imgs = channel_imgs
    imgs_to_combine = [
        scaled_imgs[i][:, :, np.newaxis] * np.array(colors[i])
        for i in range(num_channels)
    ]
    if not len(imgs_to_combine):
        imgs_to_combine = [np.ones(colored_imgs[0].shape)]  # white placeholder
    img = imgs_to_combine[0]
    for img2 in imgs_to_combine[1:]:
        img = 1 - (1 - img) * (1 - img2)
    return img

In [None]:
def colorized_frame(
    channel_to_color,
    get_frame_func,
    filename,
    t=0,
    v=0,
    channels=default_channels,
    scaling_funcs=None,
):
    imgs = [get_frame_func(filename, v, channel, t) for channel in channels]
    if scaling_funcs:
        for idx in range(len(channels)):
            channel = channels[idx]
            if channel not in scaling_funcs:
                raise ValueError(f"missing scaling_func for {channel}")
            imgs[idx] = scaling_funcs[channel](imgs[idx])
    img = composite_channels(
        imgs,
        [channel_to_color[channel] for channel in channels],
        scale=(not scaling_funcs),
    )
    return img

In [None]:
def rectangles_intersect(ul1, lr1, ul2, lr2):
    return not (
        (ul1[0] > lr2[0]) or (lr1[0] < ul2[0]) or (ul1[1] > lr2[1]) or (lr1[1] < ul2[1])
    )


def scale_around_center(scale, center):
    x, y = center
    return (
        SimilarityTransform(translation=(-x, -y))
        + SimilarityTransform(scale=scale)
        + SimilarityTransform(translation=(x, y))
    )


def output_transformation(input_width, input_height, output_width, output_height):
    width_ratio = input_width / output_width
    height_ratio = input_height / output_height
    scale = max(width_ratio, height_ratio)
    x = -(output_width - input_width / scale) / 2
    y = -(output_height - input_height / scale) / 2
    return SimilarityTransform(translation=(x, y)) + SimilarityTransform(scale=scale)


def mosaic_frame(
    get_frame_func,
    positions,
    image_dims,
    timepoint,
    center=None,
    scale=1,
    output_dims=(1024, 1024),
):
    if center is None:
        columns = positions["x_idx"].max() - positions["x_idx"].min() + 1
        rows = positions["y_idx"].max() - positions["y_idx"].min() + 1
        center = (image_dims[0] * columns / 2, image_dims[1] * rows / 2)
    viewport_transform = output_transformation(*image_dims, *output_dims)
    output_img = np.zeros((output_dims[1], output_dims[0], 3))
    viewport_ul = (0, 0)
    viewport_lr = (output_dims[0] - 1, output_dims[1] - 1)  # TODO: off-by-one?
    for (filename, pos_num), position in positions.iterrows():
        frame_corner = (
            -image_dims[0] * position["x_idx"],
            -image_dims[1] * position["y_idx"],
        )
        frame_transform = (
            output_transformation(*image_dims, *output_dims)
            + scale_around_center(1 / scale, (image_dims[0] / 2, image_dims[1] / 2))
            + SimilarityTransform(
                translation=(
                    center[0] - image_dims[0] / 2,
                    center[1] - image_dims[1] / 2,
                )
            )
            + SimilarityTransform(translation=frame_corner)
        )
        frame_ul = frame_transform.inverse((0, 0))[0]
        frame_lr = frame_transform.inverse((image_dims[0] - 1, image_dims[1] - 1))[0]
        visible = rectangles_intersect(viewport_ul, viewport_lr, frame_ul, frame_lr)
        if visible:
            img = get_frame_func(t=timepoint, v=pos_num)
            output_img += warp(
                img, frame_transform, output_shape=output_dims[::-1], order=2
            )
    return output_img

In [None]:
def export_video(ary, filename, fps=30, codec="h264", crf=22, tune="stillimage"):
    with av.open(filename, mode="w") as container:
        stream = container.add_stream(codec, rate=fps, options={"crf": str(crf), "tune": tune})
        stream.width = ary[0].shape[1]
        stream.height = ary[0].shape[0]
        stream.pix_fmt = "yuv420p"
        for idx in range(len(ary)):
            img = np.round(255 * ary[idx]).astype(np.uint8)
            img = np.clip(img, 0, 255)
            frame = av.VideoFrame.from_ndarray(img, format="rgb24")
            for packet in stream.encode(frame):
                container.mux(packet)
        for packet in stream.encode():
            container.mux(packet)

In [None]:
def get_intensity_extrema(nd2, channels, v=0, step=10):
    extrema = {}
    for channel in channels:
        min_value = -1
        max_value = -1
        for t in range(0, nd2.sizes["t"], step):
            img = nd2.get_frame_2D(v=v, t=t, c=nd2.metadata["channels"].index(channel))
            if min_value == -1:
                min_value = img.min()
                # max_value = img.max()
                max_value = np.percentile(img, 99.9)
            else:
                min_value = min(min_value, img.min())
                # max_value = max(max_value, img.max())
                max_value = max(max_value, np.percentile(img, 99.9))
        extrema[channel] = (min_value, max_value)
    return extrema


def get_scaling_funcs(extrema):
    scaling_funcs = {}
    for channel, (min_value, max_value) in extrema.items():
        # careful! there's an unfortunate late-binding issue
        # SEE: https://stackoverflow.com/questions/1107210/python-create-function-in-a-loop-capturing-the-loop-variable
        # TODO: this should be a single clip...
        scaling_funcs[
            channel
        ] = lambda x, min_value=min_value, max_value=max_value: np.clip(
            (np.clip(x, min_value, max_value) - min_value) / (max_value - min_value),
            0,
            1,
        )
    return scaling_funcs

In [None]:
def mosaic_animate_scale(
    filename,
    scale=1,
    timepoints=None,
    width=1024,
    height=1024,
    # frame_rate=1, #TODO
    channels=default_channels,
    scaling_funcs=None,
    delayed=True,
    # ignore_exceptions=True,
):
    if delayed is True:
        delayed = dask.delayed(pure=True)
    elif delayed is False:
        delayed = lambda func, **kwargs: func
    # TODO
    # if ignore_exceptions:
    #     excepts_get_nd2_frame = excepts(Exception, get_nd2_frame)
    #     excepts_segmentation_func = excepts(Exception, segmentation_func)
    #     excepts_measure = excepts(Exception, measure)
    # else:
    #     excepts_get_nd2_frame = get_nd2_frame
    #     excepts_segmentation_func = segmentation_func
    #     excepts_measure = measure
    nd2 = nd2reader.ND2Reader(filename)
    nd2s = {filename: nd2 for filename in (filename,)}
    metadata = {
        nd2_filename: workflow.parse_nd2_metadata(nd2)
        for nd2_filename, nd2 in nd2s.items()
    }
    positions = workflow.get_position_metadata(metadata)
    # TODO
    # small_positions = positions[(positions["y_idx"] < 3) & (positions["x_idx"] < 3)]
    image_limits = workflow.get_filename_image_limits(metadata)
    get_frame_func = partial(
        colorized_frame,
        channel_to_color,
        workflow.get_nd2_frame,
        filename,
        scaling_funcs=scaling_funcs,
    )
    input_dims = (
        image_limits[filename][0][1] + 1,
        image_limits[filename][1][1] + 1,
    )
    if isinstance(scale, Number):
        if timepoints is None:
            timepoints = range(nd2.sizes["t"])
    else:
        if timepoints is None:
            timepoints = cycle(range(nd2.sizes["t"]))
    animation = [
        delayed(mosaic_frame)(get_frame_func, positions, input_dims, t, scale=s)
        for t, s in zip(timepoints, scale)
    ]
    return animation

In [None]:
%%time
extrema = get_intensity_extrema(nd2, ("YFP-DUAL", "RFP-Penta"))

In [None]:
extrema

In [None]:
scaling_funcs = get_scaling_funcs(extrema)

In [None]:
scaling_funcs = get_scaling_funcs({"YFP-DUAL": (262, 8000), "RFP-Penta": (278, 8000)})

In [None]:
img = nd2.get_frame_2D(t=80, v=0, c=nd2.metadata["channels"].index("YFP-DUAL"))

In [None]:
img_scaled = scaling_funcs["YFP-DUAL"](img)

In [None]:
plt.imshow(img_scaled)

In [None]:
img = nd2.get_frame_2D(t=10, v=0, c=nd2.metadata["channels"].index("RFP-Penta"))

In [None]:
img_scaled = scaling_funcs["RFP-Penta"](img)

In [None]:
plt.imshow(img_scaled)

In [None]:
%%time
scale = repeat(0.3)  # [0.3, 0.3, 0.3, 0.3]
# timepoints = range(0, 119, 30)
# timepoints = range(0, 110, 10)#[20,40,60]
timepoints = [20, 40, 60]
animation_delayed = mosaic_animate_scale(
    nd2_filename,
    scale,
    timepoints=timepoints,
    scaling_funcs=scaling_funcs,
    delayed=True,
)

In [None]:
animation_future = client.compute(animation_delayed)

In [None]:
a = client.gather(animation_future[1])

In [None]:
a[0].max()

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(a)

In [None]:
export_video(a, "/home/jqs1/scratch/jqs1/microscopy/220704/mosaics/export_test.mp4")

In [None]:
Video("/home/jqs1/scratch/jqs1/microscopy/220704/mosaics/export_test.mp4", embed=True)

In [None]:
p

# FISH

In [None]:
fish_dir = Path("/home/jqs1/scratch/jqs1/microscopy/220718/FISH/real_run")

In [None]:
def get_fish_frame(filename, v, channel, t):
    with h5py.File(filename / f"fov={v}_config={channel}_t={t}") as f:
        frame = f["data"][()]
    return frame

In [None]:
fish_colors = {
    "BF": "#ffffff",
    "RFP": "#e22400",
    # "YFP-DUAL": "#f5eb00",
    # "GFP": "#76ba40",
    "Cy5": "#e292fe",
    "Cy7": "#00faff"
    # "BFP": "#3a87fd",
}

In [None]:
fish_scaling_funcs = get_scaling_funcs(
    {
        "BF": (40_000, 65_500),
        "RFP": (8000, 15000),
        "Cy5": (5_000, 40000),
        "Cy7": (2500, 4_000),
    }
)

In [None]:
fish_scaling_funcs["Cy7"](2172)

In [None]:
a = colorized_frame(
    fish_colors,
    get_fish_frame,
    fish_dir,
    1,
    8,
    ["Cy7"],
    scaling_funcs=fish_scaling_funcs,
)
plt.figure(figsize=(40, 40))
plt.imshow(a)

In [None]:
a = colorized_frame(
    fish_colors,
    get_fish_frame,
    fish_dir,
    1,
    8,
    ["Cy5"],
    scaling_funcs=fish_scaling_funcs,
)
plt.figure(figsize=(40, 40))
plt.imshow(a)

In [None]:
a = colorized_frame(
    fish_colors,
    get_fish_frame,
    fish_dir,
    1,
    8,
    ["RFP"],
    scaling_funcs=fish_scaling_funcs,
)
plt.figure(figsize=(40, 40))
plt.imshow(a)

In [None]:
a = colorized_frame(
    fish_colors,
    get_fish_frame,
    fish_dir,
    1,
    8,
    ["RFP", "Cy5", "Cy7"],
    scaling_funcs=fish_scaling_funcs,
)
plt.figure(figsize=(40, 40))
plt.imshow(a)

In [None]:
a = colorized_frame(
    fish_colors,
    get_fish_frame,
    fish_dir,
    1,
    8,
    ["BF", "RFP", "Cy5", "Cy7"],
    scaling_funcs=fish_scaling_funcs,
)
plt.figure(figsize=(40, 40))
plt.imshow(a)

In [None]:
a = colorized_frame(
    fish_colors,
    get_fish_frame,
    fish_dir,
    1,
    8,
    ["BF", "RFP", "Cy5", "Cy7"],
    scaling_funcs=fish_scaling_funcs,
)
plt.figure(figsize=(40, 40))
plt.imshow(a)

In [None]:
%%time
frames = [
    colorized_frame(
        fish_colors,
        get_fish_frame,
        fish_dir,
        t,
        8,
        ["Cy7"],
        scaling_funcs=fish_scaling_funcs,
    )
    for t in trange(1, 11)
]

In [None]:
%%time
rescaled_frames = [skimage.transform.rescale(f, 0.5, anti_aliasing=True, channel_axis=-1) for f in tqdm(frames)]

In [None]:
%%time
export_video(rescaled_frames, "/home/jqs1/_temp/FISH_Cy7.mp4", fps=5)

In [None]:
!du -hs /home/jqs1/_temp/FISH_BF.mp4