# Imports

In [None]:
import itertools as it
from numbers import Number

import cv2
import dask
import dask.array as da
import distributed
import nd2reader
import numpy as np
import pandas as pd
from cytoolz import partial
from dask_jobqueue import SLURMCluster
from distributed import Client
from skimage.transform import SimilarityTransform, warp

In [None]:
from dask.diagnostics import ProgressBar

pbar = ProgressBar()
pbar.register()

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext pyinstrument

In [None]:
from paulssonlab.image_analysis import mosaic, workflow
from paulssonlab.util.ui import display_image

# Config

In [None]:
nd2_filename = "/home/jqs1/scratch/jqs1/microscopy/220704/220704rbs_library_fish.nd2"
# nd2_filename = "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2"

In [None]:
nd2 = nd2reader.ND2Reader(nd2_filename)

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="02:00:00",
    memory="4GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(1)

# Functions

In [None]:
from paulssonlab.image_analysis.workflow import (
    get_filename_image_limits,
    get_nd2_frame,
    get_position_metadata,
    parse_nd2_metadata,
)


def rectangles_intersect(ul1, lr1, ul2, lr2):
    return not (
        (ul1[0] > lr2[0]) or (lr1[0] < ul2[0]) or (ul1[1] > lr2[1]) or (lr1[1] < ul2[1])
    )


def scale_around_center(scale, center):
    x, y = center
    return (
        SimilarityTransform(translation=(-x, -y))
        + SimilarityTransform(scale=scale)
        + SimilarityTransform(translation=(x, y))
    )


def fixed_aspect_scale(input_width, input_height, output_width, output_height):
    width_ratio = output_width / input_width
    height_ratio = output_height / input_height
    scale = min(width_ratio, height_ratio)
    return scale


def transform_to_viewport(
    input_width,
    input_height,
    output_width,
    output_height,
    center_x,
    center_y,
    output_corner_x,
    output_corner_y,
):
    translation = SimilarityTransform(
        translation=(
            center_x - output_width / 2 + output_corner_x,
            center_y - output_height / 2 + output_corner_y,
        ),
    )
    transform = translation
    input_ul = transform.inverse((0, 0))[0]
    # TODO: off-by-one?
    input_lr = transform.inverse((input_width - 1, input_height - 1))[0]
    output_ul = (0, 0)
    # TODO: off-by-one?
    output_lr = (output_width - 1, output_height - 1)
    visible = rectangles_intersect(input_ul, input_lr, output_ul, output_lr)
    return transform, visible


def foo(img):
    return cv2.resize(img, 0.1, 0.9)


def mosaic_frame(
    get_frame_func,
    channels,
    positions,
    input_dims,
    *,
    timepoint=None,
    scale=1,
    output_dims=(1024, 1024),
    scaling_funcs=None,
    dtype=np.float32,
):
    delayed = dask.delayed(pure=True)
    columns = positions["x_idx"].max() - positions["x_idx"].min() + 1
    rows = positions["y_idx"].max() - positions["y_idx"].min() + 1
    center = np.array([input_dims[0] * columns / 2, input_dims[1] * rows / 2])
    all_channel_imgs = [[] for _ in range(len(channels))]
    input_scale = fixed_aspect_scale(
        *input_dims, output_dims[0] * scale, output_dims[1] * scale
    )
    rescaled_input_dims = np.ceil(np.array(input_dims) * input_scale).astype(np.int_)
    for (filename, pos_num), position in positions.iterrows():
        for channel, channel_imgs in zip(channels, all_channel_imgs):
            img = delayed(get_frame_func)(pos_num, channel, timepoint)
            if scaling_funcs:
                img = delayed(scaling_funcs[channel])(img)
            # img = delayed(cv2.resize)(
            #     img, rescaled_input_dims, interpolation=cv2.INTER_AREA
            # )
            # img = delayed(cv2.warpAffine)(
            #     img,
            #     frame_transform.params[:2, :],
            #     output_dims,
            #     # flags=cv2.INTER_AREA + cv2.WARP_INVERSE_MAP,
            #     flags=(cv2.INTER_LANCZOS4 + cv2.WARP_INVERSE_MAP),
            # )
            # img = delayed(cv2.resize)(img, 0.1, 0.9)
            img = delayed(foo)(img)
            img = delayed(np.clip)(img, 0, 1)  # LANCZOS4 outputs values beyond 0..1
            img = da.from_delayed(img, output_dims[::-1], dtype=dtype)
            channel_imgs.append(img)
    output = [da.stack(channel_imgs).sum(axis=0) for channel_imgs in all_channel_imgs]
    return output


def mosaic_animate_scale(
    filename,
    scale=1,
    timepoints=None,
    output_dims=(3840, 2160),
    scaling_funcs=None,
):
    frame_func = mosaic_frame
    delayed = dask.delayed(pure=True)
    nd2 = nd2reader.ND2Reader(filename)
    channels = ["YFP-DUAL", "RFP-PENTA"]
    nd2s = {filename: nd2 for filename in (filename,)}
    metadata = {
        nd2_filename: parse_nd2_metadata(nd2) for nd2_filename, nd2 in nd2s.items()
    }
    positions = get_position_metadata(metadata)
    image_limits = get_filename_image_limits(metadata)
    get_frame_func = partial(
        get_nd2_frame,
        filename,
    )
    input_dims = (
        image_limits[filename][0][1] + 1,
        image_limits[filename][1][1] + 1,
    )
    if isinstance(scale, Number):
        if timepoints is None:
            timepoints = range(nd2.sizes["t"])
    else:
        if timepoints is None:
            timepoints = it.cycle(range(nd2.sizes["t"]))
    ts_iter = list(zip(timepoints, scale))
    animation = [
        frame_func(
            get_frame_func,
            channels,
            positions,
            input_dims,
            timepoint=t,
            scale=s,
            scaling_funcs=scaling_funcs,
            output_dims=output_dims,
        )
        for t, s in ts_iter
    ]
    return animation

# Mosaic

In [None]:
scaling_funcs = mosaic.get_scaling_funcs(
    {"YFP-DUAL": (300, 2000), "RFP-PENTA": (300, 2000)}
)

In [None]:
%%time
animation_delayed = mosaic_animate_scale(
    nd2_filename,
    it.repeat(0.1),
    timepoints=range(20),
    scaling_funcs=scaling_funcs,
)

In [None]:
%%time
# animation_future = client.compute(animation_delayed)
animation_future = [client.compute(frame) for frame in animation_delayed]

In [None]:
del animation_future

In [None]:
%%time
a = client.gather(animation_future)

# Pickle

In [None]:
import cloudpickle

In [None]:
cloudpickle.dumps(cv2.resize)

In [None]:
cloudpickle.register_pickle_by_value(cv2)

In [None]:
import distributed.protocol

In [None]:
distributed.protocol.pickle.dumps(dask.delayed(cv2.resize)("h", "b"))