In [None]:
%load_ext autoreload
%autoreload 2

# Find duplicates across ROIs

Look for cells that are detected on multiple consecutive ROIs

In [None]:
import iss_preprocess as issp
import iss_analysis as issa
from iss_analysis.barcodes import barcodes as bar
from iss_analysis.barcodes.diagnostics import (
    plot_gmm_clusters,
    plot_error_along_sequence,
)
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path

In [None]:
project = "becalia_rabies_barseq"
mouse = "BRAC8498.3e"

error_correction_ds_name = "BRAC8498.3e_error_corrected_barcodes_10"
data_path = f"{project}/{mouse}"
analysis_folder = issp.io.get_processed_path(data_path) / "analysis"

# Get rabies data

In [None]:
slurm_folder = Path.home() / "slurm_logs" / "save_ara_info"
slurm_folder.mkdir(exist_ok=True)
if False:
    for chamber in [f"chamber_{i:02}" for i in range(7, 11)]:
        for roi in range(1, 11):
            print(f"Processing {chamber} {roi}")
            issa.barcodes.main.save_ara_info(
                project,
                mouse,
                chamber,
                roi,
                error_correction_ds_name,
                atlas_size=10,
                acronyms=True,
                full_scale=True,
                verbose=True,
                use_slurm=False,
                slurm_folder=slurm_folder,
                scripts_name=f"save_ara_info_{chamber}_{roi}",
            )

In [None]:
(
    rab_spot_df,
    rab_cells_barcodes,
    rabies_cell_properties,
) = issa.segment.get_barcode_in_cells(
    project,
    mouse,
    error_correction_ds_name,
    valid_chambers=None,
    save_folder=None,
    verbose=True,
    redo=False,
    add_ara_properties=True,
)

In [None]:
mcherry = issa.io.get_mcherry_cells(
    project, mouse, verbose=True, which="curated", prefix="mCherry_1"
)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(5, 5))

axes[0, 0].hist(mcherry["intensity_mean-0"], bins=100)
axes[0, 0].set_xlabel("mCherry intensity mean")
axes[0, 1].hist(mcherry["intensity_max-0"], bins=100)
axes[0, 1].set_xlabel("mCherry intensity max")
axes[1, 0].hist(mcherry["intensity_mean-1"], bins=100)
axes[1, 0].set_xlabel("Background intensity mean")
axes[1, 1].hist(mcherry["clamped_ratio"], bins=100)
axes[1, 1].set_xlabel("Ratio mCherry / Background")
fig.tight_layout()

In [None]:
(mcherry_cell_properties, rab_spot_df, rabies_cell_properties,
) = issa.segment.match_starter_to_barcodes(
    project,
    mouse,
    rabies_cell_properties,
    rab_spot_df,
    mcherry_cells=None,
    verbose=True,
    max_starter_distance=0.5,
    min_spot_number=1,
)

rabies_cell_properties["slice"] = (
    rabies_cell_properties.chamber
    + "_"
    + rabies_cell_properties.roi.map(lambda x: f"{x:02d}")
)
# add rotated ara coordinates
transform = issa.registration.ara_registration.get_ara_to_slice_rotation_matrix(
    spot_df=rab_spot_df
)
rabies_cell_properties = (
    issa.registration.ara_registration.rotate_ara_coordinate_to_slice(
        rabies_cell_properties, transform=transform
    )
)

In [None]:
ch = "chamber_08"
roi = 6
px_size = issp.io.get_pixel_size(f"{project}/{mouse}/{ch}")
mask_expansion = int(0.5/px_size)
mcherry_prefix = "mCherry_1"
mcherry_mask_5um = issp.pipeline.segment.get_cell_masks(
    f"{project}/{mouse}/{ch}",
    roi=roi,
    prefix=mcherry_prefix,
    projection="corrected",
    curated=True,
    mask_expansion=mask_expansion,
)
mcherry_mask_raw = issp.pipeline.segment.get_cell_masks(
    f"{project}/{mouse}/{ch}",
    roi=roi,
    prefix=mcherry_prefix,
    projection="corrected",
    curated=True,
    mask_expansion=0,
)
rab_spot_roi = rab_spot_df.query(f"chamber == '{ch}' and roi == {roi}")

In [None]:
ec = pd.read_pickle(analysis_folder.parent / f"error_corrected_barcodes_10.pkl")
ec_roi = ec.query(f"chamber == '{ch}' and roi == {roi}")

In [None]:
mcherry_cell_properties.columns

In [None]:
_ = mcherry_cell_properties.n_barcodes.hist()

In [None]:
_ = mcherry_cell_properties.n_barcode_spots.hist(bins=np.arange(35))

In [None]:
mcherry_cell_properties['n_presynaptic_cells'] = 0
mcherry_cell_properties['n_starter'] = 0
for c, cell_prop in mcherry_cell_properties.iterrows():
    valid_barcodes = cell_prop.valid_barcodes
    if not len(valid_barcodes):
        continue
    for bc in valid_barcodes:
        sp = rab_spot_df[rab_spot_df["corrected_bases"] == bc]
        cells = sp.mask_uid.unique()
        # remove NaN
        cells = [c for c in cells if isinstance(c, str)]
        starters = rabies_cell_properties.loc[cells].is_starter.sum()
        mcherry_cell_properties.loc[c, 'n_starter'] += starters
        mcherry_cell_properties.loc[c, 'n_presynaptic_cells'] += (len(cells) - starters)

In [None]:
mcherry_cell_properties.columns

In [None]:
mch_starter = mcherry_cell_properties[mcherry_cell_properties.n_barcode_spots>2]
props = [ 'intensity_mean-0', 'intensity_mean-1', 'clamped_ratio', 'n_presynaptic_cells', 'n_starter']

import seaborn as sns
sns.pairplot(mch_starter[props], hue='n_starter')

In [None]:
_ = mcherry_cell_properties['n_presynaptic_cells'].hist(bins=np.arange(300))

In [None]:
rabies_cell_properties.is_starter.sum()

In [None]:
starter = rabies_cell_properties.query("is_starter")
presynaptic = rabies_cell_properties.query("~is_starter")
print(f"Starter: {len(starter)}, Non-starter: {len(presynaptic)}")
print(f"Presynaptic per starter: {len(presynaptic) / len(starter)}")

In [None]:
mcherry_cells = issa.io.get_mcherry_cells(project, mouse, verbose=True)
starter_mcherry_cells = set()
for v in starter.mcherry_uid.values:
    starter_mcherry_cells = starter_mcherry_cells.union(set(v))
non_starter_mcherry_cell = set(mcherry_cells.mcherry_uid.values)
non_starter_mcherry_cell = non_starter_mcherry_cell - starter_mcherry_cells
print(
    f"Starter mcherry cells: {len(starter_mcherry_cells)}, Non-starter mcherry cells: {len(non_starter_mcherry_cell)}"
)

In [None]:
bc_per_starter = starter.n_unique_barcodes
bc_per_non_starter = presynaptic.n_unique_barcodes

fig = plt.figure(figsize=(7, 3))
ax = fig.add_subplot(1, 2, 1)
ax.hist(bc_per_starter.values, bins=np.arange(1, 8) - 0.5)
ax.set_title("Starter cells")
ax.set_ylabel("Number of cells")
ax.set_xlabel("Number of barcodes per cell")
ax = plt.subplot(122)
ax.hist(bc_per_non_starter.values, bins=np.arange(1, 8) - 0.5)
ax.set_title("Presynaptic cells")
ax.set_xlabel("Number of barcodes per cell")
ax.set_ylabel("Number of cells")
fig.tight_layout()
fig.savefig(analysis_folder / "barcode_per_cell.svg")

# Register slices

In [None]:
from iss_analysis.registration import register_serial_sections

# To reload the data, set reload=True and use_slurm=False
res = register_serial_sections.register_all_serial_sections(
    project=project,
    mouse=mouse,
    error_correction_ds_name=error_correction_ds_name,
    correlation_window_size=500,
    min_spots=10,
    max_barcode_number=50,
    gaussian_width=20,
    n_workers=20,
    verbose=True,
    use_slurm=False,
    reload=True,
    slice_window=(-1, 3),
)

In [None]:
# Try to register mCherry cells using the rabies registration
import cv2


chamber = "chamber_07"
roi = 6

In [None]:
mcherry_mask = issp.pipeline.get_cell_masks(
    data_path=f"{project}/{mouse}/{chamber}", roi=roi, prefix="mCherry_1", curated=True
)
df_fname = issp.io.get_processed_path(f"{project}/{mouse}/{chamber}") / "cells"
df_fname /= "mCherry_1_cells/mCherry_1_df_corrected_curated.pkl"
mcherry_df_all_rois = pd.read_pickle(df_fname)
mcherry_df_ref = mcherry_df_all_rois.query(f"roi == {roi}").copy()
mcherry_df_next = mcherry_df_all_rois.query(f"roi == {roi + 1}").copy()
mcherry_df_previous = mcherry_df_all_rois.query(f"roi == {roi - 1}").copy()
print(f"Found {len(mcherry_df_ref)} cells")
masks = np.unique(mcherry_mask)
assert len(masks) == (
    len(mcherry_df_ref) + 1
), f"{len(masks)} != {len(mcherry_df_ref) + 1}"

In [None]:
chamber

In [None]:
roi

In [None]:
# To register the mCherry cells, we need to get the rabies shift interpolation
(
    res,
    y_shift_interpolator_next,
    z_shift_interpolator_next,
) = issa.registration.register_serial_sections.interpolate_shifts(
    project,
    mouse,
    ref_slice=f"{chamber}_{roi:02d}",
    target_position="next",
    error_correction_ds_name=error_correction_ds_name,
    threshold=400,
    smoothing=10,
    vis=False,
)
(
    res,
    y_shift_interpolator_prev,
    z_shift_interpolator_prev,
) = issa.registration.register_serial_sections.interpolate_shifts(
    project,
    mouse,
    ref_slice=f"{chamber}_{roi:02d}",
    target_position="previous",
    error_correction_ds_name=error_correction_ds_name,
    threshold=400,
    smoothing=10,
    vis=False,
)

In [None]:
# add ara coordinates to mCherry cells
from iss_analysis.registration import ara_registration

transform = ara_registration.get_ara_to_slice_rotation_matrix(
    spot_df=rab_spot_df, verbose=False
)
ara_registration.rotate_ara_coordinate_to_slice(
    rab_spot_df, transform=transform, verbose=False
)

for df, roi2use in zip(
    [mcherry_df_ref, mcherry_df_next, mcherry_df_previous], [roi, roi + 1, roi - 1]
):
    df["x"] = df["centroid-1"]
    df["y"] = df["centroid-0"]
    issp.pipeline.ara_registration.spots_ara_infos(
        data_path=f"{project}/{mouse}/{chamber}",
        spots=df,
        roi=roi2use,
        atlas_size=10,
        acronyms=True,
        inplace=True,
        full_scale_coordinates=False,
        reload=True,
        verbose=True,
    )
    ara_registration.rotate_ara_coordinate_to_slice(
        df, transform=transform, verbose=False
    )
mcherry_df_ref.head()

In [None]:
rb_ref_df = rab_spot_df.query(f"slice == '{chamber}_{roi:02d}'").copy()
rb_next_df = rab_spot_df.query(f"slice == '{chamber}_{roi+1:02d}'").copy()

zsh = z_shift_interpolator_next(rb_next_df[["ara_y_rot", "ara_z_rot"]].values)
ysh = y_shift_interpolator_next(rb_next_df[["ara_y_rot", "ara_z_rot"]].values)

kw = dict(s=1, alpha=0.1)
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].scatter(rb_ref_df["ara_z_rot"], rb_ref_df["ara_y_rot"], c="C0", **kw)
axes[0].scatter(rb_next_df["ara_z_rot"], rb_next_df["ara_y_rot"], c="C1", **kw)
axes[1].scatter(rb_ref_df["ara_z_rot"], rb_ref_df["ara_y_rot"], c="C0", **kw)
axes[1].scatter(
    rb_next_df["ara_z_rot"] + zsh / 1000,
    rb_next_df["ara_y_rot"] + ysh / 1000,
    c="C1",
    **kw,
)
fig.tight_layout()


for ax in axes:
    # ax.scatter(*center_point, color="k", s=50)
    ax.set_aspect("equal")
    ax.set_xlim(10, 5.5)
    ax.set_ylim(6.5, 2)
    ax.set_xticks([])
    ax.set_yticks([])
axes[0].set_title("Original")
axes[1].set_title("Registered")

In [None]:
rb_ref_df = rab_spot_df.query(f"slice == '{chamber}_{roi:02d}'").copy()
rb_next_df = rab_spot_df.query(f"slice == '{chamber}_{roi+1:02d}'").copy()

zsh = z_shift_interpolator_next(rb_next_df[["ara_y_rot", "ara_z_rot"]].values)
ysh = y_shift_interpolator_next(rb_next_df[["ara_y_rot", "ara_z_rot"]].values)

kw = dict(s=10, alpha=0.4, edgecolors="none")
fig, ax = plt.subplots(1, 1, figsize=(20, 20))
ax.scatter(rb_ref_df["ara_z_rot"], rb_ref_df["ara_y_rot"], c="C0", **kw)
ax.scatter(
    rb_next_df["ara_z_rot"] + zsh / 1000,
    rb_next_df["ara_y_rot"] + ysh / 1000,
    c="C1",
    **kw,
)
fig.tight_layout()
ax.set_aspect("equal")
ax.set_xlim(8.5, 7.5)
ax.set_ylim(3.5, 3)

In [None]:
# Same but with spot image
shift_it = True

temp_df = rb_ref_df[["ara_y_rot", "ara_z_rot"]].copy()
temp_df *= 1000
ref_img = issp.segment.spots.make_spot_image(
    temp_df,
    gaussian_width=5,
    dtype="single",
    output_shape=None,
    x_col="ara_y_rot",
    y_col="ara_z_rot",
)
temp_df = rb_next_df[["ara_y_rot", "ara_z_rot"]].copy()
temp_df *= 1000
if shift_it:
    zsh = z_shift_interpolator_next(rb_next_df[["ara_y_rot", "ara_z_rot"]].values)
    ysh = y_shift_interpolator_next(rb_next_df[["ara_y_rot", "ara_z_rot"]].values)
    temp_df["ara_y_rot"] += ysh
    temp_df["ara_z_rot"] += zsh

next_img = issp.segment.spots.make_spot_image(
    temp_df,
    gaussian_width=5,
    dtype="single",
    output_shape=None,
    x_col="ara_y_rot",
    y_col="ara_z_rot",
)

# keep only the left hemisphere
ref_img = ref_img[5500:10000, 2000:6000]
next_img = next_img[5500:10000, 2000:6000]
st = np.dstack([ref_img.T, next_img.T])
rgb = issp.vis.to_rgb(st, colors=[(1, 0, 0), (0, 1, 0)], vmax=[10, 10])

fig, axes = plt.subplots(2, 2, figsize=(20, 20))
im = axes[0, 0].imshow(ref_img.T, cmap="inferno", vmax=10)
axes[0, 1].imshow(next_img.T, cmap="inferno", vmax=10)
axes[1, 0].imshow(rgb)
axes[1, 1].imshow(rgb)
axes[1, 1].set_xlim(2000, 3000)
axes[1, 1].set_ylim(2000, 500)
fig.tight_layout()

## Check serial section registration

Plot around one random spot to chekc that everything makes sense

In [None]:
if False:
    center_point = np.array([3.5, 8.5])
    fig = issa.vis.diagnostics.check_serial_registration(
        cell_info=pd.Series(
            data=dict(ara_y_rot=center_point[0], ara_z_rot=center_point[1])
        ),
        ref_slice=f"{chamber}_{roi:02d}",
        target_slice=f"{chamber}_{roi+1:02d}",
        rab_spot_df=rab_spot_df,
        rabies_cell_properties=rabies_cell_properties,
        window_size=500,
        min_spots=10,
        max_barcode_number=50,
        gaussian_width=10,
        spots_kwargs=None,
        shifts_to_use=None,
    )

# Register mCherry cells

In [None]:
interpolated_shift_z_next = z_shift_interpolator_next(
    mcherry_df_next[["ara_y_rot", "ara_z_rot"]].values
)
interpolated_shift_y_next = y_shift_interpolator_next(
    mcherry_df_next[["ara_y_rot", "ara_z_rot"]].values
)
interpolated_shift_z_previous = z_shift_interpolator_prev(
    mcherry_df_previous[["ara_y_rot", "ara_z_rot"]].values
)
interpolated_shift_y_previous = y_shift_interpolator_prev(
    mcherry_df_previous[["ara_y_rot", "ara_z_rot"]].values
)

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(20, 20))

pts_size = 100

# plot background rabies spots from ref slice
rb_df = rab_spot_df.query(f"slice == '{chamber}_{roi:02d}'").copy()
for ax in axes:
    ax.scatter(
        rb_df["ara_y_rot"],
        rb_df["ara_z_rot"],
        label="rabies",
        s=1,
        color="k",
        alpha=0.1,
    )
    ax.scatter(
        mcherry_df_ref["ara_y_rot"],
        mcherry_df_ref["ara_z_rot"],
        label="ref",
        s=pts_size,
        alpha=0.5,
    )

# plot mcherry cells from next and previous slices, raw and shifted
label = ["previous", "next"]
for i, (w, zsh, ysh) in enumerate(
    zip(
        [mcherry_df_previous, mcherry_df_next],
        (interpolated_shift_z_previous, interpolated_shift_z_next),
        (interpolated_shift_y_previous, interpolated_shift_y_next),
    )
):
    coords = w[["ara_y_rot", "ara_z_rot"]].values
    shifted = coords.copy()
    shifted[:, 0] += ysh / 1000
    shifted[:, 1] += zsh / 1000
    axes[0].scatter(
        coords[:, 0],
        coords[:, 1],
        label=f"{label[i]} raw",
        s=pts_size,
        alpha=0.5,
        color=f"C{i+1}",
    )
    axes[1].scatter(
        shifted[:, 0],
        shifted[:, 1],
        label=f"{label[i]} shifted",
        s=pts_size,
        alpha=0.5,
        color=f"C{i+1}",
    )

for ax in axes:
    ax.legend()
    ax.set_aspect("equal")
    ax.set_xlim(2.5, 4)
    ax.set_ylim(7.5, 8.5)
fig.tight_layout()

In [None]:
# Get  the stitched mCherry image
if True:
    stitched_stacks = {}
    ops = issp.io.load_ops(f"{project}/{mouse}/{chamber}")
    mch_chans = [ops["mcherry_signal_channel"], ops["mcherry_background_channel"]]
    for i, lab in enumerate(["previous", "reference", "next"]):
        print(f"Processing {lab}")
        mcherry_full_stack = stitch_registered(
            data_path=f"{project}/{mouse}/{chamber}",
            roi=roi - 1 + i,
            prefix="mCherry_1",
            channels=mch_chans,
        )
        stitched_stacks[lab] = mcherry_full_stack

In [None]:
transformed_ref = issa.registration.ara_registration.transform_stack_to_ara(
    project,
    mouse,
    chamber,
    roi,
    prefix="mCherry_1",
    channels=[0, 1],  # because we give the stack
    error_correction_ds_name=error_correction_ds_name,
    output_px_size=2,
    interpolate=True,
    output_folder=None,
    ara_zshifts_interpolator=None,
    ara_yshifts_interpolator=None,
    full_stack=stitched_stacks["reference"],
)

In [None]:
plt.figure(figsize=(20, 20))
part = [0, 3000, 0, 5000]
plt.imshow(transformed_ref[part[0] : part[1], part[2] : part[3], 0], vmax=100)

In [None]:
# Get the transformed next stack, without shift
transformed_next_raw = issa.registration.ara_registration.transform_stack_to_ara(
    project,
    mouse,
    chamber,
    roi,
    prefix="mCherry_1",
    channels=[0, 1],  # because we give the stack
    error_correction_ds_name=error_correction_ds_name,
    output_px_size=2,
    interpolate=True,
    output_folder=None,
    ara_zshifts_interpolator=None,
    ara_yshifts_interpolator=None,
    full_stack=stitched_stacks["next"],
)

In [None]:
# Same but apply shifts
transformed_next_shifted = issa.registration.ara_registration.transform_stack_to_ara(
    project,
    mouse,
    chamber,
    roi,
    prefix="mCherry_1",
    channels=[0, 1],  # because we give the stack
    error_correction_ds_name=error_correction_ds_name,
    output_px_size=2,
    interpolate=True,
    output_folder=None,
    ara_zshifts_interpolator=z_shift_interpolator_next,
    ara_yshifts_interpolator=y_shift_interpolator_next,
    full_stack=stitched_stacks["next"],
)

In [None]:
folder = issp.io.get_processed_path(f"{project}/{mouse}") / "analysis"
folder /= "rotated_stacks"
folder /= f"serial_sec_reg"
folder.mkdir(exist_ok=True, parents=True)
issp.io.write_stack(
    transformed_next_raw, folder / f"mCherry_1_{chamber}_{roi}_next_raw.tif", compress=1
)
issp.io.write_stack(
    transformed_next_shifted,
    folder / f"mCherry_1_{chamber}_{roi}_next_shifted.tif",
    compress=1,
)
issp.io.write_stack(
    transformed_ref, folder / f"mCherry_1_{chamber}_{roi}_ref.tif", compress=1
)

In [None]:
# Plot overlay or ref and next for both raw and shifted
part = slice(800, 2200), slice(3000, 4800)

st = np.dstack(
    [transformed_ref[part[0], part[1], 0], transformed_next_raw[part[0], part[1], 0]]
)
rgb_raw = issp.vis.to_rgb(
    st, colors=[(1, 0, 0), (0, 1, 0)], vmin=[0, 0], vmax=[250, 250]
)
st[..., 1] = transformed_next_shifted[part[0], part[1], 0]
rgb_shifted = issp.vis.to_rgb(
    st, colors=[(1, 0, 0), (0, 1, 0)], vmin=[0, 0], vmax=[250, 250]
)

fig, axes = plt.subplots(2, 1, figsize=(20, 40))
axes[0].imshow(rgb_raw)
axes[0].set_title("Raw")
axes[1].imshow(rgb_shifted)
axes[1].set_title("Shifted")

In [None]:
target = mcherry_df_ref.query(
    "ara_z_rot < 8.25 and ara_z_rot > 8 and ara_y_rot < 3.4 and ara_y_rot > 3.2"
).iloc[0]
target[["y", "x"]].values

In [None]:
window = 3000

mcherry_ds = dict(
    previous=mcherry_df_previous, reference=mcherry_df_ref, next=mcherry_df_next
)


ara_coord = target[["ara_z_rot", "ara_y_rot"]].values.astype(float)
y_shifts = dict(previous=interpolated_shift_y_previous, next=interpolated_shift_y_next)
z_shifts = dict(previous=interpolated_shift_z_previous, next=interpolated_shift_z_next)

fig, axes = plt.subplots(2, 3, figsize=(20, 15))
for i, lab in enumerate(["previous", "reference", "next"]):
    df = mcherry_ds[lab]
    df_ara_coords = df[["ara_z_rot", "ara_y_rot"]].values.astype(float)
    distances = np.linalg.norm(df_ara_coords - ara_coord, axis=1)
    idx = np.argmin(distances)
    closest = df.iloc[idx]
    coord = np.round(closest[["y", "x"]].values.astype(float)).astype(int)
    data = stitched_stacks[lab][
        coord[0] - window : coord[0] + window, coord[1] - window : coord[1] + window, :
    ]
    rgb = issp.vis.to_rgb(
        data, colors=[(1, 0, 0), (0, 1, 0)], vmin=[0, 0], vmax=[200, 100]
    )
    axes[0, i].imshow(rgb)
    if i == 1:
        axes[1, 1].imshow(rgb)
    else:
        df_ara_coords = np.array(df[["ara_z_rot", "ara_y_rot"]].values, dtype=float)
        df_ara_coords[:, 0] += z_shifts[lab] / 1000
        df_ara_coords[:, 1] += y_shifts[lab] / 1000
        distances = np.linalg.norm(df_ara_coords - ara_coord, axis=1)
        idx = np.argmin(distances)
        closest = df.iloc[idx]
        coord = np.round(closest[["y", "x"]].values.astype(float)).astype(int)
        data = stitched_stacks[lab][
            coord[0] - window : coord[0] + window,
            coord[1] - window : coord[1] + window,
            :,
        ]
        rgb = issp.vis.to_rgb(
            data, colors=[(1, 0, 0), (0, 1, 0)], vmin=[0, 0], vmax=[200, 100]
        )
        axes[1, i].imshow(rgb)
fig.tight_layout()

In [None]:
if False:
    chamber = "chamber_10"
    slurm_folder = Path.home() / "slurm_logs" / "register_mcherry_cells"
    slurm_folder.mkdir(exist_ok=True)
    issp.pipeline.segment.save_curated_dataframes(
        data_path=f"{project}/{mouse}/{chamber}",
        prefix="mCherry_1",
        intensity_channels=None,
        rois=None,
        mask_expansion=None,
        use_slurm=True,
        slurm_folder=slurm_folder,
        scripts_name=f"save_curated_mcherry_{chamber}",
    )

In [None]:
"" or ""

In [None]:
m = cv2.resize(
    mcherry_mask.astype("int16"),
    (0, 0),
    fx=0.3,
    fy=0.3,
    interpolation=cv2.INTER_NEAREST,
).astype(float)
print(np.unique(m))
print(m.shape)
m[m == 0] = np.nan
plt.imshow(m % 20, cmap="tab20", interpolation="none")

In [None]:
ref_chamber = "chamber_07"
ref_roi = 1
df = res[0]["next"]
cells_in_ref = rabies_cell_properties.query(
    f"chamber == '{ref_chamber}' and roi == {ref_roi}"
).copy()
cell_coords = cells_in_ref.loc[df.index, ["x", "y"]].values
shifts = df[["shift_y", "shift_z"]].values
shift_ampl = np.linalg.norm(shifts, axis=1)

In [None]:
smooth_shifts, y_shift_interpolator, z_shift_interpolator = interpolate_shifts(
    cell_coords, shifts, threshold, smoothing=smoothing
)

xlims = [cell_coords[:, 0].min(), cell_coords[:, 0].max()]
ylims = [cell_coords[:, 1].min(), cell_coords[:, 1].max()]
print(f"X limits: {xlims}")
print(f"Y limits: {ylims}")
grid = np.mgrid[xlims[0] : xlims[1] : 500, ylims[0] : ylims[1] : 500]
grid_flat = grid.reshape(2, -1).T

smooth_y_shifts = y_shift_interpolator(grid_flat)
smooth_z_shifts = z_shift_interpolator(grid_flat)
smooth_grid = np.stack([smooth_y_shifts, smooth_z_shifts], axis=1)
smooth_grid_ampl = np.linalg.norm(smooth_grid, axis=1)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20, 10))
# axes[0].scatter(cell_coords[:, 0], cell_coords[:, 1], c='k')
ax = axes[0]
ax.set_aspect("equal")
qu = ax.quiver(
    cell_coords[:, 0],
    cell_coords[:, 1],
    shifts[:, 1],
    shifts[:, 0],
    shift_ampl,
    angles="xy",
    scale_units="xy",
    scale=1,
    cmap="inferno",
    clim=[0, threshold],
    edgecolors="k",
)
plt.colorbar(qu, ax=ax)
ax.set_title("Raw Shifts")

ax = axes[1]
ax.set_aspect("equal")
smooth_shifts_ampl = np.linalg.norm(smooth_shifts, axis=1)
qu = ax.quiver(
    cell_coords[:, 0],
    cell_coords[:, 1],
    smooth_shifts[:, 1],
    smooth_shifts[:, 0],
    smooth_shifts_ampl,
    angles="xy",
    scale_units="xy",
    scale=1,
    cmap="inferno",
    clim=[0, threshold],
    edgecolors="k",
)
plt.colorbar(qu, ax=ax)
ax.set_title("Smoothed shifts")

In [None]:
delta = smooth_shifts - shifts
delta_ampl = np.linalg.norm(delta, axis=1)
print(np.nanmedian(delta_ampl))
plt.subplot(1, 1, 1, aspect="equal")
plt.quiver(
    cell_coords[:, 0],
    cell_coords[:, 1],
    delta[:, 1],
    delta[:, 0],
    angles="xy",
    scale_units="xy",
    scale=0.1,
)

In [None]:
cells_in_ref.loc[df.index]

In [None]:
df.loc[cells_in_ref.index]

In [None]:
cells_in_ref["shifts_z"] = np.nan
cells_in_ref["shifts_y"] = np.nan

cell_shifts = df.loc[cells_in_ref.index]
cell_shifts.head()

In [None]:
txt = []
for k, v in res.items():
    if type(v) == dict:
        txt.append(f"{k:02}: done")
    else:
        txt.append(f"{k}: TODO")
print(". ".join(txt[:10]))
print(". ".join(txt[10:20]))
print(". ".join(txt[20:30]))
print(". ".join(txt[30:40]))

In [None]:
from iss_analysis.vis import diagnostics

# debug plot around one cell
ref_slice = "chamber_09_03"
target_slice = "chamber_09_04"

ref_cells = rabies_cell_properties.query(f"slice == '{ref_slice}'")
cell_info = ref_cells.iloc[100]

window_size = 500
min_spots = 10
max_barcode_number = 50
gaussian_width = 30

fig = diagnostics.check_serial_registration(
    cell_info,
    ref_slice,
    target_slice,
    rab_spot_df,
    rabies_cell_properties,
    window_size=300,
    min_spots=10,
    max_barcode_number=50,
    gaussian_width=30,
    shifts_to_use=None,
)

# Compute cell distances

For each cell we will register aroudn the cell to the next slice and to the n+2 slice.
Then transform points and find closest cell with same barcode.

In [None]:
matching_threshold = 50.0 / 1000


rabies_cell_properties["slice"] = (
    rabies_cell_properties.chamber
    + "_"
    + rabies_cell_properties.roi.map(lambda x: f"{x:02d}")
)
section_infos = issa.io.get_sections_info(project, mouse)
distance2previous, distance2next, distancewithin = [
    pd.Series(
        np.nan + np.empty(rabies_cell_properties.shape[0], dtype=float),
        index=rabies_cell_properties.index,
    )
    for i in range(3)
]
all_dist_within = []
all_dist_next_prev = {"next": [], "previous": []}
matches = []
for slice_id, slice_info in section_infos.iterrows():
    if slice_id not in res:
        raise IndexError(f"Slice {slice_id} not found in res")
        print(f"WARN: slice {slice_id} not found in res")
        continue
    reg_res = res[slice_id]
    slice = f"{slice_info['chamber']}_{slice_info['roi']:02d}"
    if slice not in rabies_cell_properties.slice.unique():
        raise ValueError(f"Slice {slice} not found in rabies_cell_properties")
    slice_cells = rabies_cell_properties.query(f"slice == '{slice}'")
    for icell, cellinfo in slice_cells.iterrows():
        bc = cellinfo["all_barcodes"]
        valid = slice_cells.all_barcodes.map(lambda x: any([b in bc for b in x]))
        valid = valid & (slice_cells.index != icell)
        valid = slice_cells[valid]
        if not valid.shape[0]:
            distance = np.inf
            all_dist_within.append([np.inf])
        else:
            distance = np.linalg.norm(
                valid[["ara_y_rot", "ara_z_rot"]].values.astype(float)
                - cellinfo[["ara_y_rot", "ara_z_rot"]].values.astype(float),
                axis=1,
            )
            if any(distance < matching_threshold):
                matching = valid[distance < matching_threshold]
                matching_dst = distance[distance < matching_threshold]
                for match_index, (icell_match, cell_match) in enumerate(
                    matching.iterrows()
                ):
                    matches.append(
                        dict(
                            cell=icell,
                            match=icell_match,
                            distance=matching_dst[match_index],
                            slice_type="current",
                        )
                    )

            all_dist_within.append(distance)
            distance = distance.min()
        distancewithin[icell] = distance
    for name, output in zip(["next", "previous"], [distance2next, distance2previous]):
        if name not in reg_res:
            continue
        reg = reg_res[name]
        slice_change = 1 if name == "next" else -1
        prev_slice = f"{section_infos.loc[slice_id + slice_change]['chamber']}_{section_infos.loc[slice_id + slice_change]['roi']:02d}"
        prev_cells = rabies_cell_properties.query(f"slice == '{prev_slice}'")
        for irow, (icell, cellinfo) in enumerate(slice_cells.iterrows()):
            bc = cellinfo["all_barcodes"]
            valid = prev_cells.all_barcodes.map(lambda x: any([b in bc for b in x]))
            valid = prev_cells[valid]
            if not valid.shape[0]:
                distance = np.inf
            else:
                # TODO change to index when jobs have been rerun
                # shift = reg.loc[icell]
                shift = (
                    reg.iloc[irow][["shift_y", "shift_z"]].values.astype(float) / 1000
                )
                if any(np.isnan(shift)):
                    distance = np.inf
                    all_dist_next_prev[name].append([np.inf])
                else:
                    valid_pos = (
                        valid[["ara_y_rot", "ara_z_rot"]].values.astype(float) + shift
                    )
                    distance = np.linalg.norm(
                        valid_pos
                        - cellinfo[["ara_y_rot", "ara_z_rot"]].values.astype(float),
                        axis=1,
                    )
                    if any(distance < matching_threshold):
                        matching = valid[distance < matching_threshold]
                        matching_dst = distance[distance < matching_threshold]
                        for match_index, (icell_match, cell_match) in enumerate(
                            matching.iterrows()
                        ):
                            matches.append(
                                dict(
                                    cell=icell,
                                    match=icell_match,
                                    distance=matching_dst[match_index],
                                    slice_type=name,
                                )
                            )
                    all_dist_next_prev[name].append(distance)
                    distance = distance.min()
            output[icell] = distance
matches = pd.DataFrame(matches)
print(f"Found {matches.shape[0]} matches")

In [None]:
matches.slice_type.value_counts()

In [None]:
import seaborn as sns

distances = pd.DataFrame(
    dict(
        distance=np.hstack(
            [distance2previous.values, distance2next.values, distancewithin.values]
        )
        * 1000,
        type=np.hstack(
            [
                ["previous"] * distance2previous.shape[0],
                ["next"] * distance2next.shape[0],
                ["within"] * distancewithin.shape[0],
            ]
        ),
        cell=np.hstack(
            [distance2previous.index, distance2next.index, distancewithin.index]
        ),
    )
)
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(2, 1, 1)
sns.histplot(
    distances, x="distance", hue="type", bins=np.arange(0, 1500, 5), element="step"
)
ax = fig.add_subplot(2, 1, 2)
sns.histplot(
    distances, x="distance", hue="type", bins=np.arange(0, 100, 1), element="step"
)
for x in fig.axes:
    x.set_xlabel("Distance (µm)")
fig.tight_layout()
print(distances[np.isfinite(distances.distance)].shape[0])

In [None]:
import seaborn as sns

dw = np.hstack(all_dist_within)
dnext = np.hstack(all_dist_next_prev["next"])
dprev = np.hstack(all_dist_next_prev["previous"])
distances = pd.DataFrame(
    dict(
        distance=np.hstack([dw, dnext, dprev]) * 1000,
        type=np.hstack(
            [
                ["within"] * len(dw),
                ["next"] * len(dnext),
                ["previous"] * len(dprev),
            ]
        ),
    )
)
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(2, 1, 1)
sns.histplot(
    distances, x="distance", hue="type", bins=np.arange(0, 1500, 5), element="step"
)
ax = fig.add_subplot(2, 1, 2)
sns.histplot(
    distances, x="distance", hue="type", bins=np.arange(0, 100, 1), element="step"
)
for x in fig.axes:
    x.set_xlabel("Distance (µm)")
fig.tight_layout()
print(distances.shape[0])

In [None]:
_ = plt.hist(
    distance2previous[np.isfinite(distance2previous)],
    bins=np.arange(0, 0.1, 0.001),
    cumulative=True,
    histtype="step",
)

In [None]:
print(f"{(distance2previous < 20e-3).sum()} / {distance2previous.shape[0]}")

In [None]:
rabies_cell_properties.head()

# Create a allen cordinate version of mCherry.


In [None]:
bad_parts = np.abs(ara_coords).sum(axis=2) == 0
if bad_parts.sum():
    print(f"Found {bad_parts.sum()} parts of the ARA image with no data")
ara_coords[bad_parts, :] = np.nan
plt.imshow(np.abs(ara_coords).sum(axis=2))

In [None]:
data_path = f"{project}/{mouse}/chamber_08"
roi = 1
ara_coords = issp.pipeline.ara_registration.load_coordinate_image(
    data_path, roi, full_scale=False
)
plt.imshow(ara_coords[..., 0])

In [None]:
save_folder = (
    issp.io.get_processed_path(f"{project}/{mouse}") / "analysis" / "rotated_stacks"
)
slurm_folder = Path.home() / "slurm_logs" / "ara_registration"
slurm_folder.mkdir(exist_ok=True)
ops = issp.io.load_ops(f"{project}/{mouse}/chamber_07")
metadata = issp.pipeline.ara_registration.load_registration_reference_metadata(
    f"{project}/{mouse}/chamber_07", roi=1
)
prefix = "mCherry_1"  # ops["reference_prefix"],

for chamber in [f"chamber_{i:02d}" for i in range(7, 11)]:
    for roi in range(1, 11):
        rotated_stack = issa.registration.ara_registration.transform_stack_to_ara(
            project,
            mouse,
            chamber,
            roi=roi,
            prefix=prefix,
            channels=[ops["mcherry_signal_channel"], ops["mcherry_background_channel"]],
            error_correction_ds_name=error_correction_ds_name,
            output_folder=save_folder,
            output_px_size=metadata["pixel_size"],
            use_slurm=False,
            slurm_folder=slurm_folder,
            interpolate=True,
            scripts_name=f"trans2ara_{prefix}_{chamber}_{roi}",
        )

In [None]:
chamber = "chamber_09"
roi = 3
prefix = "mCherry_1"
channels = [1, 2]
error_correction_ds_name = error_correction_ds_name
output_folder = save_folder
output_px_size = 3

In [None]:
rotated_stack = issa.registration.ara_registration.transform_stack_to_ara(
    project,
    mouse,
    chamber,
    roi=roi,
    prefix="mCherry_1",
    channels=[1, 2],
    error_correction_ds_name=error_correction_ds_name,
    output_folder=save_folder,
    output_px_size=5,
    use_slurm=False,
    slurm_folder=slurm_folder,
    interpolate=True,
    scripts_name=f"trans2ara_{chamber}_{roi}",
)

In [None]:
plt.imshow(rgbit(rotated_stack))

In [None]:
from iss_preprocess.io.save import write_stack
from iss_preprocess.io.load import get_processed_path
from iss_preprocess.pipeline.ara_registration import (
    load_coordinate_image,
    make_area_image,
)
from iss_preprocess.pipeline.stitch import stitch_registered
from skimage.transform import resize

data_path = get_processed_path(f"{project}/{mouse}/{chamber}")
full_stack = stitch_registered(
    data_path,
    prefix=prefix,
    roi=roi,
    channels=channels,
)

In [None]:
ara_coords = load_coordinate_image(data_path, roi, full_scale=False)

In [None]:
# downsample stack to ara coordinates shape
stack = np.empty((ara_coords.shape[0], ara_coords.shape[1], full_stack.shape[-1]))
for i in range(full_stack.shape[-1]):
    stack[..., i] = resize(full_stack[..., i], ara_coords.shape[:2])

In [None]:
im = plt.imshow(ara_coords[..., 0], cmap="coolwarm", vmin=6)
plt.imshow(stack[..., 1], vmax=50, alpha=0.5, cmap="gray")
plt.colorbar(im)

In [None]:
import brainglobe_atlasapi as bga

atlas_name = "allen_mouse_10um"
bg_atlas = bga.bg_atlas.BrainGlobeAtlas(atlas_name)

In [None]:
ara_shape_mm = np.array(bg_atlas.shape_um) / 1000
print(ara_shape_mm)
(np.array(bg_atlas.shape_um) / 1000) @ transform

In [None]:
ara_coords.min(axis=(0, 1))

In [None]:
# find the main ara plane of the data using rabies spots
ara_info_folder = get_processed_path(f"{project}/{mouse}") / "analysis" / "ara_infos"
target = (
    ara_info_folder
    / f"{error_correction_ds_name}_{chamber}_{roi}_rabies_spots_ara_info.pkl"
)
transform = issa.registration.ara_registration.get_ara_to_slice_rotation_matrix(
    spot_df=pd.read_pickle(target)
)

# rotate the ara coordinates
ara_coords_rot = ara_coords.reshape(-1, 3) @ transform
ara_coords_rot = ara_coords_rot.reshape(ara_coords.shape)

# transform the stack to ara coordinates
shapes = np.vstack([stack.shape[:2], ara_coords_rot.shape[:2]])
if np.any(np.diff(shapes, axis=0)):
    raise ValueError("Stack and ara_coords_rot must have the same shape")

area_img = make_area_image(data_path, roi, full_scale=False)
ara_lim = np.vstack(
    [ara_coords_rot[area_img > 0].min(axis=0), ara_coords_rot[area_img > 0].max(axis=0)]
)
width, height = np.diff(ara_lim[:, 1:], axis=0).squeeze()
w_px, h_px = np.round(np.array([width, height]) * 1000 / output_px_size).astype(int) + 1
rotated_stack = np.zeros((h_px, w_px, 2))
target_px = np.round((ara_coords_rot - ara_lim[0, :]) * 1000 / output_px_size).astype(
    int
)
target_px = target_px[..., [1, 2]]
for i in range(2):
    target_px[..., i] = np.clip(target_px[..., i], 0, rotated_stack.shape[i] - 1)

In [None]:
print(target_px.shape)
rotated_stack[target_px[..., 0], target_px[..., 1], :] = stack
print(rotated_stack.shape)

In [None]:
def rgbit(img, percentiles=(0.2, 99.9)):
    valid = np.all(img > 0, axis=-1)
    vmin, vmax = np.nanpercentile(img[valid, :], percentiles, axis=0)
    rgb = issp.vis.to_rgb(img, colors=[(1, 0, 0), (0, 1, 0)], vmin=vmin, vmax=vmax)
    return rgb


plt.imshow(rgbit(rotated_stack))

In [None]:
grid = np.meshgrid(np.arange(h_px), np.arange(w_px), indexing="ij")
grid[0].shape

In [None]:
print(target_px.shape)

In [None]:
from scipy.interpolate import NearestNDInterpolator

out = np.zeros((h_px, w_px, 2))
for axis in range(stack.shape[-1]):
    interp = NearestNDInterpolator(target_px.reshape((-1, 2)), stack[..., axis].ravel())
    grid = np.meshgrid(np.arange(h_px), np.arange(w_px), indexing="ij")
    out[..., axis] = interp(*grid)

In [None]:
plt.imshow(rgbit(out))

In [None]:
plt.imshow(stack[..., 0], vmax=np.nanpercentile(stack[..., 0], 99))

In [None]:
rng = np.random.default_rng()
x = rng.random(10) - 0.5
y = rng.random(10) - 0.5
z = np.hypot(x, y)
X = np.linspace(min(x), max(x))
Y = np.linspace(min(y), max(y))
X, Y = np.meshgrid(X, Y)  # 2D grid for interpolation
interp = NearestNDInterpolator(list(zip(x, y)), z)
print(x)
print(y)
print(list(zip(x, y)))

In [None]:
print(target_px.shape)

In [None]:
chamber = "chamber_09"
roi = 3
prefix = "mCherry_1"
channels = [1, 2]
error_correction_ds_name = error_correction_ds_name
output_folder = save_folder

In [None]:
plt.imshow(rgbit(rabies_cell_properties))

In [None]:
ara_info_folder = (
    issp.io.get_processed_path(f"{project}/{mouse}") / "analysis" / "ara_infos"
)
target = (
    ara_info_folder
    / f"{error_correction_ds_name}_{chamber}_{roi}_rabies_spots_ara_info.pkl"
)
spot_df = pd.read_pickle(target)

In [None]:
transform

In [None]:
issa.registration.ara_registration.get_ara_to_slice_rotation_matrix(
    spot_df=pd.read_pickle(target)
)

In [None]:
ref_chamber = "chamber_09"
ref_roi = 3
data_path = f"{project}/{mouse}/{ref_chamber}"
mcherry_ref = issp.pipeline.stitch_registered(
    data_path,
    prefix="mCherry_1",
    roi=ref_roi,
    channels=[2, 3],
)
mcherry_next = issp.pipeline.stitch_registered(
    data_path,
    prefix="mCherry_1",
    roi=ref_roi + 1,
    channels=[2, 3],
)

In [None]:
from skimage.transform import resize

ara_coords_ref = issp.pipeline.ara_registration.load_coordinate_image(
    data_path, ref_roi, full_scale=False
)
ara_coords_next = issp.pipeline.ara_registration.load_coordinate_image(
    data_path, ref_roi + 1, full_scale=False
)

mcherry_ref_ds = np.dstack(
    [
        resize(mcherry_ref[..., i], ara_coords_ref.shape[:2])
        for i in range(mcherry_ref.shape[-1])
    ]
)
mcherry_next_ds = np.dstack(
    [
        resize(mcherry_next[..., i], ara_coords_next.shape[:2])
        for i in range(mcherry_next.shape[-1])
    ]
)

In [None]:
def rgbit(img, percentiles=(0.2, 99.9)):
    valid = np.all(img > 0, axis=-1)
    vmin, vmax = np.nanpercentile(img[valid, :], percentiles, axis=0)
    rgb = issp.vis.to_rgb(img, colors=[(1, 0, 0), (0, 1, 0)], vmin=vmin, vmax=vmax)
    return rgb


fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(1, 3, 1)
ax.imshow(rgbit(mcherry_ref_ds))
ax.axis("off")
ax = fig.add_subplot(1, 3, 2)
ax.imshow(rgbit(mcherry_next_ds))
ax.axis("off")
fig.tight_layout()

In [None]:
# rotate the ara coordinates
transform = issa.registration.ara_registration.get_ara_to_slice_rotation_matrix(
    spot_df=rab_spot_df
)


def rotate_ara_coords(ara_coords, transform):
    ara_coords_rot = ara_coords.reshape(-1, 3) @ transform
    ara_coords_rot = ara_coords_rot.reshape(ara_coords.shape)
    return ara_coords_rot

In [None]:
def rotate_stack(stack, ara_coords_rot, area_img=None):
    shapes = np.vstack([stack.shape[:2], ara_coords_rot.shape[:2]])
    if np.any(np.diff(shapes, axis=0)):
        raise ValueError("Stack and ara_coords_rot must have the same shape")

    if area_img is None:
        area_img = np.ones_like(stack[..., 0], dtype=bool)
    ara_lim = np.vstack(
        [
            ara_coords_rot[area_img > 0].min(axis=0),
            ara_coords_rot[area_img > 0].max(axis=0),
        ]
    )
    width, height = np.diff(ara_lim[:, 1:], axis=0).squeeze()
    px_size_um = 10
    w_px, h_px = np.round(np.array([width, height]) * 1000 / px_size_um).astype(int) + 1
    rotated_stack = np.zeros((h_px, w_px, 2))
    target_px = np.round((ara_coords_rot - ara_lim[0, :]) * 1000 / px_size_um).astype(
        int
    )
    target_px = target_px[..., [1, 2]]
    for i in range(2):
        target_px[..., i] = np.clip(target_px[..., i], 0, rotated_stack.shape[i] - 1)
    rotated_stack[target_px[..., 0], target_px[..., 1], :] = stack

    return rotated_stack


ara_coords_ref_rot = rotate_ara_coords(ara_coords_ref, transform)
ara_coords_next_rot = rotate_ara_coords(ara_coords_next, transform)
area_img = issp.pipeline.ara_registration.make_area_image(
    data_path, ref_roi, atlas_size=10, full_scale=False, reload=False
)

rotated_mcherry_ref = rotate_stack(
    mcherry_ref_ds, ara_coords_ref_rot, area_img=area_img
)
area_img = issp.pipeline.ara_registration.make_area_image(
    data_path, ref_roi + 1, atlas_size=10, full_scale=False, reload=False
)

rotated_mcherry_next = rotate_stack(
    mcherry_next_ds, ara_coords_next_rot, area_img=area_img
)

In [None]:
plt.subplot(1, 2, 1)

plt.imshow(rgbit(rotated_mcherry_ref))
plt.axis("off")
plt.subplot(1, 2, 2)
area_img = issp.pipeline.ara_registration.make_area_image(
    data_path, ref_roi + 1, atlas_size=10, full_scale=False, reload=False
)

plt.imshow(rgbit(rotated_mcherry_next))
plt.axis("off")

In [None]:
issp.io.save.write_stack(
    rotated_mcherry_ref, save_folder / f"mcherry_{slice}_rotated.tif"
)

In [None]:
_ = issp.pipeline.ara_registration.crop_overview_registration(
    data_path, rois=3, overview_prefix="DAPI_1_1"
)

area_img = issp.pipeline.ara_registration.make_area_image(
    data_path, ref_roi, atlas_size=10, full_scale=False, reload=False
)

In [None]:
np.max(anchor_ref)

# Try with genes

In [None]:
chamber = "chamber_09"
ref_roi = 3
target_roi = 4

In [None]:
# Get the genes spots
genes_df = issa.io.get_genes_spots(project=project, mouse=mouse)
print(f"Loaded {genes_df.shape[0]} spots")

spots_df = genes_df.query(f"chamber == '{chamber}'")
spots_df = spots_df[spots_df.roi.isin([ref_roi, target_roi])]
print(f"Found {spots_df.shape[0]} spots in the selected rois")

In [None]:
if "ara_y_rot" not in spots_df.columns:
    spots_df = issa.registration.ara_registration.rotate_ara_coordinate_to_slice(
        spots_df
    )
if "slice" not in spots_df.columns:
    spots_df["slice"] = (
        spots_df.chamber + "_" + spots_df["roi"].map(lambda x: f"{x:02d}")
    )

In [None]:
ref_spots = spots_df.query(f"roi == {ref_roi}")
target_spots = spots_df.query(f"roi == {target_roi}")
ref_spots.head()

In [None]:
center_point = [4, 9]
window_size = 200 / 1000
plt.subplot(1, 2, 1, aspect="equal")
plt.scatter(ref_spots.ara_y_rot, ref_spots.ara_z_rot, c="r", s=1, alpha=0.1)
plt.xlim(center_point[0] - window_size, center_point[0] + window_size)
plt.ylim(center_point[1] - window_size, center_point[1] + window_size)
plt.subplot(1, 2, 2, aspect="equal")
plt.scatter(target_spots.ara_y_rot, target_spots.ara_z_rot, c="r", s=1, alpha=0.1)
plt.xlim(center_point[0] - window_size, center_point[0] + window_size)
plt.ylim(center_point[1] - window_size, center_point[1] + window_size)

In [None]:
# get the spots around the center point

sp = spots_df.query(
    f"ara_y_rot > {center_point[0] - window_size} "
    + f"and ara_y_rot < {center_point[0] + window_size} "
    + f"and ara_z_rot > {center_point[1] - window_size} "
    + f"and ara_z_rot < {center_point[1] + window_size}"
)
ref_sp = sp.query(f"roi == {ref_roi}").copy()
target_sp = sp.query(f"roi == {target_roi}").copy()

plt.subplot(1, 2, 1, aspect="equal")
plt.scatter(ref_sp.ara_y_rot, ref_sp.ara_z_rot, c="r", s=1, alpha=0.1)
plt.subplot(1, 2, 2, aspect="equal")
plt.scatter(target_sp.ara_y_rot, target_sp.ara_z_rot, c="r", s=1, alpha=0.1)

In [None]:
# make a spot image
ref_sp = sp.query(f"roi == {ref_roi}").copy()
target_sp = sp.query(f"roi == {target_roi}").copy()
ref_sp["x"] = (ref_sp.ara_y_rot - center_point[0] + window_size) * 1000
ref_sp["y"] = (ref_sp.ara_z_rot - center_point[1] + window_size) * 1000
target_sp["x"] = (target_sp.ara_y_rot - center_point[0] + window_size) * 1000
target_sp["y"] = (target_sp.ara_z_rot - center_point[1] + window_size) * 1000

output_shape = np.array([window_size * 2 * 1000, window_size * 2 * 1000], dtype=int)
ref_img = issp.segment.spots.make_spot_image(
    ref_sp, gaussian_width=3, output_shape=output_shape
)
target_img = issp.segment.spots.make_spot_image(
    target_sp, gaussian_width=3, output_shape=output_shape
)

plt.subplot(1, 2, 1)
plt.imshow(ref_img, cmap="Greys", origin="lower")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(target_img, cmap="Greys", origin="lower")
plt.axis("off")

# END Try with genes

# Check ARA reg quality

In [None]:
from znamutils import slurm_it

slurm_folder = Path.home() / "slurm_logs" / "ara_registration"
slurm_folder.mkdir(exist_ok=True)

save_folder = issp.io.get_processed_path(f"{project}/{mouse}/analysis/ara_infos")

for chamber in [f"chamber_{i:02d}" for i in [7, 8, 9, 10]]:
    data_path = f"{project}/{mouse}/{chamber}"
    issp.pipeline.ara_registration.check_reg(
        data_path,
        save_folder=save_folder,
        slurm_folder=slurm_folder,
        use_slurm=True,
        rois=None,
        scripts_name=f"check_reg_{chamber}",
    )

In [None]:
a_img = area_img.astype(float)
a_img[a_img == 0] = np.nan
fig = plt.figure(figsize=(5, 5))
spots = rab_spot_df.query("chamber == 'chamber_09' & roi == 3")
plt.imshow(
    mcherry_ref_ds,
    cmap="gray",
    vmin=0,
    vmax=10,
    extent=[0, mcherry_ref_ds.shape[1] * 8, 0, mcherry_ref_ds.shape[0] * 8],
)
plt.imshow(
    a_img[: mcherry_ref_ds.shape[0], : mcherry_ref_ds.shape[1]] % 20,
    alpha=0.2,
    cmap="tab20",
    vmin=0,
    vmax=20,
    extent=[0, mcherry_ref_ds.shape[1] * 8, 0, mcherry_ref_ds.shape[0] * 8],
)
badies = spots.query("area_id == 0")
# plt.scatter(badies.x, badies.y, c='k', s=20)
# plt.scatter(spots.x, spots.y, c=spots.area_id%20, cmap='tab20', s=10)

In [None]:
fig = plt.figure(figsize=(20, 20))
spots = rab_spot_df.query("chamber == 'chamber_09' & roi == 3")
plt.imshow(
    mcherry_ref_ds,
    cmap="gray",
    vmin=0,
    vmax=10,
    extent=[0, mcherry_ref_ds.shape[1] * 8, 0, mcherry_ref_ds.shape[0] * 8],
)
plt.imshow(
    a_img[: mcherry_ref_ds.shape[0], : mcherry_ref_ds.shape[1]] % 20,
    alpha=0.2,
    cmap="tab20",
    vmin=0,
    vmax=20,
    extent=[0, mcherry_ref_ds.shape[1] * 8, 0, mcherry_ref_ds.shape[0] * 8],
)
badies = spots.query("area_id == 0")
# plt.scatter(badies.x, badies.y, c='k', s=20)
# plt.scatter(spots.x, spots.y, c=spots.area_id%20, cmap='tab20', s=10)
plt.axis("off")

In [None]:
ara_lim = np.vstack(
    [
        ara_coords_ref_rot[area_img > 0].min(axis=0),
        ara_coords_ref_rot[area_img > 0].max(axis=0),
    ]
)
ara_lim

In [None]:
ara_coords_next.shape

In [None]:
plt.imshow(ara_coords[..., 0])

In [None]:
plt.imshow(mcherry_ref[::8, ::8, 0], vmax=100)

In [None]:
ara_coords = spot_df[[f"ara_{i}" for i in "xyz"]].values
rotated_coords = ara_coords @ transform

In [None]:
reg.shape

In [None]:
print(f"NaN reg: {np.isnan(reg.shift_y).sum()}/{reg.shape[0]}")
plt.scatter(reg.shift_y, reg.shift_z)

In [None]:
import seaborn as sns

sns.histplot(distancewithin[~np.isnan(distancewithin)], bins=100)

In [None]:
distance2previous

In [None]:
distance = np.linalg.norm(
    valid[["x", "y"]].values.astype(float) - cellinfo[["x", "y"]].values.astype(float),
    axis=1,
)
distance

In [None]:
# add barcode_id for plotting
barcodes = list(rab_spot_df.corrected_bases.unique())
rab_spot_df["barcode_id"] = rab_spot_df.corrected_bases.map(lambda x: barcodes.index(x))

In [None]:
from tqdm import tqdm

window_size = 300
min_spots = 10
max_barcode_number = 50
min_barcode_number = 1
gaussian_width = 30

section_infos = issa.io.get_sections_info(project, mouse)
section_infos["slice"] = (
    section_infos["chamber"] + "_" + section_infos["roi"].map(lambda x: f"{x:02d}")
)
ref_chamber, ref_roi = "chamber_09", 2
ref_sec = section_infos.query("chamber == @ref_chamber and roi == @ref_roi").iloc[0]
ref_slice = ref_sec.slice
print(ref_slice)
todo = (ref_sec.absolute_section + np.arange(3)).astype(int)

rabies_cell_properties["slice"] = (
    rabies_cell_properties["chamber"]
    + "_"
    + rabies_cell_properties["roi"].map(lambda x: f"{x:02d}")
)
cells = rabies_cell_properties.query("chamber == @ref_chamber and roi == @ref_roi")
ref_spots = rab_spot_df.query("slice == @ref_slice")
nbcs_all = np.zeros((cells.shape[0], len(todo))) * np.nan
shifts_all = np.zeros((cells.shape[0], len(todo), 2)) * np.nan
max_corrs_all = np.zeros((cells.shape[0], len(todo))) * np.nan
dist2closest = np.zeros((cells.shape[0], len(todo))) * np.nan
closest_target = [["NaN"] * cells.shape[0] for _ in range(len(todo))]
for sec_id in todo:
    print(f"Do section {sec_id}")
    slice_df = section_infos.query("absolute_section == @sec_id").iloc[0]
    for ic, (_, cell_info) in tqdm(enumerate(cells.iterrows()), total=cells.shape[0]):
        # Get candidate cells that might be the same
        cell_barcodes = cell_info.all_barcodes
        cells_target = rabies_cell_properties.query("slice == @slice_df.slice")
        cells_target = cells_target[
            cells_target.all_barcodes.map(
                lambda x: any([bc in cell_barcodes for bc in x])
            )
        ]
        distance = np.sqrt(
            (cells_target.ara_y_rot - cell_info.ara_y_rot) ** 2
            + (cells_target.ara_z_rot - cell_info.ara_z_rot) ** 2
        )
        cells_target = cells_target[distance < window_size * 2 / 1000]
        if cells_target.shape[0] == 0:
            dist2closest[ic, sec_id - todo[0]] = np.inf
            continue

        (
            shifts,
            maxcorr,
            n_bcs,
        ) = issa.registration.register_serial_sections.register_local_spots(
            center_point=(cell_info.ara_y_rot, cell_info.ara_z_rot),
            spot_df=rab_spot_df,
            ref_slice=ref_slice,
            target_slice=slice_df.slice,
            window_size=window_size,
            min_spots=min_spots,
            max_barcode_number=max_barcode_number,
            gaussian_width=gaussian_width,
            verbose=False,
            debug=False,
        )
        nbcs_all[ic, sec_id - todo[0]] = n_bcs
        shifts_all[ic, sec_id - todo[0]] = shifts
        max_corrs_all[ic, sec_id - todo[0]] = maxcorr

        if n_bcs < min_barcode_number:
            continue
        if any(np.isnan(shifts)):
            print(f"Cell {ic} has NaN shift")
            continue
        # shifts the target cell ara_y and ara_z
        shifted_targets = cells_target.copy()
        if slice_df.slice == ref_slice:
            # remove the reference cell from the target cells if it is in the same slice
            shifted_targets = shifted_targets.drop(cell_info.name)
        if len(shifted_targets) == 0:
            dist2closest[ic, sec_id - todo[0]] = np.inf
            continue
        shifted_targets.ara_y_rot += shifts[1] / 1000
        shifted_targets.ara_z_rot += shifts[0] / 1000
        # compute the distance between the target cells and the reference cell
        distance = np.sqrt(
            (shifted_targets.ara_y_rot - cell_info.ara_y_rot) ** 2
            + (shifted_targets.ara_z_rot - cell_info.ara_z_rot) ** 2
        )
        # select the closest cell
        best_target = shifted_targets.loc[distance.idxmin()]
        closest_target[sec_id - todo[0]][ic] = best_target.name
        dist2closest[ic, sec_id - todo[0]] = distance.min()

In [None]:
labels = ["within slice", "next slice", "skip one slice"]
fig = plt.figure(figsize=(5, 3))
ax = fig.add_subplot(111)
for iax, label in enumerate(labels):
    ax.hist(
        dist2closest[:, iax] * 1000,
        bins=np.arange(0, 1000, 1),
        label=label,
        alpha=0.5,
        histtype="step",
        lw=3,
        cumulative=True,
        density=True,
    )
ax.set_xlim(0, 100)
ax.set_ylim(0, 0.4)
ax.legend(loc="upper right")
ax.set_xlabel("Distance to closest cell (um)")
ax.set_ylabel("Proportion of cells")

In [None]:
print(f"{dist2closest.shape[0]} cells")
ncells = dist2closest.shape[0]
reg_fail = np.isnan(dist2closest).sum(axis=0)
no_target = np.isinf(dist2closest).sum(axis=0)
has_close_target = ((dist2closest * 1000) < 15).sum(axis=0)
print(f"{no_target[1]}/{ncells} have no neighbour in the next slice.")
left = ncells - no_target[1]
print(f"Registration failed for {reg_fail[1]}/{left} cells.")
left = left - reg_fail[1]
print(f"{has_close_target[1]}/{left} have a close target in the next slice.")
print(np.isnan(dist2closest).sum(axis=0))
print(np.isinf(dist2closest).sum(axis=0))
print(((dist2closest * 1000) < 15).sum(axis=0))

In [None]:
d = dist2closest[:, 1]
bad = np.isnan(d)
print(f"Bad cells: {bad.sum()}")
print(f"Good cells: {(~bad).sum()}")

In [None]:
spot_ref = rab_spot_df.query("slice == @ref_slice")
# cut around ara_y_rot and ara_z_rot of cell_info
spot_ref = spot_ref.query(
    "ara_y_rot > @cell_info.ara_y_rot - @window_size / 1000 and ara_y_rot < @cell_info.ara_y_rot + @window_size / 1000"
)
spot_ref = spot_ref.query(
    "ara_z_rot > @cell_info.ara_z_rot - @window_size / 1000 and ara_z_rot < @cell_info.ara_z_rot + @window_size / 1000"
)
bc_ref = spot_ref.corrected_bases.unique()
spot_target = rab_spot_df.query("slice == @slice_df.slice")
spot_target = spot_target.query(
    "ara_y_rot > @cell_info.ara_y_rot - @window_size / 1000 and ara_y_rot < @cell_info.ara_y_rot + @window_size / 1000"
)
spot_target = spot_target.query(
    "ara_z_rot > @cell_info.ara_z_rot - @window_size / 1000 and ara_z_rot < @cell_info.ara_z_rot + @window_size / 1000"
)
bc_target = spot_target.corrected_bases.unique()
common_bc = set(bc_ref).intersection(bc_target)
print(f"{len(common_bc)} common barcodes")
common_bc = list(common_bc)
spot_ref = spot_ref.query("corrected_bases in @common_bc")
spot_target = spot_target.query("corrected_bases in @common_bc")


nspots_ber_cb = spot_ref.groupby("corrected_bases").size()[common_bc]
nspots_bct_cb = spot_target.groupby("corrected_bases").size()[common_bc]
min_nspots = np.vstack([nspots_ber_cb, nspots_bct_cb]).min(axis=0)
valid = min_nspots > 10
common_bc = np.array(common_bc)[valid]
spot_ref = spot_ref.query("corrected_bases in @common_bc")
spot_target = spot_target.query("corrected_bases in @common_bc")
print(
    f"{len(spot_ref)} spots in reference slice and {len(spot_target)} in target slice"
)

In [None]:
ms = 20
ref_kw = dict(cmap="tab20", marker="s", s=ms, alpha=0.5, c=spot_ref.barcode_id % 20)
target_kw = dict(
    cmap="tab20", marker="o", s=ms, alpha=0.5, c=spot_target.barcode_id % 20
)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for i in range(3):
    axes[0, i].plot(cell_info.ara_y_rot, cell_info.ara_z_rot, "o", color="k")
    axes[0, i].scatter(spot_ref.ara_y_rot, spot_ref.ara_z_rot, **ref_kw)
axes[1, 2].scatter(spot_ref.ara_y_rot, spot_ref.ara_z_rot, **ref_kw)
axes[1, 0].scatter(spot_target.ara_y_rot, spot_target.ara_z_rot, **target_kw)
axes[1, 1].scatter(
    spot_target.ara_y_rot - shifts[0] / 1000,
    spot_target.ara_z_rot - shifts[1] / 1000,
    **target_kw
)
axes[1, 2].scatter(
    spot_target.ara_y_rot - shifts[0] / 1000,
    spot_target.ara_z_rot - shifts[1] / 1000,
    **target_kw
)
axes[0, 2].scatter(
    spot_target.ara_y_rot - shifts[0] / 1000,
    spot_target.ara_z_rot - shifts[1] / 1000,
    **target_kw
)
if False:
    axes[1].plot(cell_info.ara_y_rot, cell_info.ara_z_rot, "o", color="k")
    axes[1].scatter(spot_ref.ara_y_rot, spot_ref.ara_z_rot, **ref_kw)
    axes[1].scatter(
        spot_target.ara_y_rot - shifts[0] / 1000,
        spot_target.ara_z_rot - shifts[1] / 1000,
        **target_kw
    )
for x in axes.flatten():
    x.set_aspect("equal")
    if True:
        w = 0.2
        x.set_xlim(cell_info.ara_y_rot - w, cell_info.ara_y_rot + w)
        x.set_ylim(cell_info.ara_z_rot - w, cell_info.ara_z_rot + w)
    x.set_xticks([])
    x.set_yticks([])
fig.tight_layout()

In [None]:
(
    shifts,
    maxcorr,
    n_bcs,
    phase_corrs,
    spot_images,
    best_barcodes,
) = issa.registration.register_serial_sections.register_local_spots(
    center_point=(cell_info.ara_y_rot, cell_info.ara_z_rot),
    spot_df=rab_spot_df,
    ref_slice=ref_slice,
    target_slice=slice_df.slice,
    window_size=window_size,
    min_spots=min_spots,
    max_barcode_number=max_barcode_number,
    gaussian_width=gaussian_width,
    verbose=True,
    debug=True,
)

In [None]:
fig, axes = plt.subplots(len(spot_images), 3, figsize=(5, 50))
for row, data in enumerate(spot_images):
    axes[row, 0].imshow(data[0])
    axes[row, 1].imshow(data[1])
    axes[row, 2].imshow(phase_corrs[row])
for x in axes.flatten():
    x.set_aspect("equal")
    x.set_xticks([])
    x.set_yticks([])
fig.tight_layout()

In [None]:
spots_ref = rab_spot_df.query(
    "slice == @ref_slice and corrected_bases == @cell_info.main_barcode"
)
spots_target = rab_spot_df.query(
    "slice == @slice_df.slice and corrected_bases == @cell_info.main_barcode"
)
spots_ref.shape
spots_target.shape

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(5, 5))
ax[0, 0].imshow(spot_images[0, 0])
ax[0, 1].imshow(spot_images[0, 1])
ax[1, 0].imshow(phase_corrs[0])

In [None]:
res = pd.DataFrame(
    columns=["shift_y", "shift_z", "maxcorr", "n_barcodes"],
    data=np.vstack([np.hstack(a) for a in assignment_by_bc]),
)

import seaborn as sns

fig = plt.figure(figsize=(8, 2))
for iax, col in enumerate(["shift_y", "shift_z", "n_barcodes"]):
    ax = fig.add_subplot(1, 3, iax + 1)
    sns.histplot(res, x=col, ax=ax, bins=20)

fig.tight_layout()

In [None]:
sns.pairplot(res)

In [None]:
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

fig = plt.figure(figsize=(12, 5))

toplot = np.vstack([np.hstack(a) for a in assignment_by_bc])
labels = ["Shift Y (um)", "Shift Z (um)", "Max Corr", "N barcodes used"]
cmaps = ["RdBu", "RdBu", "viridis", "viridis"]
xycoords = (cells_in_ref[["y", "x"]].values * 0.2) / 1000
ara_coords = cells_in_ref[["ara_z_rot", "ara_y_rot"]].values
# px_size = issp.io.get_pixel_size(data_path)
for iax, label in enumerate(labels):
    for ic, data in enumerate([ara_coords, xycoords]):
        ax = fig.add_subplot(2, 4, 4 * ic + iax + 1, facecolor="k")
        ax_divider = make_axes_locatable(ax)
        # Add an Axes to the right of the main Axes.
        cax = ax_divider.append_axes("right", size="7%", pad="2%")
        vmin, vmax = np.nanpercentile(toplot[:, iax], [1, 99])
        if iax < 2:
            vminmax = max(abs(vmin), abs(vmax))
            vmin, vmax = -vminmax, vminmax
        sc = ax.scatter(
            data[:, 0],
            data[:, 1],
            c=toplot[:, iax],
            s=1,
            cmap=cmaps[iax],
            vmin=vmin,
            vmax=vmax,
            alpha=1,
        )
        ax.set_xlabel("Raw Y (mm)" if ic else "ARA Z (mm)")
        ax.set_ylabel("Raw X (mm)" if ic else "ARA Y (mm)")
        cb = plt.colorbar(sc, cax=cax)
        cb.set_label(label)
        ax.set_aspect("equal")
        if not ic:
            ax.set_ylim(5, 1.8)
            ax.set_xlim(6, 10)
        else:
            ax.set_ylim(1, 3.5)
            ax.set_xlim(0.2, 4)
plt.tight_layout()

In [None]:
res

# End register slices

In [None]:
# find starter
if False:
    starters_positions = issa.io.get_starter_cells(project, mouse)
    rabies_cell_properties = issa.segment.match_starter_to_barcodes(
        project,
        mouse,
        rabies_cell_properties,
        rab_spot_df,
        mcherry_cells=starters_positions,
        redo=True,
    )
    rabies_cell_properties.head()

# Distance between cells

Look at what is the distance between rabies cells with the same barcode inside one ROI.

In [None]:
# Iterate on barcodes and on chamber/roi, use a KDTree to find the closest pair of
# cells with same barcode and build a full list
from scipy.spatial import KDTree

ncellbarcode = rabies_cell_properties.all_barcodes.apply(len).sum()
within_slice_distance_px = np.zeros(ncellbarcode)
rabies_cell_properties["closest_distance_px"] = [
    np.zeros(len(x)) for x in rabies_cell_properties.all_barcodes
]
i = 0
for (roi, chamber), roi_df in rabies_cell_properties.groupby(["roi", "chamber"]):
    all_barcodes = set()
    for bc in roi_df["all_barcodes"].values:
        all_barcodes.update(bc)
    for bc in all_barcodes:
        df = roi_df[roi_df["all_barcodes"].apply(lambda x: bc in x)]
        bc_index = df.all_barcodes.apply(lambda x: x.index(bc))
        if len(df) > 1:
            coords = df[["x", "y"]].values
            tree = KDTree(coords)
            dist, idx = tree.query(coords, k=2)
            within_slice_distance_px[i : i + len(df)] = dist[:, 1]
            for j in range(len(df)):
                rabies_cell_properties.loc[df.index[j], "closest_distance_px"][
                    bc_index[j]
                ] = dist[j, 1]
        else:
            within_slice_distance_px[i] = np.nan
            rabies_cell_properties.loc[df.index[0], "closest_distance_px"][
                bc_index[0]
            ] = np.nan
        i += len(df)

In [None]:
px_size = issp.io.get_pixel_size(f"{project}/{mouse}/{chamber}")
within_slice_distance_um = within_slice_distance_px * px_size

fig, axes = plt.subplots(2, 1, figsize=(6, 3))
ax = axes[0]
ax.set_title("Within slice distance between cells with same barcode")
twin_ax = ax.twinx()
ax.hist(within_slice_distance_um, bins=np.arange(0, 501, 1))
n_far = np.sum(within_slice_distance_um > 500)
twin_ax.scatter(501, n_far, color="red", label=">500um")
twin_ax.set_ylim(0, n_far * 1.3)
ax.set_xlim(0, 501)
twin_ax.legend(loc="upper right")
ax.set_ylabel("Number of cells")
ax = axes[1]
ax.hist(within_slice_distance_um, bins=np.arange(0, 101, 1))
ax.set_ylabel("Number of cells")
ax.set_xlabel("Distance (um)")
fig.tight_layout()
plt.show()

In [None]:
closeby = rabies_cell_properties[
    rabies_cell_properties.closest_distance_px.map(np.nanmin) < (30 / px_size)
]
print(f"Number of cells with a close neighbor: {len(closeby)}")

# Get ARA coordinates of spots

In [None]:
# Get transform to rotate ARA coordinates to slice
from iss_analysis.registration import ara_registration

transform = ara_registration.get_ara_to_slice_rotation_matrix(rab_spot_df)
rabies_cell_properties = ara_registration.rotate_ara_coordinate_to_slice(
    rabies_cell_properties, transform=transform
)
rab_spot_df = ara_registration.rotate_ara_coordinate_to_slice(
    rab_spot_df, transform=transform
)

In [None]:
from matplotlib import cm

fig, axes = plt.subplots(3, 3, figsize=(10, 6))
step = 0.01
bins = [
    np.arange(5, 9.2, step),
    np.arange(step, 8, step * 10),
    np.arange(4, 12, step * 10),
]
rab_spot_df["slice"] = (
    rab_spot_df["chamber"] + "_" + rab_spot_df["roi"].map(lambda x: f"{x:02d}")
)
rabies_cell_properties["slice"] = (
    rabies_cell_properties["chamber"]
    + "_"
    + rabies_cell_properties["roi"].map(lambda x: f"{x:02d}")
)
slices = sorted(rab_spot_df["slice"].unique())
colors = cm.get_cmap("viridis", len(slices))
for islice, slice in enumerate(slices):
    cell_prop = rabies_cell_properties[rabies_cell_properties["slice"] == slice]
    for iax, coord in enumerate("xyz"):
        axes[iax, 0].hist(
            cell_prop[f"ara_{coord}"],
            alpha=0.5,
            label=f"{chamber} {roi}",
            histtype="step",
            bins=bins[iax],
            color=colors(islice),
        )
        axes[iax, 1].hist(
            cell_prop[f"ara_{coord}_rot"],
            alpha=0.5,
            label=f"{chamber} {roi}",
            histtype="step",
            bins=bins[iax],
            color=colors(islice),
        )
        axes[iax, 2].scatter(
            cell_prop[f"ara_{coord}"],
            cell_prop[f"ara_{coord}_rot"] - cell_prop[f"ara_{coord}"],
            label=f"{chamber} {roi}",
            s=5,
            alpha=0.5,
            color=colors(islice),
        )
        for i in range(3):
            axes[iax, i].set_xlabel(f"ARA {coord}")
        axes[iax, 2].set_ylabel(f"Rotated {coord} - ARA {coord}")
for i in range(2):
    for j in range(3):
        axes[j, i].set_ylabel(f"Number of cells")

axes[0, 0].set_title("ARA coordinates")
axes[0, 1].set_title("ARA coordinates rotated")
axes[0, 2].set_title("Difference")
fig.tight_layout()

In [None]:
# now calculate the 3d dimensions in the rotated ara coordinates
from tqdm import tqdm

all_barcodes = set()
for bc in rabies_cell_properties["all_barcodes"].values:
    all_barcodes.update(bc)
ncellbarcode = rabies_cell_properties.all_barcodes.apply(len).sum()
ara_3d_distance = np.zeros(ncellbarcode)
ara_within_distance = np.zeros(ncellbarcode)
i = 0
for barcode in tqdm(all_barcodes):
    df = rabies_cell_properties[
        rabies_cell_properties["all_barcodes"].apply(lambda x: barcode in x)
    ]
    ara_coords = df[["ara_x_rot", "ara_y_rot", "ara_z_rot"]].values
    ara_coords[~np.isfinite(ara_coords)] = 0
    tree = KDTree(ara_coords)
    dist, idx = tree.query(ara_coords, k=2)
    ara_3d_distance[i : i + len(df)] = dist[:, 1]
    for (chamber, roi), cell_prop in df.groupby(["chamber", "roi"]):
        ara_coords = cell_prop[["ara_x_rot", "ara_y_rot", "ara_z_rot"]].values
        ara_coords[~np.isfinite(ara_coords)] = 0
        tree = KDTree(ara_coords)
        dist, idx = tree.query(ara_coords, k=2)
        ara_within_distance[i : i + len(cell_prop)] = dist[:, 1]
        i += len(cell_prop)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 3))
ax.hist(
    np.clip(ara_3d_distance * 1000, 0, 500),
    bins=np.arange(0, 501, 5),
    label="All slices",
    alpha=0.5,
    cumulative=False,
)
ax.hist(
    np.clip(ara_within_distance * 1000, 0, 500),
    bins=np.arange(0, 501, 5),
    alpha=0.5,
    label="Within slice",
    cumulative=False,
)
ax.set_xlabel("Distance between cells with same barcode (um)")
ax.set_ylabel("Number of cells")
ax.legend()
ax.set_ylim(0, 500)
ax.set_xlim(0, 500)

In [None]:
# Now look only at neighboring slices and project the 3d distance to the slice
from iss_analysis.registration import utils

close_within_slice = dict()
close_around_slice = dict()
ara_2d_distance = []
ara_within_2d_distance = []
for barcode in tqdm(all_barcodes):
    cells_df = rabies_cell_properties[
        rabies_cell_properties["all_barcodes"].apply(lambda x: barcode in x)
    ]
    for (chamber, roi), cell_prop in cells_df.groupby(["chamber", "roi"]):
        # within first
        ara_coords = cell_prop[["ara_y_rot", "ara_z_rot"]].values
        ara_coords[~np.isfinite(ara_coords)] = 0
        tree = KDTree(ara_coords)
        dist, idx = tree.query(ara_coords, k=2)
        ara_within_2d_distance.extend(dist[:, 1])
        closeby = np.where(dist[:, 1] < 30 / 1000)[0]
        for c in closeby:
            csource = cell_prop.index[c]
            if csource not in close_within_slice:
                close_within_slice[csource] = []
            close_within_slice[csource].append([cell_prop.index[idx[c]], dist[c, 1]])

        # now to surrounding slices
        surrounding_rois = utils.get_surrounding_slices(
            chamber, roi, project, mouse, include_ref=False
        )
        surrounding_slices = (
            surrounding_rois.chamber
            + "_"
            + surrounding_rois.roi.map(lambda x: f"{x:02d}")
        )
        surrounding_cells = cells_df[cells_df["slice"].isin(surrounding_slices)]
        surrounding_coords = surrounding_cells[["ara_y_rot", "ara_z_rot"]].values
        surrounding_coords[~np.isfinite(surrounding_coords)] = 0
        tree = KDTree(surrounding_coords)
        dist, idx = tree.query(ara_coords, k=1)
        ara_2d_distance.extend(dist)
        closeby = np.where(dist < 30 / 1000)[0]
        for c in closeby:
            csource = cell_prop.index[c]
            if csource not in close_around_slice:
                close_around_slice[csource] = []
            close_around_slice[csource].append(
                [surrounding_cells.index[idx[c]], dist[c]]
            )

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 3))
params = dict(
    bins=np.arange(0, 201, 3), alpha=1, cumulative=False, lw=2, histtype="step"
)
ax.hist(
    np.clip(np.array(ara_2d_distance) * 1000, 0, 200),
    label="Neighbouring slices",
    **params
)
ax.hist(
    np.clip(np.array(ara_within_2d_distance) * 1000, 0, 200),
    label="Within slice",
    **params
)
ax.set_xlabel("Distance between cells with same barcode (um)")
ax.set_ylabel("Number of cells")
ax.legend()
ax.set_ylim(0, 150)

In [None]:
print(f"{len(close_within_slice)} cells with close neighbors within slice")
print(f"{len(close_around_slice)} cells with close neighbors around slice")

In [None]:
for c in close_around_slice.keys():
    if "chamber_09_3" in c:
        print(c, close_around_slice[c])

In [None]:
list(close_around_slice.keys())[0]

# Plot example close cells

Find an example of cells that are close to each other and plot them.
We want the rotated ara of the spots and mask of the first cells, same for the second
And the row slice for both cells in a window around the cell

In [None]:
example_cell = "chamber_09_3_14937"

neighbour, dst = close_around_slice[example_cell][0]
print(f"Cell {example_cell} is {dst * 1000:.2f} um away from {neighbour}")

cells = rabies_cell_properties.loc[[example_cell, neighbour]]
cells

In [None]:
# plot the ara part
from iss_analysis import vis
from functools import partial

fig, axes = plt.subplots(2, 2, figsize=(10, 10))
ax_ara = axes[0, 0]
# window size in um around the source to plot background
window = 200


source = cells.iloc[0]
window = np.array([-1, 1]) * window / 1000
center = cells.iloc[0][["ara_y_rot", "ara_z_rot"]].values

_get_spot = partial(
    vis.get_spot_part,
    xlim=center[0] + window,
    ylim=center[1] + window,
    xcol="ara_y_rot",
    ycol="ara_z_rot",
)

labels = ["Source", "Neighbour"]
colors = cm.get_cmap("Set2").colors[3:]
bg_colors = ["lightblue", "lightgrey"]
for ic, (cname, cell) in enumerate(cells.iterrows()):
    c, r = cell.chamber, cell.roi
    bg_cells = rabies_cell_properties[
        (rabies_cell_properties.chamber == c) & (rabies_cell_properties.roi == r)
    ]
    bg_cells = _get_spot(bg_cells)
    ax_ara.scatter(
        bg_cells["ara_y_rot"] * 1000,
        bg_cells["ara_z_rot"] * 1000,
        color=bg_colors[ic],
        s=100,
        label=f"{c} roi {r}",
        zorder=0,
        alpha=0.5,
    )

    bg_spots = rab_spot_df[(rab_spot_df.chamber == c) & (rab_spot_df.roi == r)]
    bg_spots = _get_spot(bg_spots)
    ax_ara.scatter(
        bg_spots["ara_y_rot"] * 1000,
        bg_spots["ara_z_rot"] * 1000,
        color=bg_colors[ic],
        s=10,
        zorder=0,
        alpha=0.5,
    )

    ax_ara.scatter(
        cell["ara_y_rot"] * 1000,
        cell["ara_z_rot"] * 1000,
        color=colors[ic],
        label=f"{cname}",
        s=100,
    )
    # find spots of this cell
    spots = rab_spot_df[rab_spot_df["cell_mask"] == cell.cell_id]
    ax_ara.scatter(
        spots["ara_y_rot"] * 1000,
        spots["ara_z_rot"] * 1000,
        color=colors[ic],
        ec="k",
        label=f"{labels[ic]} spots",
        s=10,
    )

    ax = axes[1, ic]
    ax.scatter(
        bg_cells.x,
        bg_cells.y,
        c=bg_cells.cell_id % 20,
        cmap="tab20",
        vmin=0,
        vmax=19,
        s=100,
        alpha=0.5,
    )
    bg = bg_spots.cell_mask == -1
    # ax.scatter(bg_spots[bg].x, bg_spots[bg].y, color='k', s=10, alpha=0.5)
    # ax.scatter(bg_spots[~bg].x, bg_spots[~bg].y, c=bg_spots[~bg].cell_mask % 20, s=10, alpha=0.5, cmap='tab20', vmin=0, vmax=19)
    ax.scatter(
        bg_spots.x,
        bg_spots.y,
        c=bg_spots.index % 20,
        s=10,
        alpha=0.5,
        cmap="tab20",
        vmin=0,
        vmax=19,
    )
    ax.scatter(cell["x"], cell["y"], color=colors[ic], s=100, marker="x")
    if False:
        ax.scatter(
            cell["x"],
            cell["y"],
            color=colors[ic],
            s=100,
        )
        ax.scatter(
            spots["x"],
            spots["y"],
            color=colors[ic],
            ec="k",
            s=10,
        )

for x in axes.flatten():
    x.axis("off")
ax_ara.legend(loc="upper right", ncol=2)
fig.tight_layout()

In [None]:
barcodes = list(rab_spot_df.corrected_bases.unique())
rab_spot_df["bc_index"] = rab_spot_df.corrected_bases.map(lambda x: barcodes.index(x))
rab_spot_df.bc_index

In [None]:
# plot the ara part
from iss_analysis import vis
from functools import partial

fig, axes = plt.subplots(1, 2, figsize=(20, 10))

# window size in um around the source to plot background
window = 300


source = cells.iloc[0]
window = np.array([-1, 1]) * window / 1000
center = cells.iloc[1][["ara_y_rot", "ara_z_rot"]].values + 0.4

_get_spot = partial(
    vis.get_spot_part,
    xlim=center[0] + window,
    ylim=center[1] + window,
    xcol="ara_y_rot",
    ycol="ara_z_rot",
)

labels = ["Source", "Neighbour"]
colors = cm.get_cmap("Set2").colors[3:]
bg_colors = ["lightblue", "lightgrey"]
for ic, (cname, cell) in enumerate(cells.iterrows()):
    c, r = cell.chamber, cell.roi
    bg_cells = rabies_cell_properties[
        (rabies_cell_properties.chamber == c) & (rabies_cell_properties.roi == r)
    ]
    bg_cells = _get_spot(bg_cells)
    axes[ic].scatter(
        bg_cells["ara_y_rot"] * 1000,
        bg_cells["ara_z_rot"] * 1000,
        color=bg_colors[ic],
        s=100,
        label=f"{c} roi {r}",
        zorder=0,
        alpha=0.1,
    )

    bg_spots = rab_spot_df[(rab_spot_df.chamber == c) & (rab_spot_df.roi == r)]
    bg_spots = _get_spot(bg_spots)
    axes[ic].scatter(
        bg_spots["ara_y_rot"] * 1000,
        bg_spots["ara_z_rot"] * 1000,
        c=bg_spots.bc_index % 20,
        cmap="tab20",
        vmin=0,
        vmax=19,
        s=20,
        zorder=0,
        alpha=1,
    )


for x in axes.flatten():
    x.axis("on")
    x.grid(True)
ax_ara.legend(loc="upper right", ncol=2)
fig.tight_layout()

# Register ara data

The aim is to register slices together but it's hard to do directly, so first we 
register to the ara, rotate the ara to align with the slicing plane and then only
register the slices. `rab_spot_df` already has the `ara_x/y/z_rot` information. 

We will register locally, around one cell of interest. First we subselect the data
in window around that cell, then we register the two slices together.
For that we get barcodes present in both slices, keep the 
one that have enough dots and make a blurred version of the mask to register the two.


In [None]:
from iss_analysis.registration import ara_registration
from iss_analysis.registration import utils

# add ARA coordinates to rabies cell properties
transform = ara_registration.get_ara_to_slice_rotation_matrix(spot_df=rab_spot_df)
rabies_cell_properties = ara_registration.rotate_ara_coordinate_to_slice(
    rabies_cell_properties, transform=transform
)

# select some random cell
cell_of_interest = rabies_cell_properties.iloc[200]

ref_chamber, ref_roi = cell_of_interest.chamber, cell_of_interest.roi
surrounding_rois = utils.get_surrounding_slices(
    ref_chamber, ref_roi, project, mouse, include_ref=True
)

# we will look only at the previous slice for now
surrounding_rois = surrounding_rois.iloc[:2]

In [None]:
from iss_analysis.registration import register_serial_sections

ref_slice = (
    surrounding_rois.iloc[0].chamber + "_" + f"{surrounding_rois.iloc[0].roi:02d}"
)
target_slice = (
    surrounding_rois.iloc[1].chamber + "_" + f"{surrounding_rois.iloc[1].roi:02d}"
)
(
    shifts,
    maxcorr,
    phase_corrs,
    spot_images,
    barcodes,
) = register_serial_sections.register_local_spots(
    spot_df=rab_spot_df,
    ref_slice=ref_slice,
    target_slice=target_slice,
    center_point=cell_of_interest[["ara_y_rot", "ara_z_rot"]].values,
    window_size=250,
    min_spots=5,
    max_barcode_number=500,
    gaussian_width=30,
    verbose=False,
)

In [None]:
# Loop on all cells and run the registration
from tqdm import tqdm
from functools import partial
from multiprocessing import Pool

verbose = True
n_workers = 4
cells_in_ref = rabies_cell_properties.query(
    f"chamber == '{ref_chamber}' and roi == {ref_roi}"
)
shifts = np.zeros((len(cells_in_ref), 2))
maxcorrs = np.zeros(len(cells_in_ref))

reg_one_cell = partial(
    register_serial_sections.register_local_spots,
    spot_df=rab_spot_df,
    ref_slice=ref_slice,
    target_slice=target_slice,
    window_size=250,
    min_spots=5,
    max_barcode_number=500,
    gaussian_width=30,
    verbose=False,
    debug=False,
)
cell_coords = cells_in_ref[["ara_y_rot", "ara_z_rot"]].values
if n_workers == 1:
    assignment_by_bc = list(map(reg_one_cell, cell_coords))
else:
    if verbose:
        print(f"Registering {len(cell_coords)} cells using {n_workers} workers")
    with Pool(n_workers) as pool:
        assignment_by_bc = list(
            tqdm(
                pool.imap(reg_one_cell, cell_coords),
                total=len(cell_coords),
            )
        )

In [None]:
window = 2000

cell_pos = cell_of_interest[["ara_y_rot", "ara_z_rot"]].values.astype(float)
win_around = np.array([-1, 1]) * window / 1000 + cell_pos[None, :].T

print(f"Cropping around {np.round(cell_pos,2)} with window of {window}um")

barcodes_by_roi = []
spots_by_roi = []
for r, rdf in surrounding_rois.iterrows():
    spots = rab_spot_df.query(f"chamber == '{rdf.chamber}' and roi == {rdf.roi}")
    for i, coord in enumerate("yz"):
        w = win_around[i]
        spots = spots.query(f"ara_{coord}_rot >= {w[0]} and ara_{coord}_rot <= {w[1]}")
    spots_by_roi.append(spots)
    barcodes_by_roi.append(set(spots.corrected_bases.unique()))

print(f"Found {len(spots)} spots in the surrounding slice")

barcodes = barcodes_by_roi[0].intersection(barcodes_by_roi[1])
print(
    f"Found {len(barcodes)} barcodes in common (intersection of {len(barcodes_by_roi[0])} and {len(barcodes_by_roi[1])})"
)

In [None]:
# select the barcodes that are present in both slices in large numbers
spots = pd.concat(spots_by_roi)
spots = spots.query("corrected_bases in @barcodes")
bc_per_roi = spots.groupby(["slice", "corrected_bases"]).size().unstack().fillna(0)
best_barcodes = bc_per_roi.min(axis=0).sort_values(ascending=False)
best_barcodes = best_barcodes[best_barcodes > 5]
bc_per_roi[best_barcodes.index]

spots = spots.query("corrected_bases in @best_barcodes.index")
print(
    f"Found {len(spots)} spots in the pair of slices with the selected {len(best_barcodes)} barcodes"
)

In [None]:
# make a spot image for each barcode
from tqdm import tqdm
from iss_preprocess.segment.spots import make_spot_image

gaussian_width = 30

origin = np.array([spots.ara_y_rot.min(), spots.ara_z_rot.min()])
corner = (
    np.array([spots.ara_y_rot.max(), spots.ara_z_rot.max()])
    + (1 + gaussian_width * 20) / 1000
)
output_shape = ((corner - origin) * 1000).astype(int)

spot_images = np.empty((len(best_barcodes), 2, *output_shape), dtype="single")
for ibc, bc in tqdm(enumerate(best_barcodes.index), total=len(best_barcodes)):
    bc_df = spots[spots["corrected_bases"] == bc]
    for islice, (slice, slice_df) in enumerate(bc_df.groupby("slice")):
        # rename to x, y for make_spot_image
        sp = pd.DataFrame(
            slice_df[["ara_y_rot", "ara_z_rot"]].values - origin, columns=["x", "y"]
        )
        sp *= 1000
        img = make_spot_image(
            sp, gaussian_width=gaussian_width, dtype="single", output_shape=output_shape
        )
        spot_images[best_barcodes.index.get_loc(bc), islice] = img

In [None]:
sz = 2
max2plot = 20
fig, axes = plt.subplots(10, 4, figsize=(sz * 4, sz * 10))
colors = cm.get_cmap("tab20").colors
for ibc, bc in enumerate(best_barcodes.index):
    if ibc >= max2plot:
        break
    bc_df = spots[spots["corrected_bases"] == bc]
    for islice, (slice, slice_df) in enumerate(bc_df.groupby("slice")):
        axes[ibc % 10, islice + ibc // 10 * 2].imshow(
            spot_images[ibc, islice],
            cmap="Greys",
            extent=[
                origin[0] * 1000,
                corner[0] * 1000,
                corner[1] * 1000,
                origin[1] * 1000,
            ],
        )
        axes[ibc % 10, islice + ibc // 10 * 2].scatter(
            slice_df.ara_y_rot * 1000,
            slice_df.ara_z_rot * 1000,
            s=5,
            alpha=0.3,
            color=colors[ibc % 20],
        )
for x in axes.flatten():
    x.axis("equal")
    x.set_xticks([])
    x.set_yticks([])
fig.tight_layout()

In [None]:
# do phase correlation for each pair

from image_tools.registration.phase_correlation import phase_correlation

shifts = np.zeros((len(best_barcodes), 2))
max_corrs = np.zeros(len(best_barcodes))
phase_corrs = np.zeros((len(best_barcodes), *output_shape))
for ibc in tqdm(range(len(best_barcodes))):
    ref = np.nan_to_num(spot_images[ibc, 0])
    target = np.nan_to_num(spot_images[ibc, 1])
    shifts[ibc], max_corrs[ibc], phase_corrs[ibc], _ = phase_correlation(
        ref, target, whiten=False
    )

In [None]:
# plot all the phase correlations
fig, axes = plt.subplots(4, 5, figsize=(10, 8))
for i in range(20):
    if i >= max2plot:
        break
    axes.flatten()[i].imshow(phase_corrs[i], cmap="viridis")
    axes.flatten()[i].axis("off")
fig.tight_layout()

In [None]:
sum_corr = phase_corrs.sum(axis=0)
# find the max and the corresponding shift
maxcorr = np.max(sum_corr)
argmax = np.array(np.unravel_index(np.argmax(sum_corr), sum_corr.shape))
# shift is relative to center of image
shifts = argmax - np.array(sum_corr.shape) // 2
print(f"Max correlation: {maxcorr} at shift {shifts}")

In [None]:
fig = plt.figure(figsize=(10, 5))

ax = fig.add_subplot(2, 2, 1)
ax.imshow(
    phase_corrs.sum(axis=0).T,
    cmap="viridis",
    extent=(
        -output_shape[1] // 2,
        output_shape[1] // 2,
        -output_shape[0] // 2,
        output_shape[0] // 2,
    ),
    origin="lower",
)
ax.scatter(*shifts, color="red", marker="x")
ax.set_title("Sum of phase correlations")


ax = fig.add_subplot(2, 2, 3)
ax.scatter(shifts[0], shifts[1], s=500, color="red", marker="x")
ax.scatter(shifts[:, 0], shifts[:, 1], s=max_corrs / max_corrs.max() * 10, alpha=0.5)
ax.set_aspect("equal")
ax.set_xlabel("Y shift (px)")
ax.set_ylabel("Z shift (px)")
ax.set_title(r"Shifts for each pair - dot size $\alpha$ corr. coeff.")

for i in range(2):
    sh = shifts[:, i]
    ax = plt.subplot(2, 2, 2 + i * 2)
    ax.hist(sh, bins=np.arange(sh.min(), sh.max(), 10), alpha=0.5)
    ax.set_xlabel(f"{['Y', 'Z'][i]} shifts")
    ax.set_ylabel("Number of barcodes")
    ax.axvline(shifts[i], color="red", label="Max correlation")
fig.tight_layout()

In [None]:
# for each spot_image pair, find the argmax of the ref
argmaxes = np.zeros((len(best_barcodes), 2))
for ipair, img_pair in enumerate(spot_images):
    argmaxes[ipair] = np.unravel_index(np.argmax(img_pair[0]), img_pair[0].shape)

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
for i in range(2):
    sc = axes[i].scatter(
        argmaxes[:, 0],
        argmaxes[:, 1],
        c=shifts[:, i],
        cmap="RdBu",
        vmin=shifts[i] - 100,
        vmax=shifts[i] + 100,
    )
    axes[i].set_aspect("equal")
    cb = fig.colorbar(sc, ax=axes[i])
    cb.ax.axhline(shifts[i], color="red", label="Max correlation")
    axes[i].set_xlabel("Y max of spot image (um)")
    axes[i].set_ylabel("Z max of spot image (um)")
    axes[i].set_title(f'Shift {["Y", "Z"][i]}')

# Try to register with anchor then rabies

First step: use anchor to register slice linearly together
Second step: rabies for non-linear part


In [None]:
chamber = "chamber_07"
fixed_roi = 6
moving_roi = 7

In [None]:
# Get the stitched stack
data_path = f"{project}/{mouse}/{chamber}"
ops = issp.io.load_ops(data_path)
prefix = ops["reference_prefix"]
ref_ch = ops["reg2ref_reference_channels"]

full_fixed_image = issp.pipeline.stitch_registered(
    data_path=data_path,
    roi=fixed_roi,
    prefix=prefix,
    channels=ref_ch,
    projection="median",
    correct_illumination=True,
)
full_moving_image = issp.pipeline.stitch_registered(
    data_path=data_path,
    roi=moving_roi,
    prefix=prefix,
    channels=ref_ch,
    projection="median",
    correct_illumination=True,
)

print(full_fixed_image.shape, full_moving_image.shape)

In [None]:
# Downsample the images by a lot

import cv2

downsample = 20
fixed_image = cv2.resize(
    full_fixed_image[..., 0],
    (full_fixed_image.shape[1] // downsample, full_fixed_image.shape[0] // downsample),
)
moving_image = cv2.resize(
    full_moving_image[..., 0],
    (
        full_moving_image.shape[1] // downsample,
        full_moving_image.shape[0] // downsample,
    ),
)
print(fixed_image.shape, moving_image.shape)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(fixed_image, cmap="inferno", vmax=400)
_ = axes[1].imshow(moving_image, cmap="inferno", vmax=400)

In [None]:
# Pad the images to the same size
max_size = np.max([fixed_image.shape, moving_image.shape], axis=0)
fixed_padding = max_size - np.array(fixed_image.shape)
fixed_image = np.pad(fixed_image, [(0, fixed_padding[0]), (0, fixed_padding[1])])
moving_padding = max_size - np.array(moving_image.shape)
moving_image = np.pad(moving_image, [(0, moving_padding[0]), (0, moving_padding[1])])

rgb = np.dstack([fixed_image, moving_image])
rgb = issp.vis.to_rgb(rgb, colors=[(1, 0, 0), (0, 1, 0)], vmax=[500, 500])
_ = plt.imshow(rgb)

In [None]:
from dipy.align.transforms import AffineTransform2D
from dipy.align.imaffine import AffineRegistration

fixed_clipped = np.clip(fixed_image, 10, 300)
moving_clipped = np.clip(moving_image, 10, 300)
sigma_diff = 10.0
radius = 10


affreg = AffineRegistration(sigmas=[10, 4, 2])
transform = AffineTransform2D()
affine = affreg.optimize(fixed_image, moving_image, transform, params0=None)

moving_affine_transformed = affine.transform(moving_clipped)

fig, ax = plt.subplots(1, 1, figsize=(5, 5))

stereo = np.dstack([fixed_clipped, moving_affine_transformed])
stereo = issp.vis.to_rgb(stereo, colors=[(1, 0, 0), (0, 1, 0)])
ax.imshow(stereo)
ax.set_axis_off()
fig.tight_layout()

In [None]:
from dipy.align.imwarp import SymmetricDiffeomorphicRegistration
from dipy.align.metrics import CCMetric

clip_th = [20, 100]
sigma_diff = 5.0
radius = 5

fixed_clipped = np.clip(fixed_image, 10, 300)
moving_clipped = np.clip(moving_image, 10, 300)

fig, axes = plt.subplots(2, 2, figsize=(10, 10))
for ix, radius in enumerate([5, 20]):
    for iy, sigma_diff in enumerate([5, 20]):
        level_iters = [100, 50, 10]
        metric = CCMetric(2, sigma_diff, radius)
        sdr = SymmetricDiffeomorphicRegistration(
            metric, level_iters, ss_sigma_factor=0.5
        )

        mapping = sdr.optimize(
            static=fixed_clipped, moving=moving_clipped, prealign=affine.affine
        )

        warped = mapping.transform(moving_clipped)
        rgb = np.dstack([fixed_clipped, warped]).astype(float)
        rgb = issp.vis.to_rgb(rgb, colors=[(1, 0, 0), (0, 1, 0)], vmax=[500, 500])
        ax = axes[ix, iy]
        ax.imshow(rgb)
        ax.set_title(f"Radius {radius}, sigma_diff {sigma_diff}")
        ax.set_xticks([])
        ax.set_yticks([])
fig.tight_layout()

In [None]:
warped_full = mapping.transform(full_moving_image[..., 0])

In [None]:
plt.imshow(warped_full)