# Imports

In [None]:
import itertools as it
import os
import re
from collections import namedtuple
from functools import partial
from pathlib import Path

import dask
import distributed
import h5py
import holoviews as hv
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import scipy
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from tqdm.auto import tqdm

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.image_analysis.new as new
from paulssonlab.image_analysis import *
from paulssonlab.util.ui import display_image

In [None]:
# %load_ext pyinstrument

In [None]:
hv.extension("bokeh")

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="06:00:00",
    memory="2GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(1)

In [None]:
cluster.adapt(maximum=20)

# Trench detection

In [None]:
# filename = "/home/jqs1/scratch/jqs1/microscopy/230213/230213induction.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230215/230215induction.nd2" #v=7
# filename = "/home/jqs1/scratch/jqs1/microscopy/230326/230326promrbs.nd2" #v=8,t=10
filename = "/home/jqs1/scratch/jqs1/microscopy/230404/230404_rbsprom.nd2"

In [None]:
nd2 = nd2reader.ND2Reader(filename)

In [None]:
nd2.metadata["channels"]

In [None]:
img = nd2.get_frame_2D(v=8, c=0, t=10)

In [None]:
display_image(img / img.max() * 4)

In [None]:
%%time
diag = util.tree()
trenches = trench_detection.find_trenches(img, diagnostics=diag)

In [None]:
%%time
diag2 = util.tree()
trenches2 = trench_detection.find_trenches(
    img, peak_func=trench_detection.peaks.find_peaks, diagnostics=diag2
)

In [None]:
trenches

In [None]:
diag["label_1"]["find_trench_ends"]["image_with_trenches"]

In [None]:
diag2["label_1"]["find_trench_ends"]["image_with_trenches"]

In [None]:
diag["label_1"]["find_trench_lines"]["hough_0"]["peak_func"].keys()

In [None]:
diag["label_1"]["find_trench_lines"]["hough_0"]["peak_func"]["spectrum"]

In [None]:
diag["labeling"].keys()

In [None]:
%%time
diag3 = util.tree()
img_bin = trench_detection.set_finding.binarize_trench_image(img, diagnostics=diag3)

In [None]:
img_bin

In [None]:
plt.hist(img.flat, bins=300, log=True);

In [None]:
from paulssonlab.image_analysis.image import (
    gaussian_box_approximation,
    normalize_componentwise,
    remove_large_objects,
)

In [None]:
lowpass_radius = 500
img_lowpass = gaussian_box_approximation(img, lowpass_radius)

In [None]:
0

In [None]:
%%time
rb = skimage.restoration.rolling_ball(img, radius=30)

In [None]:
display_image(rb, scale=True)

In [None]:
display_image((img - rb) / img.max() * 20)

In [None]:
?skimage.filters.threshold_sauvola

In [None]:
display_image(img > skimage.filters.threshold_sauvola(img, window_size=7))

In [None]:
display_image(img > skimage.filters.threshold_otsu

In [None]:
display_image(img / img.max() * 30)

In [None]:
display_image(img - img_lowpass, scale=True)

In [None]:
display_image(img - img_lowpass, scale=True)

In [None]:
display_image(img_bin[1])

In [None]:
diag["labeling"]["binarize_trench_image"].keys()

In [None]:
diag["labeling"]["binarize_trench_image"]["thresholded_image"]

In [None]:
diag["label_1"]["find_trench_lines"]["hough_0"]["peak_func"]["refined_points"]

In [None]:
diag["label_2"]["find_trench_ends"].keys()

In [None]:
??trench_detection.hough.find_periodic_lines

In [None]:
%%time
h, theta, rho = trench_detection.hough.hough_line_intensity(
    img_t, theta=np.linspace(-np.pi / 5, np.pi / 5, 1000)
)

In [None]:
display_image(h, scale=True)

In [None]:
%%time
smooth = 4
diff_h = np.diff(h.astype(np.int_), axis=1)  # TODO: is diff necessary??
diff_h_std = diff_h.std(axis=0)  # / diff_h.max(axis=0)
if smooth:
    diff_h_std_smoothed = scipy.ndimage.gaussian_filter1d(diff_h_std, smooth)

In [None]:
plt.plot(diff_h_std_smoothed)

In [None]:
h.shape

In [None]:
h.shape

In [None]:
imor

In [None]:
scipy.signal.periodogram

In [None]:
%%time
nfft = 2**16
max_period = None
freqs, spectrum = scipy.signal.periodogram(
    h, window="hann", nfft=nfft, scaling="spectrum", axis=0
)
if max_period:
    spectrum[:max_period] = 0
# pitch_idx = spectrum.argmax()
# pitch = 1 / freqs[pitch_idx]

In [None]:
spectrum.shape

In [None]:
display_image(spectrum / spectrum.max() * 5)

In [None]:
h[0].shape

In [None]:
np.deg2rad(5) / np.pi

# Dewarping

In [None]:
display_image(img / img.max() * 5, downsample=4)

In [None]:
img.shape

In [None]:
np.array([0, 5055]) - 5055 / 2

In [None]:
quadratic_root(0, 1 / 8574218.5, 1)

In [None]:
# SEE: https://math.stackexchange.com/a/2007723
def quadratic_root(a, b, c):
    return 2 * c / (-b + np.sqrt(b**2 - 4 * a * c))


# SEE: http://www.cs.ait.ac.th/~mdailey/papers/Bukhari-RadialDistortion.pdf
def barrel_transform(coords, center=np.array((0, 0)), k1=0):
    coords_centered = coords - center
    r_u = np.sqrt((coords_centered**2).sum(axis=1))[:, np.newaxis]
    r_d = quadratic_root(k1 * r_u, -1, r_u)
    new_coords = center + (r_d / r_u) * coords_centered
    return new_coords

In [None]:
%%time
# k1 = -5e-9
# k1 = 2e-9
# k1 = 1.5e-9
k1 = 8.947368421052635e-10
# k1 = 0
img_t = skimage.transform.warp(
    img,
    barrel_transform,
    map_args=dict(center=(np.array(img.shape)[::-1] - 1) / 2, k1=k1),
    preserve_range=True,
)

In [None]:
display_image(img_t / img_t.max() * 5, downsample=4)

In [None]:
display_image(img / img.max() * 5, downsample=4)

In [None]:
img_t2 = np.nan_to_num(img_t, nan=0)

In [None]:
np.nanmax(img_t)

In [None]:
np.nanmin(img_t)

In [None]:
display_image((img_t - img), scale=True, downsample=4)

# Dewarping optimization

In [None]:
k1s = np.linspace(-1e-9, 2e-9, 20)
res = []
for k1 in tqdm(k1s):
    img_corrected = skimage.transform.warp(
        img,
        barrel_transform,
        map_args=dict(center=(np.array(img.shape)[::-1] - 1) / 2, k1=k1),
    )
    h, theta, rho = trench_detection.hough.hough_line_intensity(
        img_corrected, theta=np.linspace(-np.pi / 50, np.pi / 50, 400)
    )
    smooth = 4
    diff_h = np.diff(h.astype(np.int_), axis=1)  # TODO: is diff necessary??
    diff_h_std = diff_h.std(axis=0)  # / diff_h.max(axis=0)
    if smooth:
        diff_h_std_smoothed = scipy.ndimage.gaussian_filter1d(diff_h_std, smooth)
    else:
        diff_h_std_smoothed = diff_h_std
    theta_idx = diff_h_std_smoothed.argmax()
    diff_h_std_max = diff_h_std_smoothed[theta_idx]
    angle = theta[theta_idx]
    res.append(
        dict(
            k1=k1,
            h=h,
            diff_h=diff_h,
            diff_h_std=diff_h_std,
            diff_h_std_smoothed=diff_h_std_smoothed,
            angle=angle,
            theta_idx=theta_idx,
            img_corrected=img_corrected,
        )
    )

In [None]:
plt.figure(figsize=(20, 20))
for idx in range(len(res)):
    plt.plot(res[idx]["diff_h_std"][120:200], label=idx)
plt.legend()
plt.plot(res[12]["diff_h_std"][120:200], lw=4)

In [None]:
plt.figure(figsize=(20, 20))
idx = 12
plt.plot(res[idx - 2]["diff_h_std"][120:200])
plt.plot(res[idx]["diff_h_std"][120:200], lw=4)
plt.plot(res[idx + 2]["diff_h_std"][120:200])

In [None]:
res[12]["k1"]

In [None]:
display_image(res[-3]["img_corrected"], scale=True)

In [None]:
display_image((img_t - img), scale=True, downsample=1)

In [None]:
k1s

In [None]:
img_t

In [None]:
np.asarray(img)

In [None]:
img_t.max()

In [None]:
img.max()

In [None]:
%%time
diag = util.tree()
trenches = trench_detection.find_trenches(img, diagnostics=diag)

In [None]:
%%time
diag_t = util.tree()
trenches_t = trench_detection.find_trenches(img_t, diagnostics=diag_t)

In [None]:
diag_t["label_1"]["find_trench_ends"]["image_with_trenches"]

In [None]:
diag["label_1"]["find_trench_ends"]["image_with_trenches"]

In [None]:
diag["labeling"]["set_finding"].keys()

In [None]:
diag["labeling"]["set_finding"]["image_with_lines"]

In [None]:
diag["labeling"]["find_trench_lines"]["hough_0"]["profile"]

In [None]:
diag_t["labeling"]["set_finding"]["image_with_lines"]