# Select parameters to register fluorescence images

This is for acquisition that have no rounds, but multiple channels to align.

- `binarise_quantile`: The quantile to use to binarise the moving image. The default
is `0.7`.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import iss_preprocess as iss
import iss_preprocess.io

data_path = "becalia_rabies_barseq/BRAC8498.3e/chamber_06/"
prefix = "mCherry_1"
ops = iss.io.load_ops(data_path)
tile_coors = ops["ref_tile"]
tile_coors = [19, 5, 5]

## Look at current registration

This runs the registration with the current parameters from ops.

In [None]:
# Current registration results
import iss_preprocess.pipeline.register

reg_out, db_output = iss.pipeline.register.register_fluorescent_tile(
    data_path,
    tile_coors,
    prefix,
    reference_prefix=None,
    debug=True,
    save_output=False,
)

In [None]:
ops_prefix = prefix.split("_")[0].lower()
projection = ops[f"{ops_prefix}_projection"]
projection = ops.get(f"{ops_prefix}_reg_projection", projection)
print(f"Using projection: {projection}")
stack_ori = iss.io.load_tile_by_coors(
    data_path, tile_coors=tile_coors, suffix=projection, prefix=prefix
)
corrected_hyb = iss.reg.rounds_and_channels.apply_corrections(
    stack_ori, matrix=reg_out["matrix_between_channels"], cval=0.0
)

In [None]:
# Plot the initial registration
colors = [(1, 0, 0), (0, 1, 0), (1, 0, 1), (0, 1, 1)]
vmax = np.percentile(stack_ori, 99.9, axis=(0, 1))
vmin = np.percentile(stack_ori, 0.1, axis=(0, 1))
rgb_ori = iss.vis.to_rgb(stack_ori, colors, vmax, vmin)
rgb_reg = iss.vis.to_rgb(corrected_hyb, colors, vmax, vmin)

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(rgb_ori)
axes[0].set_title("Before channel registration")
axes[1].imshow(rgb_reg)
axes[1].set_title("Registered")
if False:
    for ax in axes:
        ax.axis("off")
        ax.set_xlim([1000, 1500])
        ax.set_ylim([1000, 1500])
fig.tight_layout()

# Change parameters

In [None]:
projection = "max-median"
threshold_quantile = None
block_size = 512
overlap = 0.8
max_residual = 5

ref_ch = ops["ref_ch"]

## Median filter

The images are first filtered with a median filter to remove noise. The size of the
filter is `ops["reg_median_filter"]`. The size does not need to be very large, it is 
just used to remove dead pixels and some regular noise.

In [None]:
from scipy.ndimage import median_filter
from skimage.morphology import disk

stack_ori = iss.io.load_tile_by_coors(
    data_path, tile_coors=tile_coors, suffix=projection, prefix=prefix
)

# median filter if needed
median_filter_size = ops["reg_median_filter"]
if median_filter_size is not None:
    print(f"Filtering with median filter of size {median_filter_size}")
    assert isinstance(median_filter_size, int), "reg_median_filter must be an integer"
    stack = median_filter(
        stack_ori.copy(), footprint=disk(median_filter_size), axes=(0, 1)
    )
else:
    stack = stack_ori

w = 512

fig = plt.figure(figsize=(7, 5))
vmin, vmax = np.percentile(stack[:w, :w, ref_ch], [1, 99.9])
plt.subplot(1, 3, 1)
plt.imshow(stack_ori[:w, :w, ref_ch], vmin=vmin, vmax=vmax)
plt.title("Original")
plt.subplot(1, 3, 2)
plt.imshow(stack[:w, :w, ref_ch], vmin=vmin, vmax=vmax)
plt.title(f"Filtered ({median_filter_size} px disk)")

rgb = iss.vis.to_rgb(
    np.dstack([stack_ori[:w, :w, ref_ch], stack[:w, :w, ref_ch]]),
    colors=[(1, 0, 0), (0, 1, 0)],
    vmin=(vmin, vmin),
    vmax=(vmax, vmax),
)
plt.subplot(1, 3, 3)
plt.imshow(rgb)
plt.title("Overlay")
for x in fig.axes:
    x.axis("off")

# Binarisation quantile

The binarisation is used to increase the importance of the shared background between
channels in the registration. It is applied on each block of the image independently,
but here to get an idea we just plot the whole image with one threshold.

In [None]:
tl = (0, 0)  # top left of the part to plot
w = 2024
# width of the part to plot
nch = stack.shape[2]
stack_bin = stack.copy()
plt.subplot(111)
if threshold_quantile is not None:
    for ich in range(nch):
        ref_thresh = np.quantile(stack[:, :, ich], threshold_quantile)
        stack_bin[:, :, ich] = stack[:, :, ich] > ref_thresh
else:
    plt.title("NO BINARISATION")
rgb = iss.vis.to_rgb(
    stack_bin[tl[0] : tl[0] + w, tl[1] : tl[1] + w, :],
    colors=[(1, 0, 0), (0, 1, 0), (1, 0, 1), (0, 1, 1)],
    vmin=[0] * nch,
    vmax=[1] * nch,
)

plt.imshow(rgb)
plt.xticks([])
plt.yticks([])
plt.tight_layout()

# Block size and overlap

The registration will be done in blocks. The size of the blocks should be big enough
to make sure that there is enough shared information between the channels, but small
enough to ensure we can get local variation. The overlap should be high enough to get a 
decent fit from the shifts.

In [None]:
# remove the runtime warning
import warnings

import iss_preprocess.diagnostics.diag_register

warnings.filterwarnings("ignore", category=RuntimeWarning)
(
    tile_coors,
    matrices,
    debug_info,
) = iss.diagnostics.diag_register.check_affine_channel_registration(
    data_path,
    prefix=prefix,
    tile_coords=tile_coors,
    projection=projection,
    binarise_quantile=threshold_quantile,
    block_size=block_size,
    overlap=overlap,
    max_residual=max_residual,
)
warnings.filterwarnings("default")

# Final results

That should look pretty

In [None]:
corrected_hyb = iss.reg.rounds_and_channels.apply_corrections(
    stack_ori, matrix=matrices, cval=0.0
)
vmax = np.percentile(corrected_hyb, 99.9, axis=(0, 1))
vmin = np.percentile(corrected_hyb, 1, axis=(0, 1))

fig = plt.figure(figsize=(10, 5))
ax = plt.subplot(1, 2, 1)
rgb = iss.vis.to_rgb(
    stack_ori, colors=[(1, 0, 0), (0, 1, 0), (1, 0, 1), (0, 1, 1)], vmin=vmin, vmax=vmax
)
plt.imshow(rgb)
plt.subplot(1, 2, 2, sharex=ax, sharey=ax)
rgb = iss.vis.to_rgb(
    corrected_hyb,
    colors=[(1, 0, 0), (0, 1, 0), (1, 0, 1), (0, 1, 1)],
    vmin=vmin,
    vmax=vmax,
)
plt.imshow(rgb)
ax.set_xlim(3290 - 512, 3290)
ax.set_ylim(0, 512)
ax.set_xticks([])
ax.set_yticks([])


fig.suptitle(
    f"{block_size} block, {overlap} overlap, {threshold_quantile} quantile binarisation"
)

## [slow] Look at a whole ROI

If you want to look at the whole ROI, you can run the registration on the whole image
and stitch. We are not doing this by default as it is slow.

In [None]:
raise ValueError

# Stop to make sure we run the slow only if we really want

In [None]:
# Raw stitching, just looking at this acquisition

# ensure that we have tiling information
_ = iss.pipeline.register_within_acquisition(
    data_path, prefix=prefix, roi=5, reload=True, save_plot=True, use_slurm=False
)
stitched_registered = [
    iss.pipeline.stitch.stitch_tiles(
        data_path,
        prefix,
        roi=5,
        suffix="max-median",
        ich=i,
        correct_illumination=False,
        shifts_prefix=None,
        register_channels=True,
        allow_quick_estimate=False,
    )
    for i in range(4)
]
rgb = iss.vis.to_rgb(np.dstack(stitched_registered), colors, vmax, vmin)
print(rgb.shape)
plt.figure(figsize=(20, 20))
plt.imshow(rgb[5000:10000, 10000:15000, :], interpolation="none")

In [None]:
# Final stitching to reference
stitched_registered = iss.pipeline.stitch.stitch_registered(
    data_path,
    prefix,
    5,
    channels=range(4),
    ref_prefix=None,
    filter_r=False,
    projection=None,
    correct_illumination=False,
)
bad = np.any(stitched_registered == 0, axis=2)
vmax = np.nanpercentile(stitched_registered[~bad, :], 99.99, axis=(0))
rgb = iss.vis.to_rgb(stitched_registered, colors, vmax, vmin=0)
plt.figure(figsize=(20, 20))
plt.imshow(rgb[5000:10000, 10000:15000, :], interpolation="none")