# Imports

In [None]:
import glob
import itertools as it
from functools import partial
from pathlib import Path

import dask
import distributed
import h5py
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import skimage
import zarr
from dask_jobqueue import SLURMCluster
from distributed import Client
from IPython.display import Video
from PIL import Image, ImageDraw, ImageFont
from tqdm.auto import tqdm, trange

In [None]:
from dask.diagnostics import ProgressBar

pbar = ProgressBar()
pbar.register()

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext pyinstrument

In [None]:
from paulssonlab.image_analysis import mosaic, workflow
from paulssonlab.image_analysis.ui import display_image

In [None]:
#!micromamba install -y av

# Config

In [None]:
# nd2_filename = "/home/jqs1/scratch/jqs1/microscopy/221227daniel/Experiment.nd2"
# nd2_filename = "/home/jqs1/scratch/jqs1/microscopy/220718/RBS_DEG_library_20x.nd2"
nd2_filename = workflow.SplitFilename(
    sorted(
        glob.glob(
            "/home/jqs1/scratch/jqs1/microscopy/230619/230619_NAO745_repressilators_split.nd2*"
        )
    )
)

In [None]:
nd2 = workflow.get_nd2_reader(nd2_filename)

In [None]:
nd2.metadata["channels"]

In [None]:
default_channels = ("RFP-EM", "YFP-EM", "CFP-EM")
channel_to_color = {"RFP-EM": "#e22400", "YFP-EM": "#faff00", "CFP-EM": "#00ffea"}

# "BF": "#ffffff",
# "RFP-PENTA": "#C500BB",  # "#FF5AF6",
# "RFP-Penta": "#C500BB",
# # "RFP-Penta": "#e22400",
# # "YFP-DUAL": "#13FF00",
# "YFP-DUAL": "#FAFF00",
# # "GFP": "#76ba40",
# "Cy5": "#e292fe",
# # "Cy7": "#FF0000"
# # "BFP": "#3a87fd",

In [None]:
font = ImageFont.truetype("fira/FiraSans-Medium.ttf")

In [None]:
dask.config.set({"distributed.scheduler.allowed-failures": 10})

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="03:00:00",
    memory="16GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=2,
    processes=2,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(20)

In [None]:
cluster.adapt(maximum=10)

# Mosaic

In [None]:
%%time
# extrema = mosaic.get_intensity_extrema(nd2, ("YFP-DUAL", "RFP-Penta"))

In [None]:
extrema

In [None]:
scaling_funcs = mosaic.get_scaling_funcs(
    {"RFP-EM": (400, 2_000), "YFP-EM": (300, 4_000), "CFP-EM": (600, 1_500)}
)

In [None]:
def positions_func(positions):
    return positions.drop(["x_idx", "y_idx"], axis=1).join(
        positions["position_name"].apply(
            lambda name: pd.Series(
                [int(idx) for idx in name.split(".")], index=["y_idx", "x_idx"]
            )
        )
    )

In [None]:
# dark = skimage.io.imread(
#     "/home/jqs1/scratch/jqs1/microscopy/221227daniel/40x_DarkImage.tiff"
# )
# flats = {
#     "mCherry": skimage.io.imread(
#         "/home/jqs1/scratch/jqs1/microscopy/221227daniel/mCherry_20x_Ph2_flatfield.tiff"
#     )
# }

In [None]:
positions, input_dims = mosaic.get_nd2_metadata(nd2_filename)

In [None]:
min_zoom = 5
max_zoom = 1

num_t = nd2.sizes["t"]
# scale = np.concatenate(
#     (
#         np.repeat(min_zoom, 2 * num_t),
#         np.geomspace(min_zoom, max_zoom, 2 * num_t),
#         np.repeat(max_zoom, num_t // 2),
#         np.geomspace(min_zoom, max_zoom, num_t)[::-1],
#         np.repeat(min_zoom, num_t // 2),
#     )
# )
# timepoints = it.cycle(np.repeat(np.arange(num_t), 2))
scale = it.repeat(max_zoom)
timepoints = [0]  # ,20,40,80,120,160,195]

# timepoints = np.arange(num_t)

In [None]:
%%time
distributed = False
# if distributed:
#     dark_delayed = client.scatter(dark, broadcast=True)
#     flats_delayed = {k: client.scatter(v, broadcast=True) for k, v in flats.items()}
# else:
#     dark_delayed = dark
#     flats_delayed = flats
offset = (-1000, 870)
output_dims = (3840, 2160)
animation_delayed = mosaic.mosaic_animate_scale(
    partial(workflow.get_nd2_frame, nd2_filename),
    scale=scale,
    timepoints=timepoints,
    positions=positions,
    scaling_funcs=scaling_funcs,
    offset=offset,
    # rotation=np.deg2rad(-0.15),  # TODO: necessary?
    channels=default_channels,
    channel_to_color=channel_to_color,
    # overlay_only=True,
    # overlay_func=partial(
    #     mosaic.square_overlay,
    #     min_scale=40,
    #     min_n=0,
    #     min_width=0.25,
    #     max_scale=0.05,
    #     max_n=7,
    #     max_width=0.9,
    #     n_range=(4, 8),
    #     font=font,
    # ),
    # dark=dark_delayed,
    # flats=flats_delayed,
    # positions_func=positions_func,
    input_dims=input_dims,
    output_dims=output_dims,
    # output_dims=(1024, 1024),
    # output_dims=(1024, 512),
    delayed=True,
)

In [None]:
%%time
a = dask.compute(animation_delayed, scheduler="sync")[0]

In [None]:
display_image(a[0], downsample=4)

In [None]:
display_image(a[0], downsample=4)

In [None]:
display_image(a[0], downsample=4)

In [None]:
display_image(a[0], downsample=4)

In [None]:
display_image(a[0], downsample=4)

In [None]:
# video_dir = Path("/home/jqs1/scratch/jqs1/microscopy/220718/mosaics/")
video_dir = Path("/home/jqs1/scratch/jqs1/microscopy/230619/")
video_dir.mkdir(exist_ok=True)

In [None]:
zarr_filename = video_dir / "230619_crop.zarr"

In [None]:
out = zarr.open_array(zarr_filename, mode="r")

In [None]:
%%time
animation_future = [
    client.compute(
        dask.delayed(mosaic.write_to_zarr)(
            zarr_filename, frame, frame_num, len(animation_delayed)
        )
    )
    for frame_num, frame in enumerate(tqdm(animation_delayed))
]

In [None]:
display_image(out[700], downsample=1)

In [None]:
%%time
# animation_future = client.compute(animation_delayed)
animation_future = [client.compute(frame) for frame in tqdm(animation_delayed)]

In [None]:
%%time
a = client.gather(animation_future)

In [None]:
%%time
mosaic.export_video(
    # a,
    # animation_future,
    out,
    video_dir / "230619_crop.mp4",
    fps=30,
)

In [None]:
client.restart();

In [None]:
Video(
    video_dir / "230426zoom_v2_overlay.mp4",
    embed=True,
)