# Imports

In [None]:
import itertools as it
import os
import re
from collections import namedtuple
from functools import partial
from pathlib import Path

import dask
import distributed
import h5py
import holoviews as hv
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import scipy
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from holoviews.operation.datashader import regrid
from tqdm.auto import tqdm, trange

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from paulssonlab.image_analysis import *
from paulssonlab.image_analysis.ui import display_image

In [None]:
%load_ext pyinstrument

In [None]:
hv.extension("bokeh")

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="06:00:00",
    memory="2GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(1)

In [None]:
cluster.adapt(maximum=20)

# Trench detection

In [None]:
# filename = "/home/jqs1/scratch/jqs1/microscopy/210511/RBS_ramp.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230213/230213induction.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230215/230215induction.nd2" #v=7
# filename = "/home/jqs1/scratch/jqs1/microscopy/230326/230326promrbs.nd2" #v=8,t=10
filename = "/home/jqs1/scratch/jqs1/microscopy/230404/230404_rbsprom.nd2"

In [None]:
nd2 = nd2reader.ND2Reader(filename)

# FISH correction

In [None]:
# filename = "/home/jqs1/scratch/jqs1/microscopy/230213/230213induction.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230215/230215induction.nd2" #v=7
# filename = "/home/jqs1/scratch/jqs1/microscopy/230326/230326promrbs.nd2" #v=8,t=10
filename = "/home/jqs1/scratch/jqs1/microscopy/230404/230404_rbsprom.nd2"
fish_filename = Path(filename).parent / "FISH/real_run"

In [None]:
k1 = 8.947368421052635e-10

In [None]:
def calibrate_image(img, k1=0):
    img = skimage.img_as_float32(img)
    img = image.correct_radial_distortion(img, k1=k1)
    return img

In [None]:
%%time
delayed = util.get_delayed(True)
fish_frames = {}
fish_crops = {}
fish_channels = set()
fish_timepoints = set()
for msg in readers.send_eaton_fish(
    fish_filename,
    r"fov=(?P<v>\d+)_config=(?P<c>\w+)_t=(?P<t>\d+)",
    slices=dict(t=None, v=[8]),
    delayed=delayed,
):
    # print(msg["metadata"],msg["image"].shape)
    fish_img = msg["image"]
    # fish_img_corrected = delayed(calibrate_image)(fish_img, k1=k1)
    fish_img_corrected = delayed(skimage.img_as_float32)(fish_img)
    t = msg["metadata"]["t"]
    channel = msg["metadata"]["channel"]
    fish_channels.add(channel)
    fish_timepoints.add(t)
    fish_frames[(t, channel)] = fish_img_corrected
fish_channels = list(sorted(fish_channels))
fish_timepoints = list(sorted(fish_timepoints))

In [None]:
fish_frames0 = dask.compute(fish_frames)[0]

In [None]:
%%time
stacks = {}
for channel in fish_channels:
    for timepoint_idx, timepoint in enumerate(fish_timepoints):
        img = fish_frames0[(timepoint, channel)]
        if channel not in stacks:
            stacks[channel] = np.full((len(fish_timepoints), *img.shape), np.nan)
        stacks[channel][timepoint_idx, :, :] = img

In [None]:
stacks["GFP"].shape

In [None]:
stacks["GFP"][3:9].max(axis=0)

In [None]:
stacks["GFP"][3:9].max(axis=0)

In [None]:
display_image(stacks["GFP"][3:9].max(axis=0), scale=0.99)

In [None]:
display_image(stacks["GFP"][3:9].min(axis=0), scale=0.99)

In [None]:
display_image(
    stacks["GFP"][3:9].max(axis=0) - stacks["GFP"][3:9].min(axis=0), scale=0.99
)

In [None]:
display_image(stacks["GFP"][:9].max(axis=0) - stacks["GFP"][:9].min(axis=0), scale=0.99)

In [None]:
info

# Drift correction

In [None]:
imgs = {t: nd2.get_frame_2D(v=8, c=0, t=t)[:500, :500] for t in trange(225)}

In [None]:
hv.HoloMap({k: ui.RevImage(v) for k, v in imgs.items()})

In [None]:
%%time
diag = util.tree()
trenches, info = trench_detection.find_trenches(
    imgs[0], width_to_pitch_ratio=1.4 / 3.5, join_info=False, diagnostics=diag
)

In [None]:
diag["bboxes"]

In [None]:
%%time
diag2 = util.tree()
trenches2, info2 = trench_detection.find_trenches(
    imgs[20], width_to_pitch_ratio=1.4 / 3.5, join_info=False, diagnostics=diag2
)

In [None]:
diag2["bboxes"]

In [None]:
%%time
diag3 = util.tree()
trenches3, info3 = trench_detection.find_trenches(
    imgs[210], width_to_pitch_ratio=1.4 / 3.5, join_info=False, diagnostics=diag3
)

In [None]:
diag3["bboxes"]

In [None]:
ui.RevImage(imgs[210]) * trench_detection.plot_trenches(trenches2)

In [None]:
ui.RevImage(imgs[210]) * trench_detection.plot_trenches(trenches)

In [None]:
trench_detection.plot_trenches(trenches2).opts(
    hv.opts.Rectangles(line_color="blue")
) * trench_detection.plot_trenches(trenches3)

In [None]:
plots = {}
t_min = 70
t_delta = 50
for t in range(t_min, t_min + t_delta + 1, 3):
    # for t in [t_min, t_min+t_delta+1]:
    crop = get_crop(imgs[t], trenches2, 12)
    pts = trench_cell_endpoints(crop)
    plots[t] = ui.RevImage(crop).opts(frame_width=40) * hv.Points(pts + 0.5).opts(
        color="red", size=4
    )
hv.HoloMap(plots)

In [None]:
# idx=59
idx = 100
crop = get_crop(imgs[idx], trenches2, 12)
pts = trench_cell_endpoints(crop)
ui.RevImage(crop).opts(frame_width=40) * hv.Points(pts + 0.5).opts(color="red", size=4)
# hv.Points(pts)

In [None]:
%%time
diag_d = util.tree()
shift = find_trench_drift(imgs[20], imgs[101], trenches, diagnostics=diag_d)

In [None]:
diag_d["features1"] * diag_d["features2"].opts(color="red")

In [None]:
diag_d["correspondences"]

In [None]:
plots = {}
for t in trange(225):
    crop = get_crop(imgs[t], trenches2, 5)
    pts = trench_cell_endpoints(crop)
    # plots[t] = ui.RevImage(crop).opts(frame_width=40) * hv.Points(pts + 0.5).opts(
    #     color="red", size=4
    # )
    plots[t] = hv.Curve([(0, t), (2, 4)])

In [None]:
plots = {}
for t in trange(225):
    crop = get_crop(imgs[t], trenches2, 5)
    pts = trench_cell_endpoints(crop)
    plots[t] = ui.RevImage(crop).opts(frame_width=40) * hv.Points(pts + 0.5).opts(
        color="red", size=4
    )

In [None]:
%%time
k1 = 8.947368421052635e-10
img1 = image.correct_radial_distortion(nd2.get_frame_2D(v=8, t=70, c=0), k1=k1)
img2 = image.correct_radial_distortion(nd2.get_frame_2D(v=8, t=110, c=0), k1=k1)

In [None]:
%%time
trenches = trench_detection.find_trenches(img1, width_to_pitch_ratio=2 / 3.5)

In [None]:
regrid(
    hv.HoloMap({t: ui.RevImage(x) for t, x in enumerate([img1, img2])})
) * hv.HoloMap({t: trench_detection.plot_trenches(trenches) for t in range(2)})

In [None]:
%%time
diag_d = util.tree()
shift = find_trench_drift(img1, img2, trenches, diagnostics=diag_d)

In [None]:
len(trenches)

In [None]:
diag_d["features"]

# Drift correction test

In [None]:
k1 = 8.947368421052635e-10
imgs = {
    t: image.correct_radial_distortion(nd2.get_frame_2D(v=8, c=0, t=t), k1=k1)[
        :500, :500
    ]
    for t in trange(225)
}

In [None]:
imgs = {t: nd2.get_frame_2D(v=8, c=0, t=t)[:500, :500] for t in trange(22)}

In [None]:
img_x = imgs[20].copy()
y = 167
x = 476
img_x[y, x] = 20_000
img_x[y + 1, x] = 0
img_x[y - 1, x] = 0
img_x[y, x - 1] = 0
img_x[y, x + 1] = 0

In [None]:
%%time
diag = util.tree()
trenches, info = trench_detection.find_trenches(
    imgs[20],
    angle=np.deg2rad(50),
    pitch=16.482897384305836,
    width_to_pitch_ratio=1.4 / 3.5,
    join_info=False,
    diagnostics=diag,
)

In [None]:
diag["labeling"]["set_finding"]["image_with_lines"]

In [None]:
info

In [None]:
diag["labeling"]["set_finding"].keys()

In [None]:
diag["labeling"]["set_finding"]["profiles"].Curve.XXX

In [None]:
diag["labeling"]["set_finding"]["stacked_profile"]

In [None]:
diag["bboxes"]

In [None]:
diag["label_1"]["find_trench_ends"].keys()

In [None]:
diag["label_1"]["find_trench_ends"]["

In [None]:
x = np.random.random((500, 500)) + np.arange(500)[:, np.newaxis] / 100
x = (
    np.arange(500)[np.newaxis, :] * np.ones(500)[:, np.newaxis]
)  # * np.random.random((500,500))# / 100

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(geometry.get_roi_crop(x == 467, trenches, 29)[:, :])

In [None]:
ui.RevImage(x == 467) * trench_detection.plot_trenches(
    trenches, lines=True, labels=True
)

In [None]:
ui.RevImage(imgs[20]) * trench_detection.plot_trenches(trenches, lines=True)

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(geometry.get_roi_crop(imgs[20], trenches, 29)[160:].T)

In [None]:
image_limits = geometry.get_image_limits(imgs[0].shape)

In [None]:
def median_translation(data, diagnostics=None, **kwargs):
    # plt.hist(data[:, 1, 0] - data[:, 0, 0], bins=10);
    median = np.median(data[:, 1, :] - data[:, 0, :], axis=0)
    return np.array([median[0], 0])

In [None]:
def mock_features(img):
    # print(img.shape);0/0
    return np.array([[img.shape[1] / 2, 5], [img.shape[1] / 2, img.shape[0] - 5]])

In [None]:
rois = trenches[(trenches["trench_set"] == 2) & (trenches["top_x"] < 100)][2:3]
t0 = 20
t1 = 225
# ts = np.arange(t0 + 1, t1)
# ts = [160,183]
ts = [21]
shifts = {}
shifts[t0] = np.array([0, 0])
# shifts[t0] = np.array([0, 0])
correspondences = {}
features = {}
for t in tqdm(ts):
    diag = {}
    # shift = drift.find_feature_drift(
    #     imgs[t - 1],
    #     imgs[t],
    #     trenches,
    #     initial_shift=shifts[t - 1],
    #     estimation_func=median_translation,
    #     max_iterations=3,
    #     diagnostics=diag,
    # )
    # if t > 180:
    #     f = drift.trench_cell_endpoints
    # else:
    #     f = mock_features
    shift = drift.find_feature_drift(
        imgs[t0],
        imgs[t],
        rois,
        initial_shift=shifts[t0],
        # feature_func=mock_features,
        estimation_func=median_translation,
        max_iterations=1,
        diagnostics=diag,
    )
    shifts[t] = shift
    # correspondences[t] = diag["correspondences"]
    features[t] = diag["features2"]

In [None]:
drift.trench_cell_endpoints(x)

In [None]:
x.shape

In [None]:
x = geometry.get_roi_crop(imgs[20], trenches, 33)

In [None]:
x.shape

In [None]:
rois

In [None]:
hv.HoloMap({t: ui.RevImage(imgs[t]) for t in ts}) * hv.HoloMap(
    {
        t: trench_detection.plot_trenches(
            geometry.filter_rois(geometry.shift_rois(trenches, shifts[t]), image_limits)
        )
        for t in ts
    },
) * hv.HoloMap({t: features[t] for t in ts}).opts(color="red")

In [None]:
shifts

In [None]:
ui.RevImage(nd2.get_frame_2D(v=0, t=0, c=1)[::4, ::4])

In [None]:
display_image(nd2.get_frame_2D(v=0, t=0, c=1), scale=0.99, downsample=2)

In [None]:
trenches

# RevImage fix

In [None]:
x = nd2reader.ND2Reader("/home/jqs1/scratch/jqs1/microscopy/230213/230213induction.nd2")

In [None]:
x.metadata["channels"]

In [None]:
z = x.get_frame_2D(t=0, v=0, c=0)

In [None]:
display_image(z, downsample=4, scale=0.99)

In [None]:
zz = z[::4, ::4]

In [None]:
zz.shape[0]

In [None]:
hv.Image(zz, bounds=(0, 0, zz.shape[1], zz.shape[0])).opts(
    aspect=zz.shape[1] / zz.shape[0]
)

In [None]:
hv.Image(zz, bounds=(0, zz.shape[0], zz.shape[1], 0)).opts(
    aspect=zz.shape[1] / zz.shape[0]
)

In [None]:
hv.Image(zz, bounds=(0, zz.shape[0], zz.shape[1], 0)).opts(
    aspect=zz.shape[1] / zz.shape[0],
    invert_yaxis=True,
)

In [None]:
hv.Image(zz, bounds=(0, 0, zz.shape[1], zz.shape[0])).opts(
    aspect=zz.shape[1] / zz.shape[0],
    invert_yaxis=True,
)

In [None]:
hv.Image(zz[::-1], bounds=(0, 0, zz.shape[1], zz.shape[0])).opts(
    aspect=zz.shape[1] / zz.shape[0],
    invert_yaxis=True,
)

In [None]:
hv.Image(zz[::-1], bounds=(0, 0, zz.shape[1], zz.shape[0])).opts(
    aspect=zz.shape[1] / zz.shape[0],
    invert_yaxis=True,
) * hv.Points([(100, 100)]).opts(size=100, color="red")

In [None]:
hv.Image(
    zz,
    bounds=(0, 0, zz.shape[1], zz.shape[0]),
    extents=(0, 0, zz.shape[1], zz.shape[0]),
).opts(aspect=zz.shape[1] / zz.shape[0])

In [None]:
hv.Image(
    zz,
    bounds=(0, 0, zz.shape[1], zz.shape[0]),
    extents=(0, 0, zz.shape[1], zz.shape[0]),
).opts(aspect=zz.shape[1] / zz.shape[0], invert_yaxis=True)

In [None]:
hv.Overlay(
    [
        hv.Image(
            zz[::-1],
            bounds=(0, 0, zz.shape[1], zz.shape[0]),
            extents=(0, 0, zz.shape[1], zz.shape[0]),
        ).opts(aspect=zz.shape[1] / zz.shape[0])
    ]
).opts(invert_yaxis=True)

In [None]:
(
    hv.Overlay(
        [
            hv.Image(
                zz[::-1],
                bounds=(0, 0, zz.shape[1], zz.shape[0]),
                extents=(0, 0, zz.shape[1], zz.shape[0]),
            ).opts(aspect=zz.shape[1] / zz.shape[0])
        ]
    ).opts(invert_yaxis=True)
    * hv.Points([(100, 100)]).opts(size=100, color="red")
)

In [None]:
(
    hv.Overlay(
        [
            hv.Image(
                zz[::-1],
                bounds=(0, 0, zz.shape[1], zz.shape[0]),
                # extents=(0, 0, zz.shape[1], zz.shape[0]),
            ).opts(aspect=zz.shape[1] / zz.shape[0])
        ]
    ).opts(invert_yaxis=True)
    * hv.Points([(100, 100)]).opts(size=100, color="red")
).opts(invert_yaxis=True)

In [None]:
(
    hv.Overlay(
        [
            hv.Image(
                zz[::-1],
                bounds=(0, zz.shape[0], zz.shape[1], 0),
                # extents=(0, 0, zz.shape[1], zz.shape[0]),
            ).opts(aspect=zz.shape[1] / zz.shape[0])
        ]
    ).opts(invert_yaxis=True)
    * hv.Points([(100, 100)]).opts(size=100, color="red")
).opts(invert_yaxis=True)

In [None]:
hv.Image(
    zz[::-1],
    bounds=(0, zz.shape[0], zz.shape[1], 0),
    # extents=(0, 0, zz.shape[1], zz.shape[0]),
).opts(
    hv.opts.Image(aspect=zz.shape[1] / zz.shape[0], invert_yaxis=True)
)  # * hv.Points([(100, 100)]).opts(size=100, color="red")).opts(invert_yaxis=True)

In [None]:
hv.Image(
    zz[::-1], bounds=(0, -zz.shape[0], zz.shape[1], 0), extents=(0, 0, 100, 200)
).opts(invert_yaxis=True)

In [None]:
zz.shape[1]

In [None]:
%%time
diag = util.tree()
t = trench_detection.find_trenches(z, width_to_pitch_ratio=1.4 / 3.5, diagnostics=diag)

In [None]:
diag["bboxes"]

In [None]:
t

In [None]:
ui.RevImage(z, scale=0.997) * hv.Points([(100, 100)]).opts(size=10, color="red")

In [None]:
ui.RevImage(z, scale=0.997) * trench_detection.plot_trenches(
    t
)  # .opts(invert_yaxis=True)

In [None]:
hv.Image(
    zz[::1],
    bounds=(0, zz.shape[0], zz.shape[1], 0),
    extents=(0, -zz.shape[0], zz.shape[1], zz.shape[0]),
).redim.range(z=(0, np.percentile(zz, 99.7))).opts(
    hv.opts.Image(aspect=zz.shape[1] / zz.shape[0], invert_yaxis=True)
) * hv.Points(
    [(100, 100)]
).opts(
    size=10, color="red"
)

In [None]:
hv.Image(
    zz[::1],
    bounds=(0, -zz.shape[0], zz.shape[1], 0),
    extents=(0, 0, zz.shape[1], zz.shape[0]),
).opts(hv.opts.Image(aspect=zz.shape[1] / zz.shape[0], invert_yaxis=True))

In [None]:
hv.Image(
    zz[::-1],
    bounds=(0, zz.shape[0], zz.shape[1], 0),
    extents=(0, 0, zz.shape[1], zz.shape[0]),
).opts(
    hv.opts.Image(aspect=zz.shape[1] / zz.shape[0], invert_yaxis=True)
)  # * hv.Points([(100, 100)]).opts(size=100, color="red")).opts(invert_yaxis=True)

In [None]:
ui.Image(z) * trench_detection.plot_trenches(t).opts(invert_yaxis=True)

In [None]:
np.percentile(z, 99.9)

In [None]:
a = hv.Image(np.random.random((10, 10)))

In [None]:
a.vdims[0].range

In [None]:
ui.Image(z, scale=0.997)

In [None]:
ui.Image(np.random.random((10, 10)))

In [None]:
%%opts?

In [None]:
%opts [invert_yaxis=True]

In [None]:
ui.Image(np.random.random((10, 10))) * hv.Points([(1, 1)]).opts(
    size=10, color="red", invert_yaxis=True
)

In [None]:
ui.Image(np.random.random((10, 10))) * hv.Points([(0, 0)]).opts(size=10, color="red")

In [None]:
a = (
    hv.Image(
        z[::-4, ::4],
        bounds=(
            0,
            0,
            z.shape[1],
            z.shape[0],
        ),  # , extents=(0, z.shape[0], z.shape[1], 0)
    )
    .redim.range(z=(z.min(), np.percentile(z, 99.7)))
    .opts(
        aspect=z.shape[1] / z.shape[0],
        invert_yaxis=True,
    )
)

In [None]:
ui.RevImage(np.random.random((10, 10)))

In [None]:
(
    ui.RevImage(np.random.random((10, 10)))  # .opts(invert_yaxis=False)
    * hv.Points([(1, 1)]).opts(size=10, color="red")
).opts(hv.opts.Overlay(invert_yaxis=True), hv.opts.Image(invert_yaxis=False))

In [None]:
(
    ui.RevImage(np.random.random((10, 10))).opts(invert_yaxis=False)
    * hv.Points([(1, 1)]).opts(size=10, color="red")
).opts(invert_yaxis=True)

In [None]:
a.opts

In [None]:
hv.Image(z[::-4, ::4], bounds=(0, 0, shape[1], shape[0])).opts(
    hv.opts.Image(invert_yaxis=True)
)

In [None]:
ui.RevImage(np.asarray(z))

In [None]:
display_image(z, scale=0.9)

# Cross-correlation

In [None]:
nd2 = nd2reader.ND2Reader("/home/jqs1/scratch/jqs1/microscopy/210511/RBS_ramp.nd2")

In [None]:
c = 2
f1 = nd2.get_frame_2D(v=30, t=170, c=c)
f2 = nd2.get_frame_2D(v=30, t=171, c=c)

In [None]:
display_image(f1, scale=0.99)

# Improved trench lines

In [None]:
k1 = 8.947368421052635e-10
img = image.correct_radial_distortion(nd2.get_frame_2D(t=10, v=0, c=0), k1=k1)

In [None]:
%%time
diag = util.tree()
trenches, info = trench_detection.find_trenches(
    img, width_to_pitch_ratio=1.4 / 3.5, join_info=False, diagnostics=diag
)

In [None]:
info

In [None]:
diag["labeling"]["set_finding"]["image_with_lines"]

In [None]:
display_image(img, scale=0.9)

In [None]:
%%time
diag = util.tree()
angle = np.deg2rad(-0.5)
rhos = np.linspace(*image.hough_bounds(img.shape, angle), 10)
(
    profiles,
    stacked_points,
    anchor_idx,
) = trench_detection.profile.get_trench_line_profiles(
    img, angle, rhos, diagnostics=diag
)

In [None]:
plt.plot(profiles.T);

In [None]:
image.hough_bounds(img.shape, angle)

In [None]:
%%time
diag2 = util.tree()
profiles2, stacked_points2 = trench_detection.profile.angled_profiles(
    img, angle, rhos, diagnostics=diag2
)

In [None]:
# TODO: there is a difference!!
plt.plot(profiles2.T);

In [None]:
profiles2.T.shape

In [None]:
(profiles.shape, profiles2.shape)

In [None]:
plt.plot(profiles.T[20:, 6])
plt.plot(profiles2.T[:, 6]);

In [None]:
plt.plot(profiles.T[21:, 6] - profiles2.T[:-14, 6]);

In [None]:
eps = 1e-5
assert -np.pi / 2 + eps < theta < np.pi / 2 - eps

In [None]:
x_lim, y_lim = geometry.get_image_limits(img.shape)

In [None]:
import skspatial
from skspatial.objects import Line, LineSegment

In [None]:
# coordinate system is flipped y -> -y

In [None]:
x = [
    skspatial.objects.Point((0, 0)),
    skspatial.objects.Point((1, 1)),
    skspatial.objects.Point((0, 0)),
    skspatial.objects.Point((3, 1)),
    skspatial.objects.Point((3, 1)),
    skspatial.objects.Point((1, 3)),
]

In [None]:
np.unique(x, axis=0, return_index=True)[1]

In [None]:
(rho_min, rho_max)

In [None]:
# %%pyinstrument
lines = []
offsets = []
angle = np.deg2rad(90)
rho_min, rho_max = image.hough_bounds(img.shape, angle)
for rho in np.linspace(rho_min, rho_max, 100):
    top, bottom, offset = trench_detection.profile.angled_line_profile_endpoints(
        angle, rho, x_lim, y_lim
    )
    if top is None:
        continue
    lines.append([top, bottom])
    offsets.append(offset)

In [None]:
hv.Path(lines).options(invert_yaxis=True)