# Registering mCherry to reference acquisition

This is hard because that's the only acquisition before sequencing, which means before
we have any background fluorescence. Let's first register the 2 mCherry channels.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np

import iss_preprocess as issp

In [None]:
project = "becalia_rabies_barseq"
mouse = "BRAC8498.3e"
chamber = "chamber_07"
mcherry_prefix = "mCherry_1"
roi = 9

In [None]:
data_path = f"{project}/{mouse}/{chamber}"
ops = issp.io.load_ops(data_path)
ref_prefix = ops["reference_prefix"]
mcherry_channels = [ops["mcherry_signal_channel"], ops["mcherry_background_channel"]]
ref_channel = ops["ref_ch"]

# Register mCherry within

In [None]:
import matplotlib.pyplot as plt

example_tile, bd = issp.pipeline.load_and_register_tile(
    data_path, tile_coors=[roi, 3, 3], prefix=mcherry_prefix, filter_r=False
)
fig, axes = plt.subplots(2, 2)
for i, ax in enumerate(axes.flat):
    ax.imshow(example_tile[..., i, 0], vmin=0, vmax=100)
    ax.set_title(f"Channel {i}")
    ax.set_axis_off()

rgb = issp.vis.to_rgb(
    example_tile[..., 0],
    colors=[(1, 0, 0), (0, 1, 0), (1, 0, 1), (0, 1, 1)],
    vmin=0,
    vmax=100,
)
plt.figure()
plt.imshow(rgb)
plt.axis("off")

In [None]:
o = issp.pipeline.stitch.register_within_acquisition(
    data_path,
    roi,
    prefix=mcherry_prefix,
    ref_ch=3,
    suffix="max",
    correct_illumination=False,
    reload=False,
    save_plot=True,
    dimension_prefix="genes_round_1_1",
    min_corrcoef=0.4,
    max_delta_shift=20,
    verbose=2,
)

In [None]:
ops = issp.io.load_ops(data_path)
use_masked_correlation = True
save_plot = True
downsample = 5
estimate_rotation = True

reg_prefix = mcherry_prefix
if (ref_prefix is None) or (ref_prefix == "None"):
    ref_prefix = ops["reference_prefix"]
ref_channels = ops["reg2ref_reference_channels"]
spref = reg_prefix.split("_")[0]  # short prefix
# use either the same as ref or what is in the ops
reg_channels = ops.get(f"reg2ref_{spref}_channels", ref_channels)
# if there is something defined for this acquisition, use it instead
reg_channels = ops.get(f"reg2ref_{reg_prefix}_channels", reg_channels)

# get the transformation from the stitched image to the reference
print(f"Registering {reg_prefix} to {ref_prefix} for ROI {roi}")
print(f"    mask: {use_masked_correlation}")
print(f"    ref_channels: {ref_channels}")
print(f"    reg_channels: {reg_channels}")
print(f"    estimate_rotation: {estimate_rotation}")
print(f"    downsample: {downsample}")
print(f"    save_plot: {save_plot}")

# first register within if needed
issp.pipeline.register_within_acquisition(
    data_path,
    prefix=ref_prefix,
    roi=roi,
    reload=True,
    save_plot=True,
    use_slurm=False,
)
issp.pipeline.register_within_acquisition(
    data_path,
    prefix=reg_prefix,
    roi=roi,
    reload=True,
    save_plot=True,
    use_slurm=False,
)

In [None]:
target_prefix = reg_prefix
reference_prefix = ref_prefix
target_ch = [2, 3]
ref_ch = ref_channels
target_projection = "max"
if reference_prefix is None:
    reference_prefix = ops["reference_prefix"]

ref_projection = ops[f"{reference_prefix.split('_')[0].lower()}_projection"]
if isinstance(target_ch, int):
    target_ch = [target_ch]
stitched_stack_target = None
for ch in target_ch:
    stitched = issp.pipeline.stitch.stitch_tiles(
        data_path,
        target_prefix,
        suffix=target_projection,
        roi=roi,
        ich=ch,
        shifts_prefix=target_prefix,
        correct_illumination=True,
    ).astype(np.single)  # to save memory
    if stitched_stack_target is None:
        stitched_stack_target = stitched
    else:
        stitched_stack_target += stitched
stitched_stack_target /= len(target_ch)

In [None]:
if isinstance(ref_ch, int):
    ref_ch = [ref_ch]
stitched_stack_reference = None
for ch in ref_ch:
    stitched = issp.pipeline.stitch.stitch_tiles(
        data_path,
        prefix=reference_prefix,
        suffix=ref_projection,
        roi=roi,
        ich=ch,
        shifts_prefix=reference_prefix,
        correct_illumination=True,
    ).astype(np.single)
    if stitched_stack_reference is None:
        stitched_stack_reference = stitched
    else:
        stitched_stack_reference += stitched
stitched_stack_reference /= len(ref_ch)

In [None]:
# If they have different shapes, crop to the smallest size
import warnings

from skimage import transform

if stitched_stack_target.shape != stitched_stack_reference.shape:
    warnings.warn("Stitched stacks have different shapes. Padding to match.")
    stacks_shape = np.vstack(
        (stitched_stack_target.shape, stitched_stack_reference.shape)
    )
    fshape = np.min(stacks_shape, axis=0)
    stitched_stack_target = stitched_stack_target[: fshape[0], : fshape[1]]
    stitched_stack_reference = stitched_stack_reference[: fshape[0], : fshape[1]]
else:
    fshape = stitched_stack_target.shape


def prep_stack(stack, downsample):
    if stack.dtype != bool:
        ma = np.nanpercentile(stack, 99)
        stack = np.clip(stack, 0, ma)
        stack = stack / ma
    # downsample
    new_size = np.array(stack.shape) // downsample
    stack = transform.resize(stack, new_size)
    return stack


# setup common args for registration
kwargs = dict(
    angle_range=1.0,
    niter=3,
    nangles=11,
    upsample=False,
    debug=True,
    max_shift=ops["max_shift2ref"] // downsample,
    min_shift=0,
    reference=prep_stack(stitched_stack_reference, downsample),
    target=prep_stack(stitched_stack_target, downsample),
)
if use_masked_correlation:
    kwargs["target_mask"] = prep_stack(stitched_stack_target != 0, downsample)
    kwargs["reference_mask"] = prep_stack(stitched_stack_reference != 0, downsample)

In [None]:
plt.imshow(kwargs["reference"])

In [None]:
plt.imshow(kwargs["target"])

In [None]:
plt.imshow(kwargs["target"])
plt.ylim(4500, 2500)
plt.xlim(2000, 4500)

In [None]:
debug = True

out = issp.reg.estimate_rotation_translation(
    **kwargs,
)
if debug:
    angle, shift, debug_dict = out
else:
    angle, shift, scale = out
from image_tools.similarity_transforms import transform_image

trans_targ2 = transform_image(kwargs["target"], angle=angle, shift=shift, scale=1)

In [None]:
rgb = np.dstack([kwargs["reference"], trans_targ])
rgb = issp.vis.to_rgb(
    rgb,
    colors=[(1, 0, 0), (0, 1, 0)],
    vmin=0,
    vmax=np.nanpercentile(rgb, 99, axis=(0, 1)),
)
plt.figure(figsize=(10, 10))
plt.imshow(rgb)

In [None]:
rgb = np.dstack([kwargs["reference"], trans_targ2])
rgb = issp.vis.to_rgb(
    rgb,
    colors=[(1, 0, 0), (0, 1, 0)],
    vmin=0,
    vmax=np.nanpercentile(rgb, 99, axis=(0, 1)),
)
plt.figure(figsize=(10, 10))
plt.imshow(rgb)

In [None]:
if False:
    issp.pipeline.register_fluorescent_tile(
        data_path,
        tile_coors=(13, 0, 0),
        prefix=prefix,
        reference_prefix=None,
    )
issp.pipeline.correct_hyb_shifts(
    data_path,
    prefix,
)
job2 = issp.pipeline.diagnostics.check_shift_correction(
    data_path,
    prefix,
    roi_dimension_prefix=prefix,
    within=False,
)

In [None]:
stacks_shape

In [None]:
fshape

In [None]:
# Make stitched image
stitched_ref = issp.pipeline.stitch_registered(
    data_path, ref_prefix, roi=roi, channels=ref_channel
)
stitched_mcherry = issp.pipeline.stitch_registered(
    data_path, mcherry_prefix, roi=roi, channels=mcherry_channels
)

In [None]:
# plot a downsampled version

import cv2
import matplotlib.pyplot as plt
import numpy as np

downsample = 0.1

stitched_ref_downsampled = cv2.resize(
    stitched_ref, (0, 0), fx=downsample, fy=downsample
)
stitched_mcherry_downsampled = cv2.resize(
    stitched_mcherry, (0, 0), fx=downsample, fy=downsample
)

colors = [(1, 0, 0), (0, 1, 0), (1, 0, 1), (0, 0, 1)]
st = np.dstack([stitched_mcherry_downsampled, stitched_ref_downsampled])
vmax = np.nanpercentile(st, 99.9, axis=(0, 1))
vmin = np.nanpercentile(st, 0.01, axis=(0, 1))
rgb = issp.vis.to_rgb(st, colors, vmax=vmax, vmin=vmin)

fig, axes = plt.subplots(1, 1, figsize=(10, 10), squeeze=False)
axes[0, 0].imshow(rgb)
axes[0, 0].axis("off")
fig.tight_layout()

In [None]:
issp.pipeline.register.register_to_ref_using_stitched_registration(
    data_path,
    roi,
    reg_prefix="mCherry_1",
    ref_prefix=None,
    ref_channels=None,
    reg_channels=None,
    estimate_rotation=True,
    target_suffix=None,
    use_masked_correlation=True,
    downsample=5,
    save_plot=True,
)

In [None]:
stitched_mcherry_new = issp.pipeline.stitch_registered(
    data_path, mcherry_prefix, roi=roi, channels=mcherry_channels
)

In [None]:
from iss_preprocess.vis.utils import get_stack_part

xl = [11000, 18000]
yl = [8000, 15000]
st_p = get_stack_part(stitched_mcherry, xl, yl)
st_p = np.dstack([st_p, get_stack_part(stitched_ref, xl, yl)])
vmax = np.nanpercentile(st_p, 99, axis=(0, 1))
vmin = np.nanpercentile(st_p, 1, axis=(0, 1))
rgb = issp.vis.to_rgb(st_p, colors, vmax=vmax, vmin=vmin)
plt.figure(figsize=(10, 10))
plt.imshow(rgb, extent=[xl[0], xl[1], yl[1], yl[0]])
plt.axis("off")

In [None]:
from iss_preprocess.vis.utils import get_stack_part

xl = [11000, 18000]
yl = [8000, 15000]
st_p = get_stack_part(stitched_mcherry_new, xl, yl)
st_p = np.dstack([st_p, get_stack_part(stitched_ref, xl, yl)])
vmax = np.nanpercentile(st_p, 99, axis=(0, 1))
vmin = np.nanpercentile(st_p, 1, axis=(0, 1))
rgb = issp.vis.to_rgb(st_p, colors, vmax=vmax, vmin=vmin)
plt.figure(figsize=(10, 10))
plt.imshow(rgb, extent=[xl[0], xl[1], yl[1], yl[0]])
plt.axis("off")

In [None]:
colors

In [None]:
st_p_new = get_stack_part(stitched_mcherry_new[..., 0], xl, yl)
st_p = get_stack_part(stitched_mcherry[..., 0], xl, yl)
rgb = np.dstack([st_p_new, st_p])
rgb = issp.vis.to_rgb(rgb, colors[:2], vmax=vmax[0], vmin=vmin[0])
from iss_preprocess.vis.utils import get_spot_part

sp = get_spot_part(spots_df, xl, yl)
plt.figure(figsize=(20, 20))
plt.imshow(rgb, extent=[xl[0], xl[1], yl[1], yl[0]])
plt.scatter(sp["x"], sp["y"], c="cyan", s=10, alpha=0.5)
plt.imshow(rgb, extent=[xl[0], xl[1], yl[1], yl[0]])

In [None]:
data_path

In [None]:
reg_folder = issp.io.get_processed_path(data_path) / "reg"
t = [9, 3, 4]
fname = f"tforms_to_ref_{mcherry_prefix}_{t[0]}_{t[1]}_{t[2]}.npz"
new = np.load(reg_folder / fname)
old = np.load(reg_folder.parent / "temp_mcherry_reg" / fname)

In [None]:
old["matrix_between_channels"]

In [None]:
new["matrix_between_channels"]

In [None]:
import pandas as pd

rabies_spot = (
    issp.io.get_processed_path(data_path).parent / "error_corrected_barcodes_10.pkl"
)
spots_df = pd.read_pickle(rabies_spot)
spots_df = spots_df.query("chamber == @chamber and roi == @roi").copy()

spots_df.shape

In [None]:
from iss_preprocess.vis.utils import get_spot_part

st_p = get_stack_part(stitched_mcherry, xl, yl)
vmax = np.nanpercentile(st_p, 99, axis=(0, 1))
vmin = np.nanpercentile(st_p, 1, axis=(0, 1))
rgb = issp.vis.to_rgb(st_p, colors[:2], vmax=vmax, vmin=vmin)

sp = get_spot_part(spots_df, xl, yl)
plt.figure(figsize=(20, 20))
plt.imshow(rgb, extent=[xl[0], xl[1], yl[1], yl[0]])
plt.scatter(sp["x"], sp["y"], c="cyan", s=5, alpha=0.2)
plt.axis("off")