# Select reg2ref parameters

The registration to reference uses a couple of parameters that can be tuned to improve 
the registration. The parameters are:

- `reg_channels`: The channels of the moving image to use. The default is `None` to 
use the average of all channel.
- `ref_channels`: The  channels of the fixed image to use. The default is `None` to 
use the average of all channel.
- `binarise_quantile`: The quantile to use to binarise the moving image. The default
is `0.7`.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import iss_preprocess as iss

data_path = "becalia_rabies_barseq/BRAC8501.6a/chamber_07"
reg_prefix = "barcode_round"
ref_prefix = "genes_round"
ref_tile_index = 0  # which of the reference tiles do we want to use for plots

# PARAMETERS:
reg_channels = None  # None or list of channels to use for registration
ref_channels = None  # reference channel for registration
binarise_quantile = 0.7  # quantile for binarization

In [None]:
ops = iss.io.load.load_ops(data_path)
short_pref = reg_prefix.split("_")[0]
ref_tiles = ops[f"{short_pref}_ref_tiles"]
print(f"{len(ref_tiles)} reference tiles found. Using {ref_tile_index}.")
ref_tile = ref_tiles[ref_tile_index]

## Channel selection

The first parameter is the channel selection. The registration to reference is performed
on the average of the selected channels. The hope is to improve the registration by
getting the background from the average of the channels, which is more conserved 
than the rolonie position.

In [None]:
# load the data
from scipy.ndimage import median_filter
from skimage.morphology import disk

ref_all_channels, _ = iss.pipeline.load_and_register_tile(
    data_path=data_path,
    tile_coors=ref_tile,
    prefix=ref_prefix,
    filter_r=False,
)
reg_all_channels, _ = iss.pipeline.load_and_register_tile(
    data_path=data_path, tile_coors=ref_tile, prefix=reg_prefix, filter_r=False
)

if ref_channels is not None:
    if isinstance(ref_channels, int):
        ref_channels = [ref_channels]
    ref_all_channels = ref_all_channels[:, :, ref_channels]
ref = np.nanmean(ref_all_channels, axis=(2, 3))

if reg_channels is not None:
    if isinstance(reg_channels, int):
        reg_channels = [reg_channels]
    reg_all_channels = reg_all_channels[:, :, reg_channels]
reg = np.nanmean(reg_all_channels, axis=(2, 3))


if ops["reg_median_filter"]:
    print("Median filtering...")
    ref = median_filter(ref, footprint=disk(ops["reg_median_filter"]), axes=(0, 1))
    reg = median_filter(reg, footprint=disk(ops["reg_median_filter"]), axes=(0, 1))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
iss.vis.plot_matrix_with_colorbar(ref, ax=axes[0])
axes[0].set_title("Reference")
iss.vis.plot_matrix_with_colorbar(reg, ax=axes[1])
axes[1].set_title("Target")
for ax in axes:
    ax.axis("off")

## Binarisation threshold

The second parameter is the binarisation threshold. The registration to reference is
performed on the binarised image. The binarisation is performed using a threshold.
The threshold can be tuned to improve the registration.

In [None]:
# apply binarization
if binarise_quantile is not None:
    reg_b = reg > np.quantile(reg, binarise_quantile)
    ref_b = ref > np.quantile(ref, binarise_quantile)
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
iss.vis.plot_matrix_with_colorbar(ref_b, ax=axes[0])
axes[0].set_title("Reference")
iss.vis.plot_matrix_with_colorbar(reg_b, ax=axes[1])
axes[1].set_title("Target")
for ax in axes:
    ax.axis("off")

# Actual registration

This is the slow part. If this reference image doesn't work, no hope for the rest of the
tiles.

In [None]:
angle, shift, xcorr = iss.reg.estimate_rotation_translation(
    ref,
    reg,
    angle_range=1.0,
    niter=3,
    nangles=15,
    max_shift=ops["rounds_max_shift"],
    debug=True,
)
print(f"Estimated rotation: {angle} degrees")
print(f"Estimated shifts: {shift}")

In [None]:
# transform the reg image to match the ref
reg_t = iss.reg.transform_image(reg, angle=angle, shift=shift)
reg_bt = iss.reg.transform_image(reg_b, angle=angle, shift=shift)
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()
iss.vis.plot_matrix_with_colorbar(ref, ax=axes[0])
axes[0].set_title("Reference")
iss.vis.plot_matrix_with_colorbar(reg, ax=axes[1])
axes[1].set_title("Target")
iss.vis.plot_matrix_with_colorbar(reg_t, ax=axes[2])
axes[2].set_title("Transformed")
iss.vis.plot_matrix_with_colorbar(ref_b, ax=axes[3])
axes[3].set_title("Reference (binarized)")
iss.vis.plot_matrix_with_colorbar(reg_b, ax=axes[4])
axes[4].set_title("Target (binarized)")
iss.vis.plot_matrix_with_colorbar(reg_bt, ax=axes[5])
axes[5].set_title("Transformed (binarized)")
for ax in axes:
    ax.axis("off")

In [None]:
# make a rgb overlay
vmins = [np.percentile(ref, 1), np.percentile(reg_t, 1)]
vmaxs = [np.percentile(ref, 99.5), np.percentile(reg_t, 99.5)]
rgb = iss.vis.to_rgb(
    np.stack([ref, reg_t], axis=2),
    colors=([1, 0, 0], [0, 1, 0]),
    vmin=vmins,
    vmax=vmaxs,
)
rgb_b = iss.vis.to_rgb(
    np.stack([ref_b, reg_bt], axis=2),
    colors=([1, 0, 0], [0, 1, 0]),
    vmin=[0, 0],
    vmax=[1, 1],
)
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
center = np.array(ref.shape) / 2
for i in range(2):
    axes[0, i].imshow(rgb)
    axes[1, i].imshow(rgb_b)
    axes[i, 1].set_xlim(center[1] - 250, center[1] + 250)
    axes[i, 1].set_ylim(center[0] + 250, center[0] - 250)

for ax in axes.flatten():
    ax.axis("off")

## Debugging

In case it fails, looking at the cross-correlation can help to understand why it fails.

In [None]:
print("There are the following keys in xcorr:")
for k in xcorr.keys():
    print(f"    - {k}")

In [None]:
key = "phase_corr"
x = xcorr[key]
max_shift = ops["rounds_max_shift"]
plt.imshow(x, cmap="viridis")
center = np.array(x.shape) / 2
plt.xlim(center[1] - max_shift, center[1] + max_shift)
plt.ylim(center[0] + max_shift, center[0] - max_shift)
plt.axvline(center[1], color="k", linestyle="--", lw=0.5)
plt.axhline(center[0], color="k", linestyle="--", lw=0.5)
plt.scatter(
    shift[1] + center[1],
    shift[0] + center[0],
    color="r",
    s=max_shift / 5,
    marker="o",
    fc="none",
)