# Imports

In [None]:
import itertools as it
import os
import re
from collections import namedtuple
from functools import partial
from pathlib import Path

import dask
import distributed
import h5py
import holoviews as hv
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import scipy
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from tqdm.auto import tqdm

IDX = pd.IndexSlice

In [None]:
%load_ext pyinstrument

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from paulssonlab.image_analysis import *
from paulssonlab.util.ui import display_image

In [None]:
hv.extension("bokeh")

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="06:00:00",
    memory="2GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(1)

In [None]:
cluster.adapt(maximum=20)

# Trench detection

In [None]:
# filename = "/home/jqs1/scratch/jqs1/microscopy/230213/230213induction.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230215/230215induction.nd2" #v=7
filename = "/home/jqs1/scratch/jqs1/microscopy/230326/230326promrbs.nd2"  # v=8,t=10
# filename = "/home/jqs1/scratch/jqs1/microscopy/230404/230404_rbsprom.nd2"

In [None]:
nd2 = nd2reader.ND2Reader(filename)

In [None]:
img = nd2.get_frame_2D(v=8, c=0, t=10)

In [None]:
nd2.metadata["channels"]

In [None]:
trenches

In [None]:
nd2.metadata

In [None]:
display_image(img / img.max() * 4)

In [None]:
%%time
diag = util.tree()
trenches = trench_detection.find_trenches(img, diagnostics=diag)

In [None]:
%%time
diag2 = util.tree()
trenches2 = trench_detection.find_trenches(
    img, peak_func=trench_detection.peaks.find_peaks, diagnostics=diag2
)

In [None]:
trenches

In [None]:
diag["label_1"]["find_trench_ends"]["image_with_trenches"]

In [None]:
diag2["label_1"]["find_trench_ends"]["image_with_trenches"]

In [None]:
diag["label_1"]["find_trench_lines"]["hough_0"]["peak_func"].keys()

In [None]:
diag["label_1"]["find_trench_lines"]["hough_0"]["peak_func"]["spectrum"]

In [None]:
diag["labeling"].keys()

In [None]:
%%time
diag3 = util.tree()
img_bin = trench_detection.set_finding.binarize_trench_image(img, diagnostics=diag3)

In [None]:
img_bin

In [None]:
plt.hist(img.flat, bins=300, log=True);

In [None]:
from paulssonlab.image_analysis.image import (
    gaussian_box_approximation,
    normalize_componentwise,
    remove_large_objects,
)

In [None]:
lowpass_radius = 500
img_lowpass = gaussian_box_approximation(img, lowpass_radius)

In [None]:
0

In [None]:
%%time
rb = skimage.restoration.rolling_ball(img, radius=30)

In [None]:
display_image(rb, scale=True)

In [None]:
display_image((img - rb) / img.max() * 20)

In [None]:
?skimage.filters.threshold_sauvola

In [None]:
display_image(img > skimage.filters.threshold_sauvola(img, window_size=7))

In [None]:
display_image(img > skimage.filters.threshold_otsu

In [None]:
display_image(img / img.max() * 30)

In [None]:
display_image(img - img_lowpass, scale=True)

In [None]:
display_image(img - img_lowpass, scale=True)

In [None]:
display_image(img_bin[1])

In [None]:
diag["labeling"]["binarize_trench_image"].keys()

In [None]:
diag["labeling"]["binarize_trench_image"]["thresholded_image"]

In [None]:
diag_t["labeling"]["find_trench_lines"]["hough_0"]["peak_func"]["refined_points"]

In [None]:
diag["label_2"]["find_trench_ends"].keys()

In [None]:
??trench_detection.hough.find_periodic_lines

In [None]:
%%time
h, theta, rho = trench_detection.hough.hough_line_intensity(
    img_t, theta=np.linspace(-np.pi / 5, np.pi / 5, 1000)
)

In [None]:
display_image(h, scale=True)

In [None]:
%%time
smooth = 4
diff_h = np.diff(h.astype(np.int_), axis=1)  # TODO: is diff necessary??
diff_h_std = diff_h.std(axis=0)  # / diff_h.max(axis=0)
if smooth:
    diff_h_std_smoothed = scipy.ndimage.gaussian_filter1d(diff_h_std, smooth)

In [None]:
plt.plot(diff_h_std_smoothed)

In [None]:
h.shape

In [None]:
h.shape

In [None]:
imor

In [None]:
scipy.signal.periodogram

In [None]:
%%time
nfft = 2**16
max_period = None
freqs, spectrum = scipy.signal.periodogram(
    h, window="hann", nfft=nfft, scaling="spectrum", axis=0
)
if max_period:
    spectrum[:max_period] = 0
# pitch_idx = spectrum.argmax()
# pitch = 1 / freqs[pitch_idx]

In [None]:
spectrum.shape

In [None]:
display_image(spectrum / spectrum.max() * 5)

In [None]:
h[0].shape

In [None]:
np.deg2rad(5) / np.pi

# Dewarping

In [None]:
%%time
# k1 = -5e-9
# k1 = 2e-9
# k1 = 1.5e-9
k1 = 8.947368421052635e-10
img_t = image.correct_radial_distortion(img, k1=k1)

In [None]:
%%time
res = trench_detection.hough.find_periodic_lines(
    img_t, theta=np.linspace(-np.deg2rad(10), np.deg2rad(10), 400)
)

In [None]:
%%time
res2 = trench_detection.hough.find_periodic_lines(img_t, theta=[np.deg2rad(0)])

In [None]:
%%time
diag = util.tree()
trenches, info = trench_detection.find_trenches(
    img_t,
    # angle=np.deg2rad(0.001),
    join_info=False,
    width=8,
    # width_to_line_width_ratio=2,
    # width_to_pitch_ratio=None,
    # peak_func=trench_detection.peaks.find_peaks,
    diagnostics=diag,
)

In [None]:
info

In [None]:
diag["bboxes"]

In [None]:
x_lim = np.array([0, 9])
y_lim = np.array([2, 7])
thetas = np.linspace(0, 2 * np.pi, 100)
for x0, x1 in [
    ((x_lim[0], y_lim.sum() / 2), (x_lim[0] + 1e-9, y_lim.sum() / 2)),
    ((x_lim[1], y_lim.sum() / 2), (x_lim[1] - 1e-9, y_lim.sum() / 2)),
    ((x_lim.sum() / 2, y_lim[0]), (x_lim.sum() / 2, y_lim[0] + 1e-9)),
    ((x_lim.sum() / 2, y_lim[1]), (x_lim.sum() / 2, y_lim[1] - 1e9)),
]:
    for theta in thetas[:2]:
        y0 = trench_detection.geometry.edge_point(x0, theta, x_lim, y_lim)
        y1 = trench_detection.geometry.edge_point(x1, theta, x_lim, y_lim)
        print("======", y0, y1)

In [None]:
diag.keys()

In [None]:
diag["labeling"]["find_periodic_lines"].keys()

In [None]:
diag["labeling"]["find_periodic_lines"]["h_std"]

In [None]:
diag["bboxes"].opts(frame_width=800, aspect=1)

In [None]:
info2

In [None]:
trenches2[trenches2["line_widths"] == 0]

In [None]:
trenches2[trenches2["line_widths"] == 0]["top_y"].plot.hist(bins=100)

In [None]:
info2

In [None]:
diag2.keys()

In [None]:
diag2["labeling"]["find_periodic_lines"].keys()

In [None]:
diag2["labeling"]["find_periodic_lines"]["profile"]  # .keys()

In [None]:
diag2["labeling"]["set_finding"].keys()

In [None]:
diag2["labeling"]["set_finding"]["profiles"]

In [None]:
np.random.seed(9)
data = np.random.rand(10, 2)
points = hv.Points(data)
labels = hv.Labels(
    {("x", "y"): data, "text": [chr(65 + i) for i in range(10)]}, ["x", "y"], "text"
)
overlay = (points * labels).redim.range(x=(-0.2, 1.2), y=(-0.2, 1.2))

overlay.opts(
    hv.opts.Labels(text_font_size="10px", xoffset=0.08, yoffset=0.5),
    hv.opts.Points(color="black", size=5),
)

In [None]:
trenches_s = trenches2[:1000]

In [None]:
range_stream.x_range

In [None]:
ls

In [None]:
ls[(1000, 2000), (0, 2000)]

In [None]:
bbox_plot = hv.Rectangles(
    (
        trenches_s["ul_x"],
        trenches_s["lr_y"],
        trenches_s["lr_x"],
        trenches_s["ul_y"],
    )
).opts(fill_color=None, line_color="red")

ls = hv.Labels(
    (trenches_s["ul_x"], trenches_s["ul_y"], trenches_s.index.values.astype(str))
).opts(text_color="black", text_font_size="10pt", xoffset=3, yoffset=3)


def filter_points(points, x_range, y_range):
    if x_range is None or y_range is None:
        return points
    return points[x_range, y_range]


def hover_points(points, threshold=20):
    if len(points) > threshold:
        return points.iloc[:0]
    return points


range_stream = hv.streams.RangeXY(source=bbox_plot)
streams = [range_stream]

filtered = ls.apply(filter_points, streams=streams)
# shaded = datashade(filtered, width=400, height=400, streams=streams)
hover = filtered.apply(hover_points)

bbox_plot * hover

In [None]:
trench_detection.core._trench_bbox_plot(trenches2)

In [None]:
ui.RevImage(img_t) * trench_detection.core._trench_bbox_plot(trenches2)

In [None]:
diag2["labeling"]["set_finding"]["image_with_lines"]

In [None]:
%%time
r3 = trench_detection.find_trenches(img[:500, :500], join_info=True)

In [None]:
r3

In [None]:
len(r)

In [None]:
r

In [None]:
r2

In [None]:
r2[1]

In [None]:
for msg in new.readers.send_nd2(
    filename,
    slices=dict(v=slice(1), t=slice(1)),
):
    handle_message(pipeline, msg)