# Imports

In [None]:
import itertools as it
from pathlib import Path

import dask
import distributed
import h5py
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
from cytoolz import partial
from dask_jobqueue import SLURMCluster
from distributed import Client
from IPython.display import Video
from PIL import Image, ImageDraw, ImageFont
from tqdm.auto import tqdm, trange

In [None]:
from dask.diagnostics import ProgressBar

pbar = ProgressBar()
pbar.register()

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext pyinstrument

In [None]:
from paulssonlab.image_analysis import mosaic, workflow
from paulssonlab.util.ui import display_image

# Config

In [None]:
nd2_filename = "/home/jqs1/scratch/jqs1/microscopy/220704/220704rbs_library_fish.nd2"
# nd2_filename = "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2"

In [None]:
nd2 = nd2reader.ND2Reader(nd2_filename)

In [None]:
# default_channels = ("YFP-DUAL", "RFP-Penta")
default_channels = ("YFP-DUAL", "RFP-PENTA")
channel_to_color = {
    "BF": "#ffffff",
    "RFP-PENTA": "#C500BB",  # "#FF5AF6",
    "RFP-Penta": "#C500BB",
    # "RFP-Penta": "#e22400",
    # "YFP-DUAL": "#13FF00",
    "YFP-DUAL": "#FAFF00",
    # "GFP": "#76ba40",
    "Cy5": "#e292fe",
    # "Cy7": "#FF0000"
    # "BFP": "#3a87fd",
}

In [None]:
font = ImageFont.truetype("fira/FiraSans-Medium.ttf")

In [None]:
dask.config.set({"distributed.scheduler.allowed-failures": 10})

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="02:00:00",
    memory="16GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=2,
    processes=2,
)
client = Client(cluster)

In [None]:
cluster.adapt(maximum=10)

In [None]:
cluster

In [None]:
cluster.scale(40)

# Mosaic

In [None]:
%%time
# extrema = mosaic.get_intensity_extrema(nd2, ("YFP-DUAL", "RFP-Penta"))

In [None]:
extrema

In [None]:
# scaling_funcs = mosaic.get_scaling_funcs(extrema)
# scaling_funcs = mosaic.get_scaling_funcs(
#     {"YFP-DUAL": (262, 8000), "RFP-Penta": (278, 8000)}
# )
scaling_funcs = mosaic.get_scaling_funcs(
    {"YFP-DUAL": (300, 2000), "RFP-PENTA": (300, 2000)}
)

In [None]:
def positions_func(positions):
    return positions.drop(["x_idx", "y_idx"], axis=1).join(
        positions["position_name"].apply(
            lambda name: pd.Series(
                [int(idx) for idx in name.split(".")], index=["y_idx", "x_idx"]
            )
        )
    )

In [None]:
%%time
num_t = nd2.sizes["t"]
# scale = [80,40,20,10,5,1]#it.repeat(80)
# scale = np.geomspace(80, 0.1, 900)
# scale = np.geomspace(0.5, 0.1, 60)
# scale = [20]
# scale = [0.3]
# scale = it.repeat(1)
# scale = [0.333]
# scale = it.repeat(0.05)
# scale = it.repeat(0.1)
# scale = it.repeat(0.2)
# scale = it.repeat(0.4)
# scale = np.geomspace(20, 0.05, num_t*3)# + np.linspace(0, 0, 0) #it.repeat(0.3)  # [0.3, 0.3, 0.3, 0.3]
# scale = np.geomspace(80, 0.05, 120*2)
# timepoints = range(0, 119, 30)
# timepoints = range(0, 110, 10)#[20,40,60]
# timepoints = [60]  # it.repeat(60)  # [20]  # [20,40,60,80]
# timepoints = range(num_t)
# timepoints = [num_t-1]
# timepoints = [0,60,num_t-1]
# timepoints = it.chain(range(num_t), range(num_t), range(num_t)) #[20]#[20, 40, 60]
# timepoints = [0, 30, 60, 90, 120]
# scale = [80]
scale = np.geomspace(80, 0.03, 2 * num_t)
# timepoints = it.repeat(0)
# timepoints = it.cycle(range(num_t))
timepoints = np.repeat(np.arange(num_t), 2)
offset = [-20, 820]  # [0,0]#np.array([604, 354])
# offset = [4000, -7000]
# offset = [0,0]
animation_delayed = mosaic.mosaic_animate_scale(
    nd2_filename,
    scale,
    timepoints=timepoints,
    scaling_funcs=scaling_funcs,
    offset=offset,
    # rotation=np.deg2rad(-0.65),
    rotation=np.deg2rad(0.2),
    channels=default_channels,
    channel_to_color=channel_to_color,
    # overlay_only=True,
    overlay_func=partial(
        mosaic.square_overlay,
        min_scale=40,
        min_n=0,
        min_width=0.5,
        max_scale=0.1,
        max_n=5,
        max_width=0.9,
        font=font,
    ),
    # positions_func=positions_func,
    output_dims=(3840, 2160),
    # output_dims=(1024, 1024),
    # output_dims=(1024, 512),
    delayed=True,
)

In [None]:
%%time
a = dask.compute(animation_delayed, scheduler="sync")[0]

In [None]:
display_image(a[0])

In [None]:
display_image(a[0])

In [None]:
display_image(a[0])

In [None]:
display_image(a[0])

In [None]:
%%time
# animation_future = client.compute(animation_delayed)
animation_future = [client.compute(frame) for frame in tqdm(animation_delayed)]

In [None]:
%%time
a = client.gather(animation_future)

In [None]:
# video_dir = Path("/home/jqs1/scratch/jqs1/microscopy/220718/mosaics/")
video_dir = Path("/home/jqs1/scratch/jqs1/microscopy/220704/mosaics/")
video_dir.mkdir(exist_ok=True)

In [None]:
%%time
mosaic.export_video(
    # a,
    animation_future,
    video_dir / "230425zoom_slow.mp4",
    fps=30,
)

In [None]:
client.restart();

In [None]:
Video(
    video_dir / "230423zoomtest_linescaling.mp4",
    embed=True,
)

# Grid debugging

In [None]:
from paulssonlab.image_analysis.workflow import (
    get_filename_image_limits,
    get_nd2_frame,
    get_position_metadata,
    parse_nd2_metadata,
)

In [None]:
nd2 = nd2reader.ND2Reader(nd2_filename)
nd2s = {filename: nd2 for filename in (nd2_filename,)}
metadata = {nd2_filename: parse_nd2_metadata(nd2) for nd2_filename, nd2 in nd2s.items()}
positions = get_position_metadata(metadata)

In [None]:
positions = positions_func(positions)

In [None]:
positions

In [None]:
first_row = positions[positions["y_idx"] == positions["y_idx"].min()]

In [None]:
first_row

In [None]:
upper_left = first_row[first_row["x_idx"] == first_row["x_idx"].min()].squeeze()

In [None]:
upper_right = first_row[first_row["x_idx"] == first_row["x_idx"].max()].squeeze()

In [None]:
upper_right["y"]

In [None]:
z = np.arctan2(upper_right["y"] - upper_left["y"], upper_right["x"] - upper_left["x"])

In [None]:
np.rad2deg(
    np.arctan2(upper_left["y"] - upper_right["y"], upper_left["x"] - upper_right["x"])
)

In [None]:
positions

# Overlay

In [None]:
# %%pyinstrument
# output_dims=(3840, 2160)
# output_dims=(384, 216)
output_dims = (1024, 1024)
frame = np.zeros((*output_dims[::-1], 3), dtype=np.float32) + 0.5
img = mosaic.square_overlay(
    frame,
    0,
    20,
    min_scale=80,
    min_n=0,
    min_width=0.5,
    max_scale=0.1,
    max_n=5,
    max_width=0.9,
    font=font,
)
display_image(img)

In [None]:
import cv2
from skimage.transform import SimilarityTransform, warp

In [None]:
# b=cv2.resize(img, img.shape[:-1], fx=0.2, fy=0.2, interpolation=cv2.INTER_AREA)
# transform = SimilarityTransform(scale=0.2)
transform = mosaic.scale_around_center(0.2, np.array(output_dims) / 2)
b = cv2.warpAffine(
    img,
    transform.params[:2, :],
    img.shape[:-1],
    # flags=cv2.INTER_AREA + cv2.WARP_INVERSE_MAP,
    flags=(cv2.INTER_LANCZOS4),
)

In [None]:
b.shape

In [None]:
display_image(b)

In [None]:
# transform = mosaic.scale_around_center(0.2, np.array(output_dims)/2)
# b = cv2.warpAffine(img, transform.params[:2, :],
#                     img.shape[:-1],
#                     # flags=cv2.INTER_AREA + cv2.WARP_INVERSE_MAP,
#                     flags=(cv2.INTER_LANCZOS4))
c = cv2.resize(
    img,
    np.ceil(np.array(img.shape[:-1]) * 0.2).astype(np.int32),
    interpolation=cv2.INTER_AREA,
)

In [None]:
display_image(c)

In [None]:
transform = SimilarityTransform(
    translation=np.array(output_dims) / 2
)  # mosaic.scale_around_center(0.2, np.array(output_dims)/2)
d = cv2.warpAffine(
    c,
    transform.params[:2, :],
    img.shape[:-1],
    # flags=cv2.INTER_AREA + cv2.WARP_INVERSE_MAP,
    flags=(cv2.INTER_LANCZOS4),
)

In [None]:
display_image(d)