# Imports

In [None]:
import itertools as it
import os
import re
from collections import namedtuple
from functools import partial
from pathlib import Path

import dask
import distributed
import h5py
import holoviews as hv
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import scipy
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from holoviews.operation.datashader import regrid
from tqdm.auto import tqdm, trange

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from paulssonlab.image_analysis import *
from paulssonlab.image_analysis.ui import display_image

In [None]:
%load_ext pyinstrument

In [None]:
hv.extension("bokeh")

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="06:00:00",
    memory="2GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(1)

In [None]:
cluster.adapt(maximum=20)

# Trench detection

In [None]:
# filename = "/home/jqs1/scratch/jqs1/microscopy/210511/RBS_ramp.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230213/230213induction.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230215/230215induction.nd2" #v=7
# filename = "/home/jqs1/scratch/jqs1/microscopy/230326/230326promrbs.nd2" #v=8,t=10
filename = "/home/jqs1/scratch/jqs1/microscopy/230404/230404_rbsprom.nd2"

In [None]:
nd2 = nd2reader.ND2Reader(filename)

# FISH correction

In [None]:
# filename = "/home/jqs1/scratch/jqs1/microscopy/230213/230213induction.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230215/230215induction.nd2" #v=7
# filename = "/home/jqs1/scratch/jqs1/microscopy/230326/230326promrbs.nd2" #v=8,t=10
filename = "/home/jqs1/scratch/jqs1/microscopy/230404/230404_rbsprom.nd2"
fish_filename = Path(filename).parent / "FISH/real_run"

In [None]:
k1 = 8.947368421052635e-10

In [None]:
def calibrate_image(img, k1=0):
    img = skimage.img_as_float32(img)
    img = image.correct_radial_distortion(img, k1=k1)
    return img

In [None]:
%%time
delayed = util.get_delayed(True)
fish_frames = {}
fish_crops = {}
fish_channels = set()
fish_timepoints = set()
for msg in readers.send_eaton_fish(
    fish_filename,
    r"fov=(?P<v>\d+)_config=(?P<c>\w+)_t=(?P<t>\d+)",
    slices=dict(t=None, v=[8]),
    delayed=delayed,
):
    # print(msg["metadata"],msg["image"].shape)
    fish_img = msg["image"]
    # fish_img_corrected = delayed(calibrate_image)(fish_img, k1=k1)
    fish_img_corrected = delayed(skimage.img_as_float32)(fish_img)
    t = msg["metadata"]["t"]
    channel = msg["metadata"]["channel"]
    fish_channels.add(channel)
    fish_timepoints.add(t)
    fish_frames[(t, channel)] = fish_img_corrected
fish_channels = list(sorted(fish_channels))
fish_timepoints = list(sorted(fish_timepoints))

In [None]:
fish_frames0 = dask.compute(fish_frames)[0]

In [None]:
%%time
stacks = {}
for channel in fish_channels:
    for timepoint_idx, timepoint in enumerate(fish_timepoints):
        img = fish_frames0[(timepoint, channel)]
        if channel not in stacks:
            stacks[channel] = np.full((len(fish_timepoints), *img.shape), np.nan)
        stacks[channel][timepoint_idx, :, :] = img

In [None]:
stacks["GFP"].shape

In [None]:
stacks["GFP"][3:9].max(axis=0)

In [None]:
stacks["GFP"][3:9].max(axis=0)

In [None]:
display_image(stacks["GFP"][3:9].max(axis=0), scale=0.99)

In [None]:
display_image(stacks["GFP"][3:9].min(axis=0), scale=0.99)

In [None]:
display_image(
    stacks["GFP"][3:9].max(axis=0) - stacks["GFP"][3:9].min(axis=0), scale=0.99
)

In [None]:
display_image(stacks["GFP"][:9].max(axis=0) - stacks["GFP"][:9].min(axis=0), scale=0.99)

In [None]:
info

# Drift correction

In [None]:
imgs = {t: nd2.get_frame_2D(v=8, c=0, t=t)[:500, :500] for t in trange(225)}

In [None]:
hv.HoloMap({k: ui.RevImage(v) for k, v in imgs.items()})

In [None]:
%%time
diag = util.tree()
trenches, info = trench_detection.find_trenches(
    imgs[0], width_to_pitch_ratio=1.4 / 3.5, join_info=False, diagnostics=diag
)

In [None]:
diag["bboxes"]

In [None]:
%%time
diag2 = util.tree()
trenches2, info2 = trench_detection.find_trenches(
    imgs[20], width_to_pitch_ratio=1.4 / 3.5, join_info=False, diagnostics=diag2
)

In [None]:
diag2["bboxes"]

In [None]:
%%time
diag3 = util.tree()
trenches3, info3 = trench_detection.find_trenches(
    imgs[210], width_to_pitch_ratio=1.4 / 3.5, join_info=False, diagnostics=diag3
)

In [None]:
diag3["bboxes"]

In [None]:
ui.RevImage(imgs[210]) * trench_detection.plot_trenches(trenches2)

In [None]:
ui.RevImage(imgs[210]) * trench_detection.plot_trenches(trenches)

In [None]:
trench_detection.plot_trenches(trenches2).opts(
    hv.opts.Rectangles(line_color="blue")
) * trench_detection.plot_trenches(trenches3)

In [None]:
plots = {}
t_min = 70
t_delta = 50
for t in range(t_min, t_min + t_delta + 1, 3):
    # for t in [t_min, t_min+t_delta+1]:
    crop = get_crop(imgs[t], trenches2, 12)
    pts = trench_cell_endpoints(crop)
    plots[t] = ui.RevImage(crop).opts(frame_width=40) * hv.Points(pts + 0.5).opts(
        color="red", size=4
    )
hv.HoloMap(plots)

In [None]:
# idx=59
idx = 100
crop = get_crop(imgs[idx], trenches2, 12)
pts = trench_cell_endpoints(crop)
ui.RevImage(crop).opts(frame_width=40) * hv.Points(pts + 0.5).opts(color="red", size=4)
# hv.Points(pts)

In [None]:
%%time
diag_d = util.tree()
shift = find_trench_drift(imgs[20], imgs[101], trenches, diagnostics=diag_d)

In [None]:
diag_d["features1"] * diag_d["features2"].opts(color="red")

In [None]:
diag_d["correspondences"]

In [None]:
plots = {}
for t in trange(225):
    crop = get_crop(imgs[t], trenches2, 5)
    pts = trench_cell_endpoints(crop)
    # plots[t] = ui.RevImage(crop).opts(frame_width=40) * hv.Points(pts + 0.5).opts(
    #     color="red", size=4
    # )
    plots[t] = hv.Curve([(0, t), (2, 4)])

In [None]:
plots = {}
for t in trange(225):
    crop = get_crop(imgs[t], trenches2, 5)
    pts = trench_cell_endpoints(crop)
    plots[t] = ui.RevImage(crop).opts(frame_width=40) * hv.Points(pts + 0.5).opts(
        color="red", size=4
    )

In [None]:
%%time
k1 = 8.947368421052635e-10
img1 = image.correct_radial_distortion(nd2.get_frame_2D(v=8, t=70, c=0), k1=k1)
img2 = image.correct_radial_distortion(nd2.get_frame_2D(v=8, t=110, c=0), k1=k1)

In [None]:
%%time
trenches = trench_detection.find_trenches(img1, width_to_pitch_ratio=2 / 3.5)

In [None]:
regrid(
    hv.HoloMap({t: ui.RevImage(x) for t, x in enumerate([img1, img2])})
) * hv.HoloMap({t: trench_detection.plot_trenches(trenches) for t in range(2)})

In [None]:
%%time
diag_d = util.tree()
shift = find_trench_drift(img1, img2, trenches, diagnostics=diag_d)

In [None]:
len(trenches)

In [None]:
diag_d["features"]

# Drift correction test

In [None]:
k1 = 8.947368421052635e-10
imgs = {
    t: image.correct_radial_distortion(nd2.get_frame_2D(v=8, c=0, t=t), k1=k1)[
        :500, :500
    ]
    for t in trange(225)
}

In [None]:
imgs = {t: nd2.get_frame_2D(v=8, c=0, t=t)[:500, :500] for t in trange(22)}

In [None]:
img_x = imgs[20].copy()
y = 167
x = 476
img_x[y, x] = 20_000
img_x[y + 1, x] = 0
img_x[y - 1, x] = 0
img_x[y, x - 1] = 0
img_x[y, x + 1] = 0

In [None]:
%%time
diag = util.tree()
trenches, info = trench_detection.find_trenches(
    imgs[20],
    # angle=np.deg2rad(89),
    # pitch=16.482897384305836,
    width_to_pitch_ratio=1.4 / 3.5,
    join_info=False,
    diagnostics=diag,
)

In [None]:
info

In [None]:
diag["labeling"]["set_finding"]["image_with_lines"]

In [None]:
diag["labeling"]["set_finding"].keys()

In [None]:
diag["labeling"]["set_finding"]["profiles"]

In [None]:
diag["bboxes"]

In [None]:
hv.HoloMap({t: ui.RevImage(imgs[t]) for t in ts}) * hv.HoloMap(
    {
        t: trench_detection.plot_trenches(
            geometry.filter_rois(geometry.shift_rois(trenches, shifts[t]), image_limits)
        )
        for t in ts
    },
) * hv.HoloMap({t: features[t] for t in ts}).opts(color="red")

In [None]:
image_limits = geometry.get_image_limits(imgs[0].shape)

In [None]:
def mock_features(img):
    # print(img.shape);0/0
    # return np.array([[img.shape[1] / 2, 5], [img.shape[1] / 2, img.shape[0] - 5]])
    return np.array([[0, 0], [img.shape[1] - 1, img.shape[0] - 1]])

In [None]:
def median_translation(data, diagnostics=None, **kwargs):
    # plt.hist(data[:, 1, 0] - data[:, 0, 0], bins=10);
    median = np.median(data[:, 1, :] - data[:, 0, :], axis=0)
    return median
    # return np.array([median[0], 0])

In [None]:
x = imgs[0][:100, :100].copy()

In [None]:
x[:] = 0

In [None]:
rois = trenches  # [(trenches["trench_set"] == 1)][21:25]
t0 = 20
t1 = 225
ts = np.arange(t0 + 1, t1)
# ts = np.arange(200, t1)
# ts = [160,183]
# ts = [100]
shifts = {}
shifts[t0] = np.array([0, 0])
# shifts[t0] = np.array([0, 0])
features_plot = {}
rois_plot = {}
rois_final_plot = {}
for t in tqdm(ts):
    diag = {}
    # shift = drift.find_feature_drift(
    #     imgs[t - 1],
    #     imgs[t],
    #     trenches,
    #     initial_shift=shifts[t - 1],
    #     estimation_func=median_translation,
    #     max_iterations=3,
    #     diagnostics=diag,
    # )
    # if t > 180:
    #     f = drift.trench_cell_endpoints
    # else:
    #     f = mock_features
    shift = drift.find_feature_drift(
        imgs[t0],
        # imgs[t-1],
        imgs[t],
        rois,
        initial_shift=shifts[t0],
        # feature_func=mock_features,
        # feature_func=drift.centroid,
        estimation_func=median_translation,
        max_iterations=2,
        diagnostics=diag,
    )
    shifts[t] = shift
    # correspondences[t] = diag["correspondences"]
    features_plot[t] = diag["features2"]
    rois_plot[t] = diag["rois2"]
    rois_final_plot[t] = diag["rois_final"]

In [None]:
ts2 = ts[-30:]
hv.HoloMap({t: ui.RevImage(imgs[t]) for t in ts2}) * hv.HoloMap(
    {t: rois_final_plot[t] for t in ts2},
) * hv.HoloMap({t: features_plot[t] for t in ts2}).opts(color="red")

In [None]:
hv.HoloMap({t: ui.RevImage(imgs[t]) for t in ts}) * hv.HoloMap(
    {t: rois_final_plot[t] for t in ts},
) * hv.HoloMap({t: features_plot[t] for t in ts}).opts(color="red")

In [None]:
hv.HoloMap({t: ui.RevImage(imgs[t]) for t in ts}) * hv.HoloMap(
    {t: rois_final_plot[t] for t in ts},
) * hv.HoloMap({t: features_plot[t] for t in ts}).opts(color="red")

In [None]:
hv.HoloMap({t: ui.RevImage(imgs[t]) for t in ts}) * hv.HoloMap(
    {
        t: trench_detection.plot_trenches(geometry.filter_rois(trenches, image_limits))
        for t in ts
    },
) * hv.HoloMap({t: features[t] for t in ts}).opts(color="red")

In [None]:
features[100].options(invert_yaxis=True)

In [None]:
trenches

In [None]:
roi = trenches.iloc[1]
top = (roi["top_x"], roi["top_y"])
bottom = (roi["bottom_x"], roi["bottom_y"])

In [None]:
res = trench_detection.profile.profile_line(
    imgs[21], top[::-1], bottom[::-1], linewidth=10
)

In [None]:
plt.imshow(res)

In [None]:
plt.plot(res.mean(axis=0));

In [None]:
def iter_roi_lines(rois):
    index = rois.index.values
    top_x = rois["top_x"].values
    top_y = rois["top_y"].values
    bottom_x = rois["bottom_x"].values
    bottom_y = rois["bottom_y"].values
    for i in range(len(index)):
        roi_idx = index[i]
        yield roi_idx, np.array([top_x[i], top_y[i]]), np.array(
            [bottom_x[i], bottom_y[i]]
        )


def line_shift(img, rois, linewidth=8):
    for roi_idx, top, bottom in iter_roi_lines(rois):
        res = trench_detection.profile.profile_line(
            img, top[::-1], bottom[::-1], linewidth=linewidth
        )
    pass


def find_line_drift(
    img1,
    img2,
    trenches,
    diagnostics=None,
):
    pass

In [None]:
profiles = {}
for roi_idx, top, bottom in iter_roi_lines(rois):
    res = trench_detection.profile.profile_line(
        imgs[0], top[::-1], bottom[::-1], linewidth=10
    )
    profiles[roi_idx] = res

In [None]:
profiles2 = {}
for roi_idx, top, bottom in iter_roi_lines(rois):
    res = trench_detection.profile.profile_line(
        imgs[40], top[::-1], bottom[::-1], linewidth=10
    )
    profiles2[roi_idx] = res

In [None]:
idx = 20
plt.plot(profiles[idx].mean(axis=0))
plt.plot(profiles2[idx].mean(axis=0));

In [None]:
plt.plot(np.mean(np.concatenate(list(profiles.values())), axis=0))

In [None]:
for idx, res in profiles.items():
    horiz = res.mean(axis=0)
    horiz = horiz / horiz.max()
    plt.plot(horiz);

In [None]:
for idx, res in profiles.items():
    vert = res.mean(axis=1)
    # vert = vert / vert.max()
    plt.plot(vert);