In [None]:
%load_ext autoreload
%autoreload 2

# Find duplicates across ROIs

Look for cells that are detected on multiple consecutive ROIs

In [None]:
import iss_preprocess as issp
import iss_analysis as issa
from iss_analysis.barcodes import barcodes as bar
from iss_analysis.barcodes.diagnostics import (
    plot_gmm_clusters,
    plot_error_along_sequence,
)
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
project = "becalia_rabies_barseq"
mouse = "BRAC8498.3e"

error_correction_ds_name = "BRAC8498.3e_error_corrected_barcodes_16"
data_path = f"{project}/{mouse}"
analysis_folder = issp.io.get_processed_path(data_path) / "analysis"

# Get rabies data

In [None]:
(
    rab_spot_df,
    rab_cells_barcodes,
    rabies_cell_properties,
) = issa.segment.get_barcode_in_cells(
    project,
    mouse,
    error_correction_ds_name,
    valid_chambers=None,
    save_folder=None,
    verbose=True,
    redo=True,
)

In [None]:
issp.io.get_processed_path(f"{project}/{mouse}") / "analysis" / "mcherry_cells"

In [None]:
rabies_cell_properties, mcherry_cell_properties = (
    issa.segment.match_starter_to_barcodes(
        project,
        mouse,
        rabies_cell_properties,
        rab_spot_df,
        mcherry_cells=None,
        redo=True,
        verbose=False,
        max_starter_distance=5,
        min_spot_number=4,
    )
)

# Register slices

In [None]:
from iss_analysis.registration import ara_registration
from iss_analysis.registration import utils

# add ARA coordinates to rabies cell properties
transform = ara_registration.get_ara_to_slice_rotation_matrix(spot_df=rab_spot_df)
rabies_cell_properties = ara_registration.rotate_ara_coordinate_to_slice(
    rabies_cell_properties, transform=transform
)

# select some random cell
# Focus on the slice that Alex registered
ref_chamber, ref_roi = "chamber_09", 3

In [None]:
if error_correction_ds_name is None:
    raise ValueError("error_correction_ds_name must be provided")
(
    rab_spot_df,
    _,
    rabies_cell_properties,
) = issa.segment.get_barcode_in_cells(
    project,
    mouse,
    error_correction_ds_name,
    valid_chambers=None,
    save_folder=None,
    verbose=False,
    add_ara_properties=True,
)


In [None]:
for chamber in [f"chamber_{i:02d}" for i in range(7, 11)]:
    data_path = f"{project}/{mouse}/{chamber}"
    for roi in range(1, 11):
        try:
            issp.pipeline.ara_registration.load_coordinate_image(data_path, roi, full_scale=False)
        except IOError:
            print(f"Empty {chamber} {roi}")

In [None]:
if False:
    rab_spot_df = issa.io.get_genes_spots(project, mouse, add_ara_info=True, reload=False, verbose=False)
    rab_spot_df['corrected_bases'] = rab_spot_df['gene']

In [None]:
spot_df = rab_spot_df.copy()
chamber = ref_chamber
roi = ref_roi
spots = spot_df.query(f"chamber == '{chamber}' and roi == {roi}")
spots.shape

In [None]:
# exclude spots that are out of the brain
spot_df = spot_df.query("area_id > 0")
all_planes = []

spots = spot_df.query(f"chamber == '{chamber}' and roi == {roi}")
ara_coords = spots[[f"ara_{i}" for i in "xyz"]].values
ara_coords = ara_coords[~np.any(np.isnan(ara_coords), axis=1)]
plane_coeffs = utils.fit_plane_to_points(ara_coords)
print(plane_coeffs)

In [None]:
# add rotated ara coordinates
transform = ara_registration.get_ara_to_slice_rotation_matrix(spot_df=rab_spot_df)
rabies_cell_properties = ara_registration.rotate_ara_coordinate_to_slice(
    rabies_cell_properties, transform=transform
)
# find cells in the ref slice, we will iterate on them
cells_in_ref = rabies_cell_properties.query(
    f"chamber == '{ref_chamber}' and roi == {ref_roi}"
)

surrounding_rois = utils.get_surrounding_slices(
    ref_chamber, ref_roi, project, mouse, include_ref=True
)
# to avoid to always have to groupby chamber and roi, make "slice"
surrounding_rois["slice"] = (
    surrounding_rois.chamber
    + "_"
    + surrounding_rois.roi.map(lambda x: f"{int(x):02d}")
)

# now do previous and next slice, if they exist
ref_slice = f"{ref_chamber}_{ref_roi:02d}"
ref_slice_df = surrounding_rois.query("slice == @ref_slice").iloc[0]
res_befaft = dict()

In [None]:
for islice, slice_df in surrounding_rois.iterrows():
    if slice_df.slice == ref_slice:
        # not register to self
        continue
    if slice_df.absolute_section < ref_slice_df.absolute_section:
        name = "previous"
    else:
        name = "next"
    print(f"Registering {name} slice: {slice_df.slice}")
    

In [None]:
from skimage.transform import warp_polar
from skimage.registration import phase_cross_correlation
from skimage.filters import window, difference_of_gaussians
from scipy.fft import fft2, fftshift


def estimate_rotation_and_scale(
    fixed, moving, dog=(5, 20), estimate_scale=True, debug=False
):
    # First, band-pass filter both images
    image = difference_of_gaussians(fixed, dog[0], dog[1])
    rts_image = difference_of_gaussians(moving, dog[0], dog[1])

    # window images
    wimage = image * window("hann", image.shape)
    rts_wimage = rts_image * window("hann", image.shape)

    # work with shifted FFT magnitudes
    image_fs = np.abs(fftshift(fft2(wimage)))
    rts_fs = np.abs(fftshift(fft2(rts_wimage)))

    # Create log-polar transformed FFT mag images and register
    shape = image_fs.shape
    radius = shape[0] // 8  # only take lower frequencies
    # shape = (shape[0] * 10, shape[1])
    if estimate_scale:
        scaling = "log"
    else:
        scaling = "linear"
    warped_image_fs = warp_polar(
        image_fs, radius=radius, output_shape=shape, order=0, scaling=scaling
    )
    warped_rts_fs = warp_polar(
        rts_fs, radius=radius, output_shape=shape, order=0, scaling=scaling
    )

    warped_image_fs = warped_image_fs[: shape[0] // 2, :]  # only use half of FFT
    warped_rts_fs = warped_rts_fs[: shape[0] // 2, :]
    shifts, _, _ = phase_cross_correlation(
        warped_image_fs, warped_rts_fs, upsample_factor=10, normalization=None
    )

    # Use translation parameters to calculate rotation and scaling parameters
    shiftr, shiftc = shifts[:2]
    recovered_angle = (360 / shape[0]) * shiftr
    klog = shape[1] / np.log(radius)
    if estimate_scale:
        shift_scale = np.exp(shiftc / klog)
    else:
        shift_scale = 1
    if debug:
        fig, axes = plt.subplots(2, 2, figsize=(8, 8))
        ax = axes.ravel()
        ax[0].set_title("Original Image FFT\n(magnitude; zoomed)")
        center = np.array(shape) // 2
        ax[0].imshow(
            image_fs[
                center[0] - radius : center[0] + radius,
                center[1] - radius : center[1] + radius,
            ],
            cmap="magma",
        )
        ax[1].set_title("Modified Image FFT\n(magnitude; zoomed)")
        ax[1].imshow(
            rts_fs[
                center[0] - radius : center[0] + radius,
                center[1] - radius : center[1] + radius,
            ],
            cmap="magma",
        )
        ax[2].set_title("Log-Polar-Transformed\nOriginal FFT")
        ax[2].imshow(warped_image_fs, cmap="magma")
        ax[3].set_title("Log-Polar-Transformed\nModified FFT")
        ax[3].imshow(warped_rts_fs, cmap="magma")
        plt.show()

        print(f"Recovered value for cc rotation: {recovered_angle}")
        print()
        print(f"Recovered value for scaling difference: {shift_scale}")
    return recovered_angle, shift_scale

In [None]:
from iss_analysis.registration import register_serial_sections
print(ref_slice)
print(slice_df.slice)
fig = plt.figure(figsize=(10, 10))
for iwin, window_size in enumerate([500]):
    min_spots = 10
    max_barcode_number= 50
    gaussian_width= 5
    cell_coords = cells_in_ref[["ara_y_rot", "ara_z_rot"]].values
    coords = cell_coords[10]
    out = issa.registration.register_serial_sections.register_local_spots(coords,
        spot_df=rab_spot_df,
        ref_slice=ref_slice,
        target_slice=slice_df.slice,
        window_size=window_size,
        min_spots=min_spots,
        max_barcode_number=max_barcode_number,
        gaussian_width=gaussian_width,
        verbose=20,
        debug=True,
    )
    shift, maxcorr, nbarcodes, shifts, max_corrs, phase_corrs, spot_images, best_barcodes = out

    ax =plt.subplot(3, 3, 1 + iwin *3)
    plt.imshow(spot_images[0,0])
    ax.set_title('Fixed')
    plt.ylabel(f"Window size: {window_size}um")
    ax=plt.subplot(3, 3, 2 + iwin *3)
    ax.set_title('Moving')
    plt.imshow(spot_images[0,1])
    plt.subplot(3, 3, 3 + iwin *3)
    plt.imshow(phase_corrs[0])
    plt.title(f"Shift: {np.round(shift)}")
for x in fig.axes:
    x.set_xticks([])
    x.set_yticks([])
    

In [None]:
angle, scale = estimate_rotation_and_scale(spot_images[0,0], spot_images[0,1], debug=True)

In [None]:
spot_images[0,0].max()

In [None]:
fig = plt.figure(figsize=(15, 7))
plt.subplot(1,2,1)
plt.imshow(spot_images[0,0], origin='lower', vmax=10)
plt.subplot(1,2,2)
plt.imshow(spot_images[0,1], origin='lower', vmax=10)

for x in fig.axes:
    rec = plt.Rectangle((300, 400), 100, 100, edgecolor='r', facecolor='none', alpha=0.2)
    x.add_patch(rec)


In [None]:
# estimate shift on the transformed image
from image_tools.similarity_transforms import transform_image
transformed = transform_image(spot_images[0,1], angle=angle, scale=scale)

shifts, err, phase_diff,  = phase_cross_correlation(spot_images[0,0], transformed, upsample_factor=10, normalization=None)
shifted = transform_image(transformed, shift=shifts)

stacked = np.dstack([spot_images[0,0], shifted])
rgb = issp.vis.to_rgb(stacked, colors=[(1,0,0), (0,1,0), (0,0,1)])
plt.figure(figsize=(10, 10))
plt.imshow(rgb)
plt.axis('off')

In [None]:


fig = plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
st = np.dstack([spot_images[0,0], spot_images[0,1]])
rgb = issp.vis.to_rgb(st, colors=[(1,0,0), (0,1,0)])
plt.imshow(rgb)
plt.subplot(2, 2, 2)
transformed = transform_image(spot_images[0,1], angle=angle, scale=scale)
st = np.dstack([spot_images[0,0], transformed])
rgb = issp.vis.to_rgb(st, colors=[(1,0,0), (0,1,0)])
plt.imshow(rgb)
plt.subplot(2, 2, 3)
transformed = transform_image(spot_images[0,1], angle=-angle, scale=scale)
st = np.dstack([spot_images[0,0], transformed])
rgb = issp.vis.to_rgb(st, colors=[(1,0,0), (0,1,0)])
plt.imshow(rgb)
plt.subplot(2, 2, 4)
transformed = transform_image(spot_images[0,0], angle=-angle, scale=scale)
st = np.dstack([transformed, spot_images[0,1]])
rgb = issp.vis.to_rgb(st, colors=[(1,0,0), (0,1,0)])
plt.imshow(rgb)

for x in fig.axes:
    x.axis('off')
plt.tight_layout()

In [None]:
from iss_analysis.registration import register_serial_sections
if False:
    res = register_serial_sections.register_single_section(
        project=project,
        mouse=mouse,
        use_rabies=False,
        error_correction_ds_name=error_correction_ds_name,
        ref_chamber="chamber_09",
        ref_roi=3,
        window_size=300,
        min_spots=10,
        max_barcode_number=50,
        gaussian_width=30,
        n_workers=1,
        verbose=True,
        use_slurm=False,
    )

In [None]:
from iss_analysis.vis import diagnostics
# debug plot around one cell
ref_slice = 'chamber_09_03'
target_slice = 'chamber_09_04'
rabies_cell_properties['slice'] = rabies_cell_properties.chamber + '_' + rabies_cell_properties.roi.map(lambda x: f"{x:02d}")
ref_cells = rabies_cell_properties.query(f"slice == '{ref_slice}'")
cell_info = ref_cells.iloc[100]

window_size=300
min_spots=10
max_barcode_number=50
gaussian_width=30

fig = diagnostics.check_serial_registration(cell_info, ref_slice, target_slice, rab_spot_df, rabies_cell_properties,
                                            window_size=300,

    min_spots=10,
    max_barcode_number=50,
    gaussian_width=30,
    shifts_to_use=None)


In [None]:
genes_spot_df = issa.io.get_genes_spots(project, mouse, add_ara_info=True, reload=False, verbose=False)


In [None]:
genes_spot_df['corrected_bases'] = 'GENES'
genes = list(genes_spot_df.corrected_bases.unique())
genes_spot_df['barcode_id'] = genes_spot_df['corrected_bases'].map(lambda x: genes.index(x))

In [None]:
rab_sp = rab_spot_df.query("chamber == 'chamber_09' and roi == 3")
gene_sp = genes_spot_df.query("chamber == 'chamber_09' and roi == 3")
rabies_cell_properties['slice'] = rabies_cell_properties.chamber + '_' + rabies_cell_properties.roi.map(lambda x: f"{x:02d}")
ref_cells = rabies_cell_properties.query(f"slice == '{ref_slice}'")
cell_info = ref_cells.iloc[250]

plt.figure(figsize=(10, 5))
for i, xy in enumerate(['xy', ['ara_y_rot', 'ara_z_rot']]):
    plt.subplot(1, 2, 1 + i)
    plt.scatter(gene_sp[xy[0]], gene_sp[xy[1]], c='b', s=1,alpha=0.1)
    plt.scatter(rab_sp[xy[0]], rab_sp[xy[1]], c='r', s=1, alpha=0.2)
    c_coords = cell_info[[xy[0], xy[1]]].values
    plt.scatter(cell_info[xy[0]], cell_info[xy[1]], c='k', s=10)
    plt.axis('equal')
    w = np.array([-0.2,0.2]) if i else np.array([-1000,1000])
    plt.xlim(*(w+ c_coords[0]))
    plt.ylim(*(w+ c_coords[1]))


In [None]:
plt.scatter(gene_sp.ara_y_rot, gene_sp.ara_z_rot, c='b', s=1,alpha=0.1)
plt.scatter(rab_sp.ara_y_rot, rab_sp.ara_z_rot, c='r', s=1, alpha=0.2)

In [None]:
from iss_analysis.vis import diagnostics
# debug plot around one cell
ref_slice = 'chamber_09_03'
target_slice = 'chamber_09_04'
rabies_cell_properties['slice'] = rabies_cell_properties.chamber + '_' + rabies_cell_properties.roi.map(lambda x: f"{x:02d}")
ref_cells = rabies_cell_properties.query(f"slice == '{ref_slice}'")
cell_info = ref_cells.iloc[250]

window_size=300
min_spots=10
max_barcode_number=50
gaussian_width=10

fig = diagnostics.check_serial_registration(cell_info, ref_slice, target_slice, genes_spot_df, rabies_cell_properties,
                                            window_size=300,

    min_spots=min_spots,
    max_barcode_number=max_barcode_number,
    gaussian_width=gaussian_width,
    spots_kwargs=dict(s=5, alpha=0.2),
    shifts_to_use=[-14.5, -73.5])

fig.set_size_inches(20, 10)


# Run registration and cell distance

For each cell we will register aroudn the cell to the next slice and to the n+2 slice.
Then transform points and find closest cell with same barcode.

In [None]:
# add barcode_id for plotting
barcodes = list(rab_spot_df.corrected_bases.unique())
rab_spot_df['barcode_id'] = rab_spot_df.corrected_bases.map(lambda x: barcodes.index(x))


In [None]:
from tqdm import tqdm

window_size = 300
min_spots = 10
max_barcode_number = 50
min_barcode_number = 1
gaussian_width = 30

section_infos = issa.io.get_sections_info(project, mouse)
section_infos["slice"] = (
    section_infos["chamber"] + "_" + section_infos["roi"].map(lambda x: f"{x:02d}")
)
ref_chamber, ref_roi = "chamber_09", 2
ref_sec = section_infos.query("chamber == @ref_chamber and roi == @ref_roi").iloc[0]
ref_slice = ref_sec.slice
print(ref_slice)
todo = (ref_sec.absolute_section + np.arange(3)).astype(int)

rabies_cell_properties['slice'] = rabies_cell_properties['chamber'] + "_" + rabies_cell_properties['roi'].map(lambda x: f"{x:02d}")
cells = rabies_cell_properties.query("chamber == @ref_chamber and roi == @ref_roi")
ref_spots = rab_spot_df.query("slice == @ref_slice")
nbcs_all = np.zeros((cells.shape[0], len(todo)))* np.nan
shifts_all = np.zeros((cells.shape[0], len(todo), 2))* np.nan
max_corrs_all = np.zeros((cells.shape[0], len(todo)))* np.nan
dist2closest = np.zeros((cells.shape[0], len(todo))) * np.nan
closest_target = [['NaN'] * cells.shape[0] for _ in range(len(todo))]
for sec_id in todo:
    print(f"Do section {sec_id}")
    slice_df = section_infos.query("absolute_section == @sec_id").iloc[0]
    for ic, (_, cell_info) in tqdm(enumerate(cells.iterrows()), total=cells.shape[0]):
        # Get candidate cells that might be the same
        cell_barcodes = cell_info.all_barcodes
        cells_target = rabies_cell_properties.query("slice == @slice_df.slice")
        cells_target = cells_target[cells_target.all_barcodes.map(lambda x: any([bc in cell_barcodes for bc in x]))]
        distance = np.sqrt((cells_target.ara_y_rot - cell_info.ara_y_rot) ** 2 + (cells_target.ara_z_rot - cell_info.ara_z_rot) ** 2)
        cells_target = cells_target[distance < window_size * 2 / 1000]
        if cells_target.shape[0] == 0:
            dist2closest[ic, sec_id - todo[0]] = np.inf
            continue

        shifts, maxcorr, n_bcs = (
            issa.registration.register_serial_sections.register_local_spots(
                center_point=(cell_info.ara_y_rot, cell_info.ara_z_rot),
                spot_df=rab_spot_df,
                ref_slice=ref_slice,
                target_slice=slice_df.slice,
                window_size=window_size,
                min_spots=min_spots,
                max_barcode_number=max_barcode_number,
                gaussian_width=gaussian_width,
                verbose=False,
                debug=False,
            )
        )
        nbcs_all[ic, sec_id - todo[0]] = n_bcs
        shifts_all[ic, sec_id - todo[0]] = shifts
        max_corrs_all[ic, sec_id - todo[0]] = maxcorr
        
        if n_bcs < min_barcode_number:
            continue
        if any(np.isnan(shifts)):
            print(f"Cell {ic} has NaN shift")
            continue
        # shifts the target cell ara_y and ara_z
        shifted_targets = cells_target.copy()
        if slice_df.slice == ref_slice:
            # remove the reference cell from the target cells if it is in the same slice
            shifted_targets = shifted_targets.drop(cell_info.name)
        if len(shifted_targets) == 0:
            dist2closest[ic, sec_id - todo[0]] = np.inf
            continue
        shifted_targets.ara_y_rot += shifts[1] / 1000
        shifted_targets.ara_z_rot += shifts[0] / 1000
        # compute the distance between the target cells and the reference cell
        distance = np.sqrt((shifted_targets.ara_y_rot - cell_info.ara_y_rot) ** 2 + (shifted_targets.ara_z_rot - cell_info.ara_z_rot) ** 2)
        # select the closest cell
        best_target = shifted_targets.loc[distance.idxmin()]
        closest_target[sec_id - todo[0]][ic] = best_target.name
        dist2closest[ic, sec_id - todo[0]] = distance.min()

In [None]:
labels = ['within slice', 'next slice', 'skip one slice']
fig = plt.figure(figsize=(5, 3))
ax = fig.add_subplot(111)
for iax, label in enumerate(labels):
    ax.hist(dist2closest[:,iax]*1000, bins=np.arange(0, 1000, 1), label=label, alpha=0.5, histtype='step', lw=3, cumulative=True,
            density=True)
ax.set_xlim(0, 100)
ax.set_ylim(0, 0.4)
ax.legend(loc='upper right')
ax.set_xlabel('Distance to closest cell (um)')
ax.set_ylabel('Proportion of cells')

In [None]:
print(f"{dist2closest.shape[0]} cells")
ncells = dist2closest.shape[0]
reg_fail = np.isnan(dist2closest).sum(axis=0)
no_target = np.isinf(dist2closest).sum(axis=0)
has_close_target = ((dist2closest * 1000) < 15).sum(axis=0)
print(f"{no_target[1]}/{ncells} have no neighbour in the next slice.")
left = ncells - no_target[1]
print(f"Registration failed for {reg_fail[1]}/{left} cells.")
left = left - reg_fail[1]
print(f"{has_close_target[1]}/{left} have a close target in the next slice.")
print(np.isnan(dist2closest).sum(axis=0))
print(np.isinf(dist2closest).sum(axis=0))
print(((dist2closest * 1000) < 15).sum(axis=0))

In [None]:
d = dist2closest[:,1]
bad = np.isnan(d)
print(f"Bad cells: {bad.sum()}")
print(f"Good cells: {(~bad).sum()}")

In [None]:
spot_ref = rab_spot_df.query("slice == @ref_slice")
# cut around ara_y_rot and ara_z_rot of cell_info
spot_ref = spot_ref.query(
    "ara_y_rot > @cell_info.ara_y_rot - @window_size / 1000 and ara_y_rot < @cell_info.ara_y_rot + @window_size / 1000"
)
spot_ref = spot_ref.query(
    "ara_z_rot > @cell_info.ara_z_rot - @window_size / 1000 and ara_z_rot < @cell_info.ara_z_rot + @window_size / 1000"
)
bc_ref = spot_ref.corrected_bases.unique()
spot_target = rab_spot_df.query("slice == @slice_df.slice")
spot_target = spot_target.query(
    "ara_y_rot > @cell_info.ara_y_rot - @window_size / 1000 and ara_y_rot < @cell_info.ara_y_rot + @window_size / 1000"
)
spot_target = spot_target.query(
    "ara_z_rot > @cell_info.ara_z_rot - @window_size / 1000 and ara_z_rot < @cell_info.ara_z_rot + @window_size / 1000"
)
bc_target = spot_target.corrected_bases.unique()
common_bc = set(bc_ref).intersection(bc_target)
print(f"{len(common_bc)} common barcodes")
common_bc = list(common_bc)
spot_ref = spot_ref.query("corrected_bases in @common_bc")
spot_target = spot_target.query("corrected_bases in @common_bc")


nspots_ber_cb = spot_ref.groupby("corrected_bases").size()[common_bc]
nspots_bct_cb = spot_target.groupby("corrected_bases").size()[common_bc]
min_nspots = np.vstack([nspots_ber_cb, nspots_bct_cb]).min(axis=0)
valid = min_nspots > 10
common_bc = np.array(common_bc)[valid]
spot_ref = spot_ref.query("corrected_bases in @common_bc")
spot_target = spot_target.query("corrected_bases in @common_bc")
print(f"{len(spot_ref)} spots in reference slice and {len(spot_target)} in target slice")

In [None]:
ms = 20
ref_kw = dict(cmap="tab20", marker="s",  s=ms, alpha=0.5, c=spot_ref.barcode_id%20)
target_kw = dict(cmap="tab20", marker="o",  s=ms, alpha=0.5, c=spot_target.barcode_id%20)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for i in range(3):
    axes[0, i].plot(cell_info.ara_y_rot, cell_info.ara_z_rot, "o", color='k')
    axes[0, i].scatter(spot_ref.ara_y_rot, spot_ref.ara_z_rot, **ref_kw)
axes[1, 2].scatter(spot_ref.ara_y_rot, spot_ref.ara_z_rot, **ref_kw)
axes[1, 0].scatter(spot_target.ara_y_rot, spot_target.ara_z_rot , **target_kw)
axes[1, 1].scatter(spot_target.ara_y_rot - shifts[0]/1000, spot_target.ara_z_rot - shifts[1]/1000, **target_kw)
axes[1, 2].scatter(spot_target.ara_y_rot - shifts[0]/1000, spot_target.ara_z_rot - shifts[1]/1000, **target_kw)
axes[0, 2].scatter(spot_target.ara_y_rot - shifts[0]/1000, spot_target.ara_z_rot - shifts[1]/1000, **target_kw)
if False:
    axes[1].plot(cell_info.ara_y_rot, cell_info.ara_z_rot, "o", color='k')
    axes[1].scatter(spot_ref.ara_y_rot, spot_ref.ara_z_rot, **ref_kw)
    axes[1].scatter(spot_target.ara_y_rot - shifts[0]/1000, spot_target.ara_z_rot - shifts[1]/1000, **target_kw)
for x in axes.flatten():
    x.set_aspect("equal")   
    if True:
        w = 0.2
        x.set_xlim(cell_info.ara_y_rot - w, cell_info.ara_y_rot + w)
        x.set_ylim(cell_info.ara_z_rot - w, cell_info.ara_z_rot + w)
    x.set_xticks([])
    x.set_yticks([])
fig.tight_layout()

In [None]:
shifts, maxcorr, n_bcs, phase_corrs, spot_images, best_barcodes = (
    issa.registration.register_serial_sections.register_local_spots(
        center_point=(cell_info.ara_y_rot, cell_info.ara_z_rot),
        spot_df=rab_spot_df,
        ref_slice=ref_slice,
        target_slice=slice_df.slice,
        window_size=window_size,
        min_spots=min_spots,
        max_barcode_number=max_barcode_number,
        gaussian_width=gaussian_width,
        verbose=True,
        debug=True,
    )
)


In [None]:
fig, axes = plt.subplots(len(spot_images), 3, figsize=(5, 50))
for row, data in enumerate(spot_images):
    axes[row, 0].imshow(data[0])
    axes[row, 1].imshow(data[1])
    axes[row, 2].imshow(phase_corrs[row])
for x in axes.flatten():
    x.set_aspect("equal")
    x.set_xticks([])
    x.set_yticks([])
fig.tight_layout()

In [None]:
spots_ref = rab_spot_df.query("slice == @ref_slice and corrected_bases == @cell_info.main_barcode")
spots_target = rab_spot_df.query("slice == @slice_df.slice and corrected_bases == @cell_info.main_barcode")
spots_ref.shape
spots_target.shape

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(5, 5))
ax[0, 0].imshow(spot_images[0, 0])
ax[0, 1].imshow(spot_images[0, 1])
ax[1, 0].imshow(phase_corrs[0])

In [None]:
res = pd.DataFrame(
    columns=["shift_y", "shift_z", "maxcorr", "n_barcodes"],
    data=np.vstack([np.hstack(a) for a in assignment_by_bc]),
)

import seaborn as sns

fig = plt.figure(figsize=(8, 2))
for iax, col in enumerate(["shift_y", "shift_z", "n_barcodes"]):
    ax = fig.add_subplot(1, 3, iax + 1)
    sns.histplot(res, x=col, ax=ax, bins=20)

fig.tight_layout()

In [None]:
sns.pairplot(res)

In [None]:
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

fig = plt.figure(figsize=(12, 5))

toplot = np.vstack([np.hstack(a) for a in assignment_by_bc])
labels = ["Shift Y (um)", "Shift Z (um)", "Max Corr", "N barcodes used"]
cmaps = ["RdBu", "RdBu", "viridis", "viridis"]
xycoords = (cells_in_ref[["y", "x"]].values * 0.2) / 1000
ara_coords = cells_in_ref[["ara_z_rot", "ara_y_rot"]].values
# px_size = issp.io.get_pixel_size(data_path)
for iax, label in enumerate(labels):
    for ic, data in enumerate([ara_coords, xycoords]):
        ax = fig.add_subplot(2, 4, 4 * ic + iax + 1, facecolor="k")
        ax_divider = make_axes_locatable(ax)
        # Add an Axes to the right of the main Axes.
        cax = ax_divider.append_axes("right", size="7%", pad="2%")
        vmin, vmax = np.nanpercentile(toplot[:, iax], [1, 99])
        if iax < 2:
            vminmax = max(abs(vmin), abs(vmax))
            vmin, vmax = -vminmax, vminmax
        sc = ax.scatter(
            data[:, 0],
            data[:, 1],
            c=toplot[:, iax],
            s=1,
            cmap=cmaps[iax],
            vmin=vmin,
            vmax=vmax,
            alpha=1,
        )
        ax.set_xlabel("Raw Y (mm)" if ic else "ARA Z (mm)")
        ax.set_ylabel("Raw X (mm)" if ic else "ARA Y (mm)")
        cb = plt.colorbar(sc, cax=cax)
        cb.set_label(label)
        ax.set_aspect("equal")
        if not ic:
            ax.set_ylim(5, 1.8)
            ax.set_xlim(6, 10)
        else:
            ax.set_ylim(1, 3.5)
            ax.set_xlim(0.2, 4)
plt.tight_layout()

In [None]:
res

# End register slices

In [None]:
# find starter
if False:
    starters_positions = issa.io.get_starter_cells(project, mouse)
    rabies_cell_properties = issa.segment.match_starter_to_barcodes(
        project,
        mouse,
        rabies_cell_properties,
        rab_spot_df,
        mcherry_cells=starters_positions,
        redo=True,
    )
    rabies_cell_properties.head()

# Distance between cells

Look at what is the distance between rabies cells with the same barcode inside one ROI.

In [None]:
# Iterate on barcodes and on chamber/roi, use a KDTree to find the closest pair of
# cells with same barcode and build a full list
from scipy.spatial import KDTree

ncellbarcode = rabies_cell_properties.all_barcodes.apply(len).sum()
within_slice_distance_px = np.zeros(ncellbarcode)
rabies_cell_properties["closest_distance_px"] = [
    np.zeros(len(x)) for x in rabies_cell_properties.all_barcodes
]
i = 0
for (roi, chamber), roi_df in rabies_cell_properties.groupby(["roi", "chamber"]):
    all_barcodes = set()
    for bc in roi_df["all_barcodes"].values:
        all_barcodes.update(bc)
    for bc in all_barcodes:
        df = roi_df[roi_df["all_barcodes"].apply(lambda x: bc in x)]
        bc_index = df.all_barcodes.apply(lambda x: x.index(bc))
        if len(df) > 1:
            coords = df[["x", "y"]].values
            tree = KDTree(coords)
            dist, idx = tree.query(coords, k=2)
            within_slice_distance_px[i : i + len(df)] = dist[:, 1]
            for j in range(len(df)):
                rabies_cell_properties.loc[df.index[j], "closest_distance_px"][
                    bc_index[j]
                ] = dist[j, 1]
        else:
            within_slice_distance_px[i] = np.nan
            rabies_cell_properties.loc[df.index[0], "closest_distance_px"][
                bc_index[0]
            ] = np.nan
        i += len(df)

In [None]:
px_size = issp.io.get_pixel_size(f"{project}/{mouse}/{chamber}")
within_slice_distance_um = within_slice_distance_px * px_size

fig, axes = plt.subplots(2, 1, figsize=(6, 3))
ax = axes[0]
ax.set_title("Within slice distance between cells with same barcode")
twin_ax = ax.twinx()
ax.hist(within_slice_distance_um, bins=np.arange(0, 501, 1))
n_far = np.sum(within_slice_distance_um > 500)
twin_ax.scatter(501, n_far, color="red", label=">500um")
twin_ax.set_ylim(0, n_far * 1.3)
ax.set_xlim(0, 501)
twin_ax.legend(loc="upper right")
ax.set_ylabel("Number of cells")
ax = axes[1]
ax.hist(within_slice_distance_um, bins=np.arange(0, 101, 1))
ax.set_ylabel("Number of cells")
ax.set_xlabel("Distance (um)")
fig.tight_layout()
plt.show()

In [None]:
closeby = rabies_cell_properties[
    rabies_cell_properties.closest_distance_px.map(np.nanmin) < (30 / px_size)
]
print(f"Number of cells with a close neighbor: {len(closeby)}")

# Get ARA coordinates of spots

In [None]:
# Get transform to rotate ARA coordinates to slice
from iss_analysis.registration import ara_registration

transform = ara_registration.get_ara_to_slice_rotation_matrix(rab_spot_df)
rabies_cell_properties = ara_registration.rotate_ara_coordinate_to_slice(
    rabies_cell_properties, transform=transform
)
rab_spot_df = ara_registration.rotate_ara_coordinate_to_slice(
    rab_spot_df, transform=transform
)

In [None]:
from matplotlib import cm

fig, axes = plt.subplots(3, 3, figsize=(10, 6))
step = 0.01
bins = [
    np.arange(5, 9.2, step),
    np.arange(step, 8, step * 10),
    np.arange(4, 12, step * 10),
]
rab_spot_df["slice"] = (
    rab_spot_df["chamber"] + "_" + rab_spot_df["roi"].map(lambda x: f"{x:02d}")
)
rabies_cell_properties["slice"] = (
    rabies_cell_properties["chamber"]
    + "_"
    + rabies_cell_properties["roi"].map(lambda x: f"{x:02d}")
)
slices = sorted(rab_spot_df["slice"].unique())
colors = cm.get_cmap("viridis", len(slices))
for islice, slice in enumerate(slices):
    cell_prop = rabies_cell_properties[rabies_cell_properties["slice"] == slice]
    for iax, coord in enumerate("xyz"):
        axes[iax, 0].hist(
            cell_prop[f"ara_{coord}"],
            alpha=0.5,
            label=f"{chamber} {roi}",
            histtype="step",
            bins=bins[iax],
            color=colors(islice),
        )
        axes[iax, 1].hist(
            cell_prop[f"ara_{coord}_rot"],
            alpha=0.5,
            label=f"{chamber} {roi}",
            histtype="step",
            bins=bins[iax],
            color=colors(islice),
        )
        axes[iax, 2].scatter(
            cell_prop[f"ara_{coord}"],
            cell_prop[f"ara_{coord}_rot"] - cell_prop[f"ara_{coord}"],
            label=f"{chamber} {roi}",
            s=5,
            alpha=0.5,
            color=colors(islice),
        )
        for i in range(3):
            axes[iax, i].set_xlabel(f"ARA {coord}")
        axes[iax, 2].set_ylabel(f"Rotated {coord} - ARA {coord}")
for i in range(2):
    for j in range(3):
        axes[j, i].set_ylabel(f"Number of cells")

axes[0, 0].set_title("ARA coordinates")
axes[0, 1].set_title("ARA coordinates rotated")
axes[0, 2].set_title("Difference")
fig.tight_layout()

In [None]:
# now calculate the 3d dimensions in the rotated ara coordinates
from tqdm import tqdm

all_barcodes = set()
for bc in rabies_cell_properties["all_barcodes"].values:
    all_barcodes.update(bc)
ncellbarcode = rabies_cell_properties.all_barcodes.apply(len).sum()
ara_3d_distance = np.zeros(ncellbarcode)
ara_within_distance = np.zeros(ncellbarcode)
i = 0
for barcode in tqdm(all_barcodes):
    df = rabies_cell_properties[
        rabies_cell_properties["all_barcodes"].apply(lambda x: barcode in x)
    ]
    ara_coords = df[["ara_x_rot", "ara_y_rot", "ara_z_rot"]].values
    ara_coords[~np.isfinite(ara_coords)] = 0
    tree = KDTree(ara_coords)
    dist, idx = tree.query(ara_coords, k=2)
    ara_3d_distance[i : i + len(df)] = dist[:, 1]
    for (chamber, roi), cell_prop in df.groupby(["chamber", "roi"]):
        ara_coords = cell_prop[["ara_x_rot", "ara_y_rot", "ara_z_rot"]].values
        ara_coords[~np.isfinite(ara_coords)] = 0
        tree = KDTree(ara_coords)
        dist, idx = tree.query(ara_coords, k=2)
        ara_within_distance[i : i + len(cell_prop)] = dist[:, 1]
        i += len(cell_prop)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 3))
ax.hist(
    np.clip(ara_3d_distance * 1000, 0, 500),
    bins=np.arange(0, 501, 5),
    label="All slices",
    alpha=0.5,
    cumulative=False,
)
ax.hist(
    np.clip(ara_within_distance * 1000, 0, 500),
    bins=np.arange(0, 501, 5),
    alpha=0.5,
    label="Within slice",
    cumulative=False,
)
ax.set_xlabel("Distance between cells with same barcode (um)")
ax.set_ylabel("Number of cells")
ax.legend()
ax.set_ylim(0, 500)
ax.set_xlim(0, 500)

In [None]:
# Now look only at neighboring slices and project the 3d distance to the slice
from iss_analysis.registration import utils

close_within_slice = dict()
close_around_slice = dict()
ara_2d_distance = []
ara_within_2d_distance = []
for barcode in tqdm(all_barcodes):
    cells_df = rabies_cell_properties[
        rabies_cell_properties["all_barcodes"].apply(lambda x: barcode in x)
    ]
    for (chamber, roi), cell_prop in cells_df.groupby(["chamber", "roi"]):
        # within first
        ara_coords = cell_prop[["ara_y_rot", "ara_z_rot"]].values
        ara_coords[~np.isfinite(ara_coords)] = 0
        tree = KDTree(ara_coords)
        dist, idx = tree.query(ara_coords, k=2)
        ara_within_2d_distance.extend(dist[:, 1])
        closeby = np.where(dist[:, 1] < 30 / 1000)[0]
        for c in closeby:
            csource = cell_prop.index[c]
            if csource not in close_within_slice:
                close_within_slice[csource] = []
            close_within_slice[csource].append([cell_prop.index[idx[c]], dist[c, 1]])

        # now to surrounding slices
        surrounding_rois = utils.get_surrounding_slices(
            chamber, roi, project, mouse, include_ref=False
        )
        surrounding_slices = (
            surrounding_rois.chamber
            + "_"
            + surrounding_rois.roi.map(lambda x: f"{x:02d}")
        )
        surrounding_cells = cells_df[cells_df["slice"].isin(surrounding_slices)]
        surrounding_coords = surrounding_cells[["ara_y_rot", "ara_z_rot"]].values
        surrounding_coords[~np.isfinite(surrounding_coords)] = 0
        tree = KDTree(surrounding_coords)
        dist, idx = tree.query(ara_coords, k=1)
        ara_2d_distance.extend(dist)
        closeby = np.where(dist < 30 / 1000)[0]
        for c in closeby:
            csource = cell_prop.index[c]
            if csource not in close_around_slice:
                close_around_slice[csource] = []
            close_around_slice[csource].append(
                [surrounding_cells.index[idx[c]], dist[c]]
            )

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 3))
params = dict(
    bins=np.arange(0, 201, 3), alpha=1, cumulative=False, lw=2, histtype="step"
)
ax.hist(
    np.clip(np.array(ara_2d_distance) * 1000, 0, 200),
    label="Neighbouring slices",
    **params
)
ax.hist(
    np.clip(np.array(ara_within_2d_distance) * 1000, 0, 200),
    label="Within slice",
    **params
)
ax.set_xlabel("Distance between cells with same barcode (um)")
ax.set_ylabel("Number of cells")
ax.legend()
ax.set_ylim(0, 150)

In [None]:
print(f"{len(close_within_slice)} cells with close neighbors within slice")
print(f"{len(close_around_slice)} cells with close neighbors around slice")

In [None]:
for c in close_around_slice.keys():
    if "chamber_09_3" in c:
        print(c, close_around_slice[c])

In [None]:
list(close_around_slice.keys())[0]

# Plot example close cells

Find an example of cells that are close to each other and plot them.
We want the rotated ara of the spots and mask of the first cells, same for the second
And the row slice for both cells in a window around the cell

In [None]:
example_cell = "chamber_09_3_14937"

neighbour, dst = close_around_slice[example_cell][0]
print(f"Cell {example_cell} is {dst * 1000:.2f} um away from {neighbour}")

cells = rabies_cell_properties.loc[[example_cell, neighbour]]
cells

In [None]:
# plot the ara part
from iss_analysis import vis
from functools import partial

fig, axes = plt.subplots(2, 2, figsize=(10, 10))
ax_ara = axes[0, 0]
# window size in um around the source to plot background
window = 200


source = cells.iloc[0]
window = np.array([-1, 1]) * window / 1000
center = cells.iloc[0][["ara_y_rot", "ara_z_rot"]].values

_get_spot = partial(
    vis.get_spot_part,
    xlim=center[0] + window,
    ylim=center[1] + window,
    xcol="ara_y_rot",
    ycol="ara_z_rot",
)

labels = ["Source", "Neighbour"]
colors = cm.get_cmap("Set2").colors[3:]
bg_colors = ["lightblue", "lightgrey"]
for ic, (cname, cell) in enumerate(cells.iterrows()):
    c, r = cell.chamber, cell.roi
    bg_cells = rabies_cell_properties[
        (rabies_cell_properties.chamber == c) & (rabies_cell_properties.roi == r)
    ]
    bg_cells = _get_spot(bg_cells)
    ax_ara.scatter(
        bg_cells["ara_y_rot"] * 1000,
        bg_cells["ara_z_rot"] * 1000,
        color=bg_colors[ic],
        s=100,
        label=f"{c} roi {r}",
        zorder=0,
        alpha=0.5,
    )

    bg_spots = rab_spot_df[(rab_spot_df.chamber == c) & (rab_spot_df.roi == r)]
    bg_spots = _get_spot(bg_spots)
    ax_ara.scatter(
        bg_spots["ara_y_rot"] * 1000,
        bg_spots["ara_z_rot"] * 1000,
        color=bg_colors[ic],
        s=10,
        zorder=0,
        alpha=0.5,
    )

    ax_ara.scatter(
        cell["ara_y_rot"] * 1000,
        cell["ara_z_rot"] * 1000,
        color=colors[ic],
        label=f"{cname}",
        s=100,
    )
    # find spots of this cell
    spots = rab_spot_df[rab_spot_df["cell_mask"] == cell.cell_id]
    ax_ara.scatter(
        spots["ara_y_rot"] * 1000,
        spots["ara_z_rot"] * 1000,
        color=colors[ic],
        ec="k",
        label=f"{labels[ic]} spots",
        s=10,
    )

    ax = axes[1, ic]
    ax.scatter(
        bg_cells.x,
        bg_cells.y,
        c=bg_cells.cell_id % 20,
        cmap="tab20",
        vmin=0,
        vmax=19,
        s=100,
        alpha=0.5,
    )
    bg = bg_spots.cell_mask == -1
    # ax.scatter(bg_spots[bg].x, bg_spots[bg].y, color='k', s=10, alpha=0.5)
    # ax.scatter(bg_spots[~bg].x, bg_spots[~bg].y, c=bg_spots[~bg].cell_mask % 20, s=10, alpha=0.5, cmap='tab20', vmin=0, vmax=19)
    ax.scatter(
        bg_spots.x,
        bg_spots.y,
        c=bg_spots.index % 20,
        s=10,
        alpha=0.5,
        cmap="tab20",
        vmin=0,
        vmax=19,
    )
    ax.scatter(cell["x"], cell["y"], color=colors[ic], s=100, marker="x")
    if False:
        ax.scatter(
            cell["x"],
            cell["y"],
            color=colors[ic],
            s=100,
        )
        ax.scatter(
            spots["x"],
            spots["y"],
            color=colors[ic],
            ec="k",
            s=10,
        )

for x in axes.flatten():
    x.axis("off")
ax_ara.legend(loc="upper right", ncol=2)
fig.tight_layout()

In [None]:
barcodes = list(rab_spot_df.corrected_bases.unique())
rab_spot_df["bc_index"] = rab_spot_df.corrected_bases.map(lambda x: barcodes.index(x))
rab_spot_df.bc_index

In [None]:
# plot the ara part
from iss_analysis import vis
from functools import partial

fig, axes = plt.subplots(1, 2, figsize=(20, 10))

# window size in um around the source to plot background
window = 300


source = cells.iloc[0]
window = np.array([-1, 1]) * window / 1000
center = cells.iloc[1][["ara_y_rot", "ara_z_rot"]].values + 0.4

_get_spot = partial(
    vis.get_spot_part,
    xlim=center[0] + window,
    ylim=center[1] + window,
    xcol="ara_y_rot",
    ycol="ara_z_rot",
)

labels = ["Source", "Neighbour"]
colors = cm.get_cmap("Set2").colors[3:]
bg_colors = ["lightblue", "lightgrey"]
for ic, (cname, cell) in enumerate(cells.iterrows()):
    c, r = cell.chamber, cell.roi
    bg_cells = rabies_cell_properties[
        (rabies_cell_properties.chamber == c) & (rabies_cell_properties.roi == r)
    ]
    bg_cells = _get_spot(bg_cells)
    axes[ic].scatter(
        bg_cells["ara_y_rot"] * 1000,
        bg_cells["ara_z_rot"] * 1000,
        color=bg_colors[ic],
        s=100,
        label=f"{c} roi {r}",
        zorder=0,
        alpha=0.1,
    )

    bg_spots = rab_spot_df[(rab_spot_df.chamber == c) & (rab_spot_df.roi == r)]
    bg_spots = _get_spot(bg_spots)
    axes[ic].scatter(
        bg_spots["ara_y_rot"] * 1000,
        bg_spots["ara_z_rot"] * 1000,
        c=bg_spots.bc_index % 20,
        cmap="tab20",
        vmin=0,
        vmax=19,
        s=20,
        zorder=0,
        alpha=1,
    )


for x in axes.flatten():
    x.axis("on")
    x.grid(True)
ax_ara.legend(loc="upper right", ncol=2)
fig.tight_layout()

# Register ara data

The aim is to register slices together but it's hard to do directly, so first we 
register to the ara, rotate the ara to align with the slicing plane and then only
register the slices. `rab_spot_df` already has the `ara_x/y/z_rot` information. 

We will register locally, around one cell of interest. First we subselect the data
in window around that cell, then we register the two slices together.
For that we get barcodes present in both slices, keep the 
one that have enough dots and make a blurred version of the mask to register the two.


In [None]:
from iss_analysis.registration import ara_registration
from iss_analysis.registration import utils

# add ARA coordinates to rabies cell properties
transform = ara_registration.get_ara_to_slice_rotation_matrix(spot_df=rab_spot_df)
rabies_cell_properties = ara_registration.rotate_ara_coordinate_to_slice(
    rabies_cell_properties, transform=transform
)

# select some random cell
cell_of_interest = rabies_cell_properties.iloc[200]

ref_chamber, ref_roi = cell_of_interest.chamber, cell_of_interest.roi
surrounding_rois = utils.get_surrounding_slices(
    ref_chamber, ref_roi, project, mouse, include_ref=True
)

# we will look only at the previous slice for now
surrounding_rois = surrounding_rois.iloc[:2]

In [None]:
from iss_analysis.registration import register_serial_sections

ref_slice = (
    surrounding_rois.iloc[0].chamber + "_" + f"{surrounding_rois.iloc[0].roi:02d}"
)
target_slice = (
    surrounding_rois.iloc[1].chamber + "_" + f"{surrounding_rois.iloc[1].roi:02d}"
)
shifts, maxcorr, phase_corrs, spot_images, barcodes = (
    register_serial_sections.register_local_spots(
        spot_df=rab_spot_df,
        ref_slice=ref_slice,
        target_slice=target_slice,
        center_point=cell_of_interest[["ara_y_rot", "ara_z_rot"]].values,
        window_size=250,
        min_spots=5,
        max_barcode_number=500,
        gaussian_width=30,
        verbose=False,
    )
)

In [None]:
# Loop on all cells and run the registration
from tqdm import tqdm
from functools import partial
from multiprocessing import Pool

verbose = True
n_workers = 4
cells_in_ref = rabies_cell_properties.query(
    f"chamber == '{ref_chamber}' and roi == {ref_roi}"
)
shifts = np.zeros((len(cells_in_ref), 2))
maxcorrs = np.zeros(len(cells_in_ref))

reg_one_cell = partial(
    register_serial_sections.register_local_spots,
    spot_df=rab_spot_df,
    ref_slice=ref_slice,
    target_slice=target_slice,
    window_size=250,
    min_spots=5,
    max_barcode_number=500,
    gaussian_width=30,
    verbose=False,
    debug=False,
)
cell_coords = cells_in_ref[["ara_y_rot", "ara_z_rot"]].values
if n_workers == 1:
    assignment_by_bc = list(map(reg_one_cell, cell_coords))
else:
    if verbose:
        print(f"Registering {len(cell_coords)} cells using {n_workers} workers")
    with Pool(n_workers) as pool:
        assignment_by_bc = list(
            tqdm(
                pool.imap(reg_one_cell, cell_coords),
                total=len(cell_coords),
            )
        )

In [None]:
window = 2000

cell_pos = cell_of_interest[["ara_y_rot", "ara_z_rot"]].values.astype(float)
win_around = np.array([-1, 1]) * window / 1000 + cell_pos[None, :].T

print(f"Cropping around {np.round(cell_pos,2)} with window of {window}um")

barcodes_by_roi = []
spots_by_roi = []
for r, rdf in surrounding_rois.iterrows():
    spots = rab_spot_df.query(f"chamber == '{rdf.chamber}' and roi == {rdf.roi}")
    for i, coord in enumerate("yz"):
        w = win_around[i]
        spots = spots.query(f"ara_{coord}_rot >= {w[0]} and ara_{coord}_rot <= {w[1]}")
    spots_by_roi.append(spots)
    barcodes_by_roi.append(set(spots.corrected_bases.unique()))

print(f"Found {len(spots)} spots in the surrounding slice")

barcodes = barcodes_by_roi[0].intersection(barcodes_by_roi[1])
print(
    f"Found {len(barcodes)} barcodes in common (intersection of {len(barcodes_by_roi[0])} and {len(barcodes_by_roi[1])})"
)

In [None]:
# select the barcodes that are present in both slices in large numbers
spots = pd.concat(spots_by_roi)
spots = spots.query("corrected_bases in @barcodes")
bc_per_roi = spots.groupby(["slice", "corrected_bases"]).size().unstack().fillna(0)
best_barcodes = bc_per_roi.min(axis=0).sort_values(ascending=False)
best_barcodes = best_barcodes[best_barcodes > 5]
bc_per_roi[best_barcodes.index]

spots = spots.query("corrected_bases in @best_barcodes.index")
print(
    f"Found {len(spots)} spots in the pair of slices with the selected {len(best_barcodes)} barcodes"
)

In [None]:
# make a spot image for each barcode
from tqdm import tqdm
from iss_preprocess.segment.spots import make_spot_image

gaussian_width = 30

origin = np.array([spots.ara_y_rot.min(), spots.ara_z_rot.min()])
corner = (
    np.array([spots.ara_y_rot.max(), spots.ara_z_rot.max()])
    + (1 + gaussian_width * 20) / 1000
)
output_shape = ((corner - origin) * 1000).astype(int)

spot_images = np.empty((len(best_barcodes), 2, *output_shape), dtype="single")
for ibc, bc in tqdm(enumerate(best_barcodes.index), total=len(best_barcodes)):
    bc_df = spots[spots["corrected_bases"] == bc]
    for islice, (slice, slice_df) in enumerate(bc_df.groupby("slice")):
        # rename to x, y for make_spot_image
        sp = pd.DataFrame(
            slice_df[["ara_y_rot", "ara_z_rot"]].values - origin, columns=["x", "y"]
        )
        sp *= 1000
        img = make_spot_image(
            sp, gaussian_width=gaussian_width, dtype="single", output_shape=output_shape
        )
        spot_images[best_barcodes.index.get_loc(bc), islice] = img

In [None]:
sz = 2
max2plot = 20
fig, axes = plt.subplots(10, 4, figsize=(sz * 4, sz * 10))
colors = cm.get_cmap("tab20").colors
for ibc, bc in enumerate(best_barcodes.index):
    if ibc >= max2plot:
        break
    bc_df = spots[spots["corrected_bases"] == bc]
    for islice, (slice, slice_df) in enumerate(bc_df.groupby("slice")):
        axes[ibc % 10, islice + ibc // 10 * 2].imshow(
            spot_images[ibc, islice],
            cmap="Greys",
            extent=[
                origin[0] * 1000,
                corner[0] * 1000,
                corner[1] * 1000,
                origin[1] * 1000,
            ],
        )
        axes[ibc % 10, islice + ibc // 10 * 2].scatter(
            slice_df.ara_y_rot * 1000,
            slice_df.ara_z_rot * 1000,
            s=5,
            alpha=0.3,
            color=colors[ibc % 20],
        )
for x in axes.flatten():
    x.axis("equal")
    x.set_xticks([])
    x.set_yticks([])
fig.tight_layout()

In [None]:
# do phase correlation for each pair

from image_tools.registration.phase_correlation import phase_correlation

shifts = np.zeros((len(best_barcodes), 2))
max_corrs = np.zeros(len(best_barcodes))
phase_corrs = np.zeros((len(best_barcodes), *output_shape))
for ibc in tqdm(range(len(best_barcodes))):
    ref = np.nan_to_num(spot_images[ibc, 0])
    target = np.nan_to_num(spot_images[ibc, 1])
    shifts[ibc], max_corrs[ibc], phase_corrs[ibc], _ = phase_correlation(
        ref, target, whiten=False
    )

In [None]:
# plot all the phase correlations
fig, axes = plt.subplots(4, 5, figsize=(10, 8))
for i in range(20):
    if i >= max2plot:
        break
    axes.flatten()[i].imshow(phase_corrs[i], cmap="viridis")
    axes.flatten()[i].axis("off")
fig.tight_layout()

In [None]:
sum_corr = phase_corrs.sum(axis=0)
# find the max and the corresponding shift
maxcorr = np.max(sum_corr)
argmax = np.array(np.unravel_index(np.argmax(sum_corr), sum_corr.shape))
# shift is relative to center of image
shifts = argmax - np.array(sum_corr.shape) // 2
print(f"Max correlation: {maxcorr} at shift {shifts}")

In [None]:
fig = plt.figure(figsize=(10, 5))

ax = fig.add_subplot(2, 2, 1)
ax.imshow(
    phase_corrs.sum(axis=0).T,
    cmap="viridis",
    extent=(
        -output_shape[1] // 2,
        output_shape[1] // 2,
        -output_shape[0] // 2,
        output_shape[0] // 2,
    ),
    origin="lower",
)
ax.scatter(*shifts, color="red", marker="x")
ax.set_title("Sum of phase correlations")


ax = fig.add_subplot(2, 2, 3)
ax.scatter(shifts[0], shifts[1], s=500, color="red", marker="x")
ax.scatter(shifts[:, 0], shifts[:, 1], s=max_corrs / max_corrs.max() * 10, alpha=0.5)
ax.set_aspect("equal")
ax.set_xlabel("Y shift (px)")
ax.set_ylabel("Z shift (px)")
ax.set_title(r"Shifts for each pair - dot size $\alpha$ corr. coeff.")

for i in range(2):
    sh = shifts[:, i]
    ax = plt.subplot(2, 2, 2 + i * 2)
    ax.hist(sh, bins=np.arange(sh.min(), sh.max(), 10), alpha=0.5)
    ax.set_xlabel(f"{['Y', 'Z'][i]} shifts")
    ax.set_ylabel("Number of barcodes")
    ax.axvline(shifts[i], color="red", label="Max correlation")
fig.tight_layout()

In [None]:
# for each spot_image pair, find the argmax of the ref
argmaxes = np.zeros((len(best_barcodes), 2))
for ipair, img_pair in enumerate(spot_images):
    argmaxes[ipair] = np.unravel_index(np.argmax(img_pair[0]), img_pair[0].shape)

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
for i in range(2):
    sc = axes[i].scatter(
        argmaxes[:, 0],
        argmaxes[:, 1],
        c=shifts[:, i],
        cmap="RdBu",
        vmin=shifts[i] - 100,
        vmax=shifts[i] + 100,
    )
    axes[i].set_aspect("equal")
    cb = fig.colorbar(sc, ax=axes[i])
    cb.ax.axhline(shifts[i], color="red", label="Max correlation")
    axes[i].set_xlabel("Y max of spot image (um)")
    axes[i].set_ylabel("Z max of spot image (um)")
    axes[i].set_title(f'Shift {["Y", "Z"][i]}')