# Imports

In [None]:
import numpy as np
import random
from tqdm.auto import tqdm
import nd2reader
import skimage
from skimage.transform import SimilarityTransform, warp
import holoviews as hv
import matplotlib.pyplot as plt
from matplotlib.colors import hex2color
from paulssonlab.image_analysis import workflow

In [None]:
hv.extension("bokeh")

# Config

In [None]:
nd2_filenames = ["/home/jqs1/scratch/jqs1/microscopy/220704/220704rbs_library_fish.nd2"]

In [None]:
nd2 = nd2reader.ND2Reader(nd2_filenames[0])

In [None]:
channel_to_color = {
    "BF": "#ffffff",
    "RFP-PENTA": "#e22400",
    "YFP-DUAL": "#f5eb00",
    # "GFP": "#76ba40",
    # "CY5": "#e292fe",
    # "BFP": "#3a87fd",
}

# Mosaic

In [None]:
nd2s = {filename: nd2reader.ND2Reader(filename) for filename in nd2_filenames}
metadata = {
    filename: workflow.parse_nd2_metadata(nd2) for filename, nd2 in nd2s.items()
}
positions = workflow.get_position_metadata(metadata)
image_limits = workflow.get_filename_image_limits(metadata)

In [None]:
# nd2 = nd2reader.ND2Reader(nd_filenames[0])
# all_frames, metadata = workflow.get_nd2_frame_list(nd2_filenames)#

In [None]:
positions

In [None]:
positions["x_idx"].max()

In [None]:
positions["y_idx"].max()

In [None]:
def composite_channels(imgs, hexcolors, scale=True):
    colors = [hex2color(hexcolor) for hexcolor in hexcolors]
    return _composite_channels(imgs, colors, scale=scale)


def _composite_channels(channel_imgs, colors, scale=True):
    if len(channel_imgs) != len(colors):
        raise ValueError("expecting equal numbers of channels and colors")
    num_channels = len(channel_imgs)
    if scale:
        scaled_imgs = [
            channel_imgs[i][:, :, np.newaxis] / np.percentile(channel_imgs[i], 99.9)
            for i in range(num_channels)
        ]
        for scaled_img in scaled_imgs:
            np.clip(scaled_img, 0, 1, scaled_img)  # clip in place
    else:
        scaled_imgs = channel_imgs
    imgs_to_combine = [
        scaled_imgs[i] * np.array(colors[i]) for i in range(num_channels)
    ]
    if not len(imgs_to_combine):
        imgs_to_combine = [np.ones(colored_imgs[0].shape)]  # white placeholder
    img = imgs_to_combine[0]
    for img2 in imgs_to_combine[1:]:
        img = 1 - (1 - img) * (1 - img2)
    return img

In [None]:
channels = ["YFP-DUAL", "RFP-PENTA"]
imgs = [
    nd2.get_frame_2D(v=50, t=50, c=nd2.metadata["channels"].index(channel))
    for channel in channels
]
img = composite_channels(imgs, [channel_to_color[channel] for channel in channels])

In [None]:
positions

In [None]:
saved_imgs = [
    nd2.get_frame_2D(v=20, t=50, c=nd2.metadata["channels"].index(channel))
    for channel in ("YFP-DUAL", "RFP-PENTA")
]

In [None]:
def colorized_frame(nd2, t=0, v=0, channels=("YFP-DUAL", "RFP-PENTA")):
    # imgs = [nd2.get_frame_2D(v=v, t=t, c=nd2.metadata["channels"].index(channel)) for channel in channels]
    imgs = saved_imgs
    # img = composite_channels(imgs, [channel_to_color[channel] for channel in channels])
    img = composite_channels(
        imgs, ["#{:06x}".format(random.randint(0, 0xFFFFFF)) for channel in channels]
    )
    return img

In [None]:
list(positions.iterrows())[0]

In [None]:
small_positions = positions[(positions["y_idx"] < 3) & (positions["x_idx"] < 3)]

In [None]:
def rectangles_intersect(ul1, lr1, ul2, lr2):
    print(ul1, lr1, ul2, lr2)
    return not (
        (ul1[0] > lr2[0]) or (lr1[0] < ul2[0]) or (ul1[1] > lr2[1]) or (lr1[1] < ul2[1])
    )
    # (  (  R1.topLeft.x  >  R2.bottomRight.x  )||
    # (  R1.bottomRight.x  <  R2.topLeft.x  )  ||
    # (  R1.topLeft.y > R2.bottomRight.y ) ||
    #       (   R1.bottomRight.y  <  R2.topLeft.y   ) )


def scale_around_center(scale, center):
    x, y = center
    return (
        SimilarityTransform(translation=(-x, -y))
        + SimilarityTransform(scale=scale)
        + SimilarityTransform(translation=(x, y))
    )


def output_transformation(input_width, input_height, output_width, output_height):
    width_ratio = input_width / output_width
    height_ratio = input_height / output_height
    scale = max(width_ratio, height_ratio)
    x = -(output_width - input_width / scale) / 2
    y = -(output_height - input_height / scale) / 2
    return SimilarityTransform(translation=(x, y)) + SimilarityTransform(scale=scale)


def mosaic_frame(
    get_frame_func,
    positions,
    image_dims,
    timepoint,
    center=None,
    scale=1,
    output_dims=(1024, 1024),
):
    if center is None:
        columns = positions["x_idx"].max() - positions["x_idx"].min() + 1
        rows = positions["y_idx"].max() - positions["y_idx"].min() + 1
        center = (image_dims[0] * columns / 2, image_dims[1] * rows / 2)
    viewport_transform = output_transformation(*image_dims, *output_dims)
    output_img = np.zeros((output_dims[1], output_dims[0], 3))
    viewport_ul = (0, 0)
    viewport_lr = (output_dims[0] - 1, output_dims[1] - 1)  # TODO: off-by-one?
    for x_col in range(3):
        for y_col in range(3):
            frame_corner = (-image_dims[0] * x_col, -image_dims[1] * y_col)
            frame_transform = (
                output_transformation(*image_dims, *output_dims)
                + scale_around_center(1 / scale, (image_dims[0] / 2, image_dims[1] / 2))
                + SimilarityTransform(
                    translation=(
                        center[0] - image_dims[0] / 2,
                        center[1] - image_dims[1] / 2,
                    )
                )
                + SimilarityTransform(translation=frame_corner)
            )
            frame_ul = frame_transform.inverse((0, 0))[0]
            frame_lr = frame_transform.inverse((image_dims[0] - 1, image_dims[1] - 1))[
                0
            ]
            visible = rectangles_intersect(viewport_ul, viewport_lr, frame_ul, frame_lr)
            if visible:
                output_img += warp(
                    img, frame_transform, output_shape=output_dims[::-1]
                ) * (0.5 + 0.5 * np.random.random())
    return output_img
    # for position in positions.iterrows():
    #    pass


out_img = mosaic_frame(colorized_frame, small_positions, (5056, 2960), 0, scale=0.1)
plt.imshow(out_img)

# output_shape = (512, 512)
# tr = (SimilarityTransform(scale=2, translation=(-256, -256)) +
#       output_transformation(*(5056, 2960), *output_shape))
#       # + SimilarityTransform(translation=(0, 2960))
# img2 = warp(img, tr, output_shape=output_dims[::-1])
# plt.imshow(img2)

In [None]:
img.shape

In [None]:
image_limits

In [None]:
img2 = warp(
    img, SimilarityTransform(scale=1, translation=(0, 2959 / 2))
)  # , output_shape=(512,512))

In [None]:
plt.imshow(img2)

In [None]:
SimilarityTransform(scale=0.5, translation=(10, 0)) + SimilarityTransform(
    scale=1, translation=(10, 0)
)

In [None]:
SimilarityTransform(scale=1, translation=(0, 0))

In [None]:
image_limits

In [None]:
def mosaic_animation(
    nd2_filenames,
    timepoints=slice(None),
    grid_size=None,
    downsize=1,
    width=512,
    height=512,
    frame_rate=1,
):
    nd2s = {filename: nd2reader.ND2Reader(filename) for filename in nd2_filenames}
    metadata = {
        filename: workflow.parse_nd2_metadata(nd2) for filename, nd2 in nd2s.items()
    }
    positions = workflow.get_position_metadata(metadata)
    image_limits = workflow.get_filename_image_limits(metadata)
    pass