# Imports

In [None]:
import glob
import itertools as it
import os
from functools import partial
from pathlib import Path

import dask
import distributed
import h5py
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import scipy
import skimage.measure
import zarr
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress
from holoviews.operation.datashader import regrid
from tqdm.auto import tqdm, trange

IDX = pd.IndexSlice

In [None]:
# TODO: does this help?
# %config InlineBackend.figure_format = "jpg"

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.image_analysis.calibration.distortion as distortion
from paulssonlab.image_analysis import *
from paulssonlab.image_analysis.ui import display_image

In [None]:
%load_ext pyinstrument

In [None]:
hv.extension("bokeh")

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="06:00:00",
    memory="2GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(1)

In [None]:
cluster.adapt(maximum=20)

# Geometric distortion

## Image correction is equivalent to coördinate correction

In [None]:
filename = "/home/jqs1/scratch/jqs1/microscopy/230728/calibration/230728_ultrarainbow_64fov_cfp.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230728/calibration/230728_ultrarainbow_64fov_cy5.nd2"
nd2 = nd2reader.ND2Reader(filename)

In [None]:
%%time
k1 = 7e-10
img_distorted = nd2.get_frame_2D(v=fov1, t=t, z=z_idx, c=channel_idx)
coords_distorted = distortion.df_to_coords(distortion.find_puncta(img_distorted))
input_shape = img_distorted.shape
input_center = image.center_from_shape(input_shape)
output_shape, output_center = image.radial_distortion_output(
    k1, input_shape, input_center
)
# output_shape not used
coords_distorted_corrected = image.radial_distortion(
    coords_distorted, k1=k1, input_center=input_center, output_center=output_center
)
img_corrected = image.correct_radial_distortion(img_distorted, k1=k1)
coords_corrected = distortion.df_to_coords(distortion.find_puncta(img_corrected))

In [None]:
correspondence_dists, correspondence_idxs = distortion.nearest_neighbors(
    coords_corrected, coords_distorted_corrected
)

In [None]:
plt.hist(correspondence_dists, bins=100, log=True);

In [None]:
plt.figure(figsize=(30, 30), dpi=200)
distortion.plot_puncta(
    img=img_corrected,
    coords=coords_corrected,
    coords2=coords_distorted_corrected,
    scale=0.99,
)

## Optimization

In [None]:
# filename = "/home/jqs1/scratch/jqs1/microscopy/230726/calibration/230726_ultrarainbow_40x40_Cy5-EM.nd2"
filename = "/home/jqs1/scratch/jqs1/microscopy/230728/calibration/230728_ultrarainbow_64fov_cfp.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230728/calibration/230728_ultrarainbow_64fov_cy5.nd2"
# filename = "/home/jqs1/scratch/jqs1/microscopy/230726/calibration/230726_ultrarainbow_40x40_zstack_nocy7_fov1.nd2"
nd2 = nd2reader.ND2Reader(filename)

In [None]:
%%time
fov_pairs = list(it.pairwise(range(2)))
prepared = distortion._prepare_optimize_correction(nd2, fov_pairs=tqdm(fov_pairs))

In [None]:
%%time
res = distortion.optimize_radial_distortion_correction(
    nd2, fov_pairs=fov_pairs, prepared=prepared
)

In [None]:
coords1_distorted = prepared[0][0]
coords2_distorted = prepared[0][1]
transform = prepared[1][(0, 1)]

In [None]:
%%time
k1 = 0  # 1e-9
input_center = image.center_from_shape((nd2.sizes["y"], nd2.sizes["x"]))
coords1 = coords1_distorted
coords2 = coords2_distorted
# coords1 = image.radial_distortion(coords1_distorted, k1=k1, input_center=input_center)
# coords2 = image.radial_distortion(coords2_distorted, k1=k1, input_center=input_center)
coords2_transformed = transform.inverse(coords2)
correspondence_dists, correspondence_idxs = distortion.nearest_neighbors(
    coords1, coords2_transformed
)
correspondence_mask = correspondence_dists < 5
coords1_correspondence = coords1[correspondence_mask]
coords2_correspondence = coords2[correspondence_idxs][correspondence_mask]
# transform.estimate(coords1_correspondence, coords2_correspondence)

In [None]:
plt.hist(correspondence_dists, bins=100, log=True);

In [None]:
plt.hist(correspondence_dists[correspondence_dists < 5], bins=100, log=True);

In [None]:
def correction_func(coords, params):
    return image.radial_distortion(
        coords, params[0], (params[1], params[2]), (params[1], params[2])
    )


def objective_func(params, correction_func, coords1_all, coords2_all, transform):
    # se = 0
    residuals = []
    for coords1, coords2 in zip(coords1_all, coords2_all):
        coords1_corrected = correction_func(coords1, params)
        coords2_corrected = correction_func(coords2, params)
        transform.estimate(coords1_corrected, coords2_corrected)
        # se += ((coords1_corrected - transform.inverse(coords2_corrected)) ** 2).sum()
        residuals.append(
            ((coords1_corrected - transform.inverse(coords2_corrected)) ** 2).sum(
                axis=-1
            )
        )
    rmse = np.median(np.concatenate(residuals))
    # rmse = np.sqrt(se / len(coords1_all))
    # print(params, rmse)
    return rmse

In [None]:
image.center_from_shape((nd2.sizes["y"], nd2.sizes["x"]))

In [None]:
# k1s = [0]
k1s = np.linspace(0, 1e-8, 30)
# k1 = 5.199999999997459e-10
# xs = [2527.5]
# ys = [1479.5]
xs = np.linspace(0, nd2.sizes["x"] - 1, 10)
ys = np.linspace(0, nd2.sizes["y"] - 1, 10)
obj = np.zeros((len(k1s), len(ys), len(xs)))
for x_idx, x in enumerate(tqdm(xs)):
    for y_idx, y in enumerate(ys):
        for k1_idx, k1 in enumerate(k1s):
            params = (k1, x, y)
            obj[k1_idx, y_idx, x_idx] = objective_func(
                params,
                correction_func,
                [coords1_correspondence],
                [coords2_correspondence],
                # skimage.transform.EuclideanTransform(),
                skimage.transform.EuclideanTransform(translation=(-47, 0)),
            )

In [None]:
plt.plot(obj[:, 5, 5])
plt.plot(obj[:, 9, 5]);

In [None]:
plt.imshow(obj[1]);

In [None]:
argmin = np.unravel_index(obj.argmin(), obj.shape)
argmin

In [None]:
tuple(vals[idx] for vals, idx in zip((k1s, xs, ys), argmin))

In [None]:
plt.plot(obj[:, 9, 5])
plt.plot(obj[:, 5, 5]);

In [None]:
nd2.sizes

In [None]:
nd2.get_frame_2D().shape

In [None]:
# INPUT: use stage x/y to guess initial translations/rotations (0) for each FOV pair

# LOOP until converged:
# FOR EACH FOV PAIR: compute nearest neighbor correspondences
# FOR EACH FOV PAIR: optimize translation/rotation for fixed coordinates
# FOR ALL FOV PAIRS JOINTLY: optimize distortion params

In [None]:
(tform.translation, tform.rotation)

In [None]:
# TODO: is this right or off-by-one?
# center = (np.array(shape[::-1]) - 1) / 2
center = None
k1 = 8.947368421052635e-10
coords1_undistorted = image.radial_distortion(
    coords1_correspondence, input_center=center, k1=k1
)
coords2_undistorted = image.radial_distortion(
    tform.inverse(coords2_correspondence), input_center=center, k1=k1
)

In [None]:
np.sqrt(((coords1_undistorted - coords2_undistorted) ** 2).mean())

In [None]:
coords1_correspondence

In [None]:
coords1_undistorted

In [None]:
vectorfield_difference(coords1_undistorted, coords2_undistorted)

In [None]:
vectorfield_difference(coords1_correspondence, coords1_undistorted)

In [None]:
# TODO: is this right or off-by-one?
k1s = np.arange(-1e-8, 1e-8, 1e-11)
rmses = []
for k1 in tqdm(k1s):
    coords1_undistorted = image.radial_distortion(
        coords1_correspondence, input_center=center, k1=k1
    )
    coords2_undistorted = tform.inverse(
        image.radial_distortion(coords2_correspondence, input_center=center, k1=k1)
    )
    rms = np.sqrt(((coords1_undistorted - coords2_undistorted) ** 2).mean())
    rmses.append(rms)

In [None]:
plt.plot(k1s, rmses);

In [None]:
k1s[np.argmin(rmses)]

In [None]:
k1s[np.argmin(rmses)]

In [None]:
k1 = 5.199999999997459e-8
img_undistorted = nd2.get_frame_2D(v=0, t=0, z=0, c=0)
img_distorted = image.correct_radial_distortion(img_undistorted, k1=k1)

In [None]:
%%time
input_center = (500, 500)
img_distorted2 = image.correct_radial_distortion(
    img_undistorted, k1=k1, input_center=input_center
)

In [None]:
img_distorted2

In [None]:
display_image(img_distorted2, scale=0.99)

In [None]:
display_image(img_distorted, scale=0.99)

In [None]:
display_image(img_undistorted, scale=0.99)

In [None]:
img_undistorted = nd2.get_frame_2D(v=0, t=0, z=0, c=0)

In [None]:
img_undistorted.shape

In [None]:
k1s = np.linspace(1e-10, 3e-9, 30)
# k1 = 5.199999999997459e-10
xs = np.linspace(0, img_undistorted.shape[1] - 1, 10)
ys = np.linspace(0, img_undistorted.shape[0] - 1, 10)
rmses = np.zeros((len(k1s), len(ys), len(xs)))
for x_idx, x in enumerate(tqdm(xs)):
    for y_idx, y in enumerate(ys):
        for k1_idx, k1 in enumerate(k1s):
            center = (x, y)
            coords1_undistorted = image.radial_distortion(
                coords1_correspondence, input_center=center, k1=k1
            )
            coords2_undistorted = tform.inverse(
                image.radial_distortion(
                    coords2_correspondence, input_center=center, k1=k1
                )
            )
            rms = np.sqrt(((coords1_undistorted - coords2_undistorted) ** 2).mean())
            rmses[k1_idx, y_idx, x_idx] = rms

In [None]:
# good: [2527.5, 1979.5]

In [None]:
(xs[5], ys[5])

In [None]:
k1s[4]

In [None]:
plt.plot(rmses[:, 5, 5])
plt.plot(rmses[:, 4, 4]);

In [None]:
plt.plot(rmses[0, :, 5]);

In [None]:
plt.imshow(rmses[2]);

# Trench detection

In [None]:
width_to_pitch_ratio = 1.4 / 3.5
segmentation_channel = "RFP-EM"
filename = workflow.SplitFilename(
    sorted(
        glob.glob(
            # "/home/jqs1/scratch/jqs1/microscopy/230707/230707_repressilators_restart.nd2.split.a*"
            "/home/jqs1/scratch/jqs1/microscopy/230830/230830_repressilators.nd2.split.*"
        )
    )
)

In [None]:
def get_frame_func(
    filename, position, channel, t, k1=0, center=None, dark=None, flat=None
):
    return image.correct_radial_distortion(
        np.asarray(
            workflow.get_nd2_frame(
                filename, position=position, channel=channel, t=t, dark=dark, flat=flat
            )
        ),
        k1=k1,
        input_center=center,
    )  # [600:2400, 1500:3500]
    # return np.asarray(
    #     workflow.get_nd2_frame(filename, position, channel, t, dark=dark, flat=flat)
    # )[550:2350, 1500:3500]

In [None]:
display_image(img0, scale=0.99, downsample=4)

In [None]:
img_uncorrected = (
    nd2.get_frame_2D()
)  # get_frame_func(filename, 258, segmentation_channel, 0, k1=0)

In [None]:
img_uncorrected.shape

In [None]:
display_image(img0, downsample=1, scale=0.99)

In [None]:
center = image.center_from_shape(img_uncorrected.shape) - np.array([0, -500])

In [None]:
center

In [None]:
%%time
# k1 = 5.199999999997459e-10
# k1 = 7e-10
# k1 = 6.8e-10
k1 = 1e-9
img0 = get_frame_func(filename, 0, segmentation_channel, 0, k1=k1, center=center)
image_limits = geometry.get_image_limits(img0.shape)

In [None]:
%%time
diag = util.tree()
rois, info = trench_detection.find_trenches(
    img0,
    width_to_pitch_ratio=width_to_pitch_ratio,
    join_info=False,
    diagnostics=diag,
)
angle = info["angle"]
pitch = info["pitch"]

In [None]:
# k1=
(angle, pitch)

In [None]:
# k1=5e-10
(angle, pitch)

In [None]:
# k1=0
(angle, pitch)

In [None]:
# k1=1e-9, -500
diag["labeling"]["set_finding"]["image_with_lines"]

In [None]:
# k1=1e-9, -500
diag["labeling"]["set_finding"]["image_with_lines"]

In [None]:
# k1=1e-9, +1000
diag["labeling"]["set_finding"]["image_with_lines"]

In [None]:
# k1=1e-9, -1000
diag["labeling"]["set_finding"]["image_with_lines"]

In [None]:
# k1=1e-10
diag["labeling"]["set_finding"]["image_with_lines"]

In [None]:
# k1=8e-10
diag["labeling"]["set_finding"]["image_with_lines"]

In [None]:
# k1=
diag["labeling"]["set_finding"]["image_with_lines"]

In [None]:
# k1=0
diag["labeling"]["set_finding"]["image_with_lines"]

In [None]:
diag["bboxes"]

In [None]:
trench_detection.plot_trenches(bboxes=False, lines=True)

In [None]:
diag["labeling"]["find_periodic_lines"]["profile"]  # .keys()

In [None]:
diag["labeling"]["find_periodic_lines"]["profile"]  # .keys()

In [None]:
diag["labeling"]["find_periodic_lines"]["profile"]  # .keys()

In [None]:
diag["labeling"]["find_periodic_lines"]["profile"]  # .keys()

In [None]:
diag["labeling"]["find_periodic_lines"]["profile"]  # .keys()

In [None]:
(angle, pitch)

In [None]:
diag["labeling"]["find_periodic_lines"]["peak_func"]["spectrum"]

In [None]:
diag["labeling"]["find_periodic_lines"]["profile"]  # .keys()

In [None]:
diag["labeling"]["set_finding"]["image_with_lines"]

In [None]:
diag["bboxes"]

In [None]:
diag["labeling"]["set_finding"]["profiles"]