# Imports

In [None]:
import itertools as it
import os
import re
from collections import namedtuple
from functools import partial
from pathlib import Path

import dask
import distributed
import h5py
import holoviews as hv
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import scipy
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from tqdm.auto import tqdm, trange

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from paulssonlab.image_analysis import *
from paulssonlab.util.ui import display_image

In [None]:
%load_ext pyinstrument

In [None]:
hv.extension("bokeh")

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="06:00:00",
    memory="2GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(1)

In [None]:
cluster.adapt(maximum=20)

# Trench detection

In [None]:
# filename = "/home/jqs1/scratch/jqs1/microscopy/230213/230213induction.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230215/230215induction.nd2" #v=7
# filename = "/home/jqs1/scratch/jqs1/microscopy/230326/230326promrbs.nd2" #v=8,t=10
filename = "/home/jqs1/scratch/jqs1/microscopy/230404/230404_rbsprom.nd2"

In [None]:
nd2 = nd2reader.ND2Reader(filename)

In [None]:
nd2.metadata["channels"]

In [None]:
nd2.sizes

In [None]:
img = nd2.get_frame_2D(v=8, c=0, t=0)

In [None]:
display_image(img, scale=0.9)

In [None]:
display_image(img, scale=0.99)

# Radial distortion correction

In [None]:
k1 = 8.947368421052635e-10
# k1 = 2e-8

In [None]:
%%time
img_t = image.correct_radial_distortion(img, k1=k1)

In [None]:
%%time
diag = util.tree()
trenches, info = trench_detection.find_trenches(
    img_t, width_to_pitch_ratio=1.4 / 3.5, join_info=False, diagnostics=diag
)

In [None]:
info

In [None]:
diag["bboxes"]

# Radial distortion optimization

In [None]:
k1s = np.linspace(-1e-9, 2e-9, 20)
res = []
for k1 in tqdm(k1s):
    img_corrected = image.correct_radial_distortion(img, k1=k1)
    h, theta, rho = trench_detection.hough.hough_line_intensity(
        img_corrected, theta=np.linspace(-np.pi / 50, np.pi / 50, 400)
    )
    smooth = 4
    diff_h = np.diff(h.astype(np.int_), axis=1)  # TODO: is diff necessary??
    diff_h_std = diff_h.std(axis=0)  # / diff_h.max(axis=0)
    if smooth:
        diff_h_std_smoothed = scipy.ndimage.gaussian_filter1d(diff_h_std, smooth)
    else:
        diff_h_std_smoothed = diff_h_std
    theta_idx = diff_h_std_smoothed.argmax()
    diff_h_std_max = diff_h_std_smoothed[theta_idx]
    angle = theta[theta_idx]
    res.append(
        dict(
            k1=k1,
            h=h,
            diff_h=diff_h,
            diff_h_std=diff_h_std,
            diff_h_std_smoothed=diff_h_std_smoothed,
            angle=angle,
            theta_idx=theta_idx,
            img_corrected=img_corrected,
        )
    )

In [None]:
plt.figure(figsize=(20, 20))
for idx in range(len(res)):
    plt.plot(res[idx]["diff_h_std"][120:200], label=idx)
plt.legend()
plt.plot(res[12]["diff_h_std"][120:200], lw=4)

In [None]:
plt.figure(figsize=(20, 20))
idx = 12
plt.plot(res[idx - 2]["diff_h_std"][120:200])
plt.plot(res[idx]["diff_h_std"][120:200], lw=4)
plt.plot(res[idx + 2]["diff_h_std"][120:200])

In [None]:
res[12]["k1"]

In [None]:
display_image(res[-3]["img_corrected"], scale=True)

# FISH correction

In [None]:
# filename = "/home/jqs1/scratch/jqs1/microscopy/230213/230213induction.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230215/230215induction.nd2" #v=7
# filename = "/home/jqs1/scratch/jqs1/microscopy/230326/230326promrbs.nd2" #v=8,t=10
filename = "/home/jqs1/scratch/jqs1/microscopy/230404/230404_rbsprom.nd2"
fish_filename = Path(filename).parent / "FISH/real_run"

In [None]:
k1 = 8.947368421052635e-10

In [None]:
def calibrate_image(img, k1=0):
    img = skimage.img_as_float32(img)
    img = image.correct_radial_distortion(img, k1=k1)
    return img

In [None]:
%%time
delayed = util.get_delayed(True)
fish_frames = {}
fish_crops = {}
fish_channels = set()
fish_timepoints = set()
for msg in readers.send_eaton_fish(
    fish_filename,
    r"fov=(?P<v>\d+)_config=(?P<c>\w+)_t=(?P<t>\d+)",
    slices=dict(t=None, v=[8]),
    delayed=delayed,
):
    # print(msg["metadata"],msg["image"].shape)
    fish_img = msg["image"]
    # fish_img_corrected = delayed(calibrate_image)(fish_img, k1=k1)
    fish_img_corrected = delayed(skimage.img_as_float32)(fish_img)
    t = msg["metadata"]["t"]
    channel = msg["metadata"]["channel"]
    fish_channels.add(channel)
    fish_timepoints.add(t)
    fish_frames[(t, channel)] = fish_img_corrected
fish_channels = list(sorted(fish_channels))
fish_timepoints = list(sorted(fish_timepoints))

In [None]:
fish_frames0 = dask.compute(fish_frames)[0]

In [None]:
%%time
stacks = {}
for channel in fish_channels:
    for timepoint_idx, timepoint in enumerate(fish_timepoints):
        img = fish_frames0[(timepoint, channel)]
        if channel not in stacks:
            stacks[channel] = np.full((len(fish_timepoints), *img.shape), np.nan)
        stacks[channel][timepoint_idx, :, :] = img

In [None]:
stacks["GFP"].shape

In [None]:
stacks["GFP"][3:9].max(axis=0)

In [None]:
stacks["GFP"][3:9].max(axis=0)

In [None]:
display_image(stacks["GFP"][3:9].max(axis=0), scale=0.99)

In [None]:
display_image(stacks["GFP"][3:9].min(axis=0), scale=0.99)

In [None]:
display_image(
    stacks["GFP"][3:9].max(axis=0) - stacks["GFP"][3:9].min(axis=0), scale=0.99
)

In [None]:
display_image(stacks["GFP"][:9].max(axis=0) - stacks["GFP"][:9].min(axis=0), scale=0.99)

In [None]:
info

# Drift correction

In [None]:
imgs = {t: nd2.get_frame_2D(v=8, c=0, t=t)[:500, :500] for t in trange(225)}

In [None]:
hv.HoloMap({k: ui.RevImage(v) for k, v in imgs.items()})

In [None]:
%%time
diag = util.tree()
trenches, info = trench_detection.find_trenches(
    imgs[0], width_to_pitch_ratio=1.4 / 3.5, join_info=False, diagnostics=diag
)

In [None]:
diag["bboxes"]

In [None]:
%%time
diag2 = util.tree()
trenches2, info2 = trench_detection.find_trenches(
    imgs[20], width_to_pitch_ratio=2 / 3.5, join_info=False, diagnostics=diag2
)

In [None]:
diag2["bboxes"]

In [None]:
%%time
diag3 = util.tree()
trenches3, info3 = trench_detection.find_trenches(
    imgs[210], width_to_pitch_ratio=1.4 / 3.5, join_info=False, diagnostics=diag3
)

In [None]:
diag3["bboxes"]

In [None]:
ui.RevImage(imgs[210]) * trench_detection.plot_trenches(trenches2)

In [None]:
ui.RevImage(imgs[210]) * trench_detection.plot_trenches(trenches)

In [None]:
trench_detection.plot_trenches(trenches2).opts(
    hv.opts.Rectangles(line_color="blue")
) * trench_detection.plot_trenches(trenches3)

In [None]:
# idx=59
idx = 80
crop = get_crop(imgs[idx], trenches2, 10)
pts = trench_cell_endpoints(crop)
ui.RevImage(crop).opts(frame_width=40) * hv.Points(pts + 0.5).opts(color="red", size=4)
# hv.Points(pts)

In [None]:
plots = {}
for t in trange(60, 90, 1):
    crop = get_crop(imgs[t], trenches2, 11)
    pts = trench_cell_endpoints(crop)
    plots[t] = ui.RevImage(crop).opts(frame_width=40) * hv.Points(pts + 0.5).opts(
        color="red", size=4
    )

In [None]:
hv.HoloMap(plots)

In [None]:
import warnings

TRENCH_COORDINATE_COLUMNS = set(["top", "bottom", "ul", "lr"])


def trench_cell_endpoints(img, sigma=2, k=2, min_height=0.3, margin_factor=1):
    img = skimage.img_as_float(img)
    profile = img.mean(axis=1)
    grad = misc.holoborodko_diff.holo_diff(
        1, scipy.ndimage.gaussian_filter1d(profile, sigma)
    )
    with warnings.catch_warnings(
        action="ignore", category=scipy.signal._peak_finding_utils.PeakPropertyWarning
    ):
        pos_peaks, pos_peak_props = scipy.signal.find_peaks(
            grad, height=min_height * grad.max(), width=(None, None)
        )
        neg_peaks, neg_peak_props = scipy.signal.find_peaks(
            -grad, height=-min_height * grad.min(), width=(None, None)
        )
    y1 = pos_peaks[0]
    y2 = neg_peaks[-1]
    margin = int(
        np.ceil(
            margin_factor
            * (pos_peak_props["widths"][0] + neg_peak_props["widths"][-1])
            / 2
        )
    )
    cutoff1 = min(y1 + 1 + margin, img.shape[0])
    cutoff2 = max(y2 - margin, 0)
    weights1 = profile[:cutoff1, np.newaxis]
    weights2 = profile[cutoff2:, np.newaxis]
    x1 = (
        np.arange(img.shape[1])[np.newaxis, :] * (img[:cutoff1, :] * weights1) ** k
    ).sum() / ((img[:cutoff1, :] * weights1) ** k).sum()
    x2 = (
        np.arange(img.shape[1])[np.newaxis, :] * (img[cutoff2:, :] * weights2) ** k
    ).sum() / ((img[cutoff2:, :] * weights2) ** k).sum()
    return np.array([[x1, y1], [x2, y2]])


def get_crop(img, trenches, trench_idx):
    ul_x = trenches["ul_x"].values
    ul_y = trenches["ul_y"].values
    lr_x = trenches["lr_x"].values
    lr_y = trenches["lr_y"].values
    return img[
        ul_y[trench_idx] : lr_y[trench_idx] + 1, ul_x[trench_idx] : lr_x[trench_idx] + 1
    ]


def _coordinate_columns(columns):
    cols_x = set([f"{col}_x" for col in TRENCH_COORDINATE_COLUMNS]) & set(columns)
    cols_y = set([f"{col}_y" for col in TRENCH_COORDINATE_COLUMNS]) & set(columns)
    return cols_x, cols_y


def filter_trenches(trenches, image_limits):
    x_lim = image_limits[0]
    y_lim = image_limits[1]
    cols_x, cols_y = _coordinate_columns(trenches.columns)
    return trenches[
        np.logical_and.reduce([trenches[col].between(*x_lim) for col in cols_x])
        & np.logical_and.reduce([trenches[col].between(*y_lim) for col in cols_y])
    ]


def shift_trenches(trenches, shift):
    cols_x, cols_y = _coordinate_columns(trenches.columns)
    coords_x = {col: trenches[col].values + shift[0] for col in cols_x}
    coords_y = {col: trenches[col].values + shift[1] for col in cols_y}
    return trenches.assign(**coords_x, **coords_y)


def find_trench_drift(
    img1,
    img2,
    trenches,
    tolerance=1,
    feature_func=trench_cell_endpoints,
    diagnostics=None,
):
    image_limits = geometry.get_image_limits(img1.shape)
    features1 = {}
    shifted_trenches = filter_trenches(trenches, image_limits)
    for roi_idx, crop, ul in geometry.iter_crops(img1, shifted_trenches, corner=True):
        features1[roi_idx] = feature_func(crop) + ul[np.newaxis, ...]
    shift = np.array([0, 0], dtype=np.int64)
    plot_lines = []
    features2 = {}
    while True:
        shifted_trenches = filter_trenches(
            shift_trenches(trenches, shift), image_limits
        )
        for roi_idx, crop, ul in geometry.iter_crops(
            img2, shifted_trenches, corner=True
        ):
            features2[roi_idx] = feature_func(crop) + ul[np.newaxis, ...]
        for roi_idx in features1.keys() & features2.keys():
            roi_features1 = features1[roi_idx]
            roi_features2 = features2[roi_idx]
            if roi_features1 is None or roi_features2 is None:
                continue
            for feature_idx in range(min(len(roi_features1), len(roi_features2))):
                plot_lines.append(
                    [roi_features1[feature_idx], roi_features2[feature_idx]]
                )
        break
    return plot_lines

In [None]:
%%time
x = find_trench_drift(imgs[20], imgs[40], trenches2)

In [None]:
model_robust, inliers = skimage.measure.ransac(
    (*np.array(x).swapaxes(0, 1),),
    skimage.transform.EuclideanTransform,
    min_samples=3,
    residual_threshold=2,
    max_trials=100,
)

In [None]:
y = [
    [(*c[0], "red" if inlier else "gray"), (*c[1], "red" if inlier else "gray")]
    for c, inlier in zip(x, inliers)
]

In [None]:
hv.Path(y, vdims=["color"]).opts(color="color", line_width=2)

In [None]:
model_robust.translation

In [None]:
hv.Path(x)

In [None]:
class TranslationTransform(skimage.transform.EuclideanTransform):
    def estimate(self, src, dst):
        translation = (dst - src).mean(axis=1)
        self.params[0 : self.dimensionality, self.dimensionality] = translation
        return True

In [None]:
plots = {}
for t in trange(225):
    crop = get_crop(imgs[t], trenches2, 5)
    pts = trench_cell_endpoints(crop)
    # plots[t] = ui.RevImage(crop).opts(frame_width=40) * hv.Points(pts + 0.5).opts(
    #     color="red", size=4
    # )
    plots[t] = hv.Curve([(0, t), (2, 4)])

In [None]:
plots = {}
for t in trange(225):
    crop = get_crop(imgs[t], trenches2, 5)
    pts = trench_cell_endpoints(crop)
    plots[t] = ui.RevImage(crop).opts(frame_width=40) * hv.Points(pts + 0.5).opts(
        color="red", size=4
    )

In [None]:
hv.HoloMap(plots)

In [None]:
hv.HoloMap(plots)