In [None]:
%load_ext autoreload
%autoreload 2

# Find duplicates across ROIs

Look for cells that are detected on multiple consecutive ROIs

In [None]:
import iss_preprocess as issp
import iss_analysis as issa
from iss_analysis.barcodes import barcodes as bar
from iss_analysis.barcodes.diagnostics import (
    plot_gmm_clusters,
    plot_error_along_sequence,
)
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
project = "becalia_rabies_barseq"
mouse = "BRAC8498.3e"

error_correction_ds_name = "BRAC8498.3e_error_corrected_barcodes_16"
data_path = f"{project}/{mouse}"
analysis_folder = issp.io.get_processed_path(data_path) / "analysis"

# Get rabies data

In [None]:
(
    rab_spot_df,
    rab_cells_barcodes,
    rabies_cell_properties,
) = issa.segment.get_barcode_in_cells(
    project,
    mouse,
    error_correction_ds_name,
    valid_chambers=None,
    save_folder=None,
    verbose=True,
)

In [None]:
# find starter
if False:
    starters_positions = issa.io.get_starter_cells(project, mouse)
    rabies_cell_properties = issa.segment.match_starter_to_barcodes(
        project,
        mouse,
        rabies_cell_properties,
        rab_spot_df,
        starters=starters_positions,
        redo=True,
    )
    rabies_cell_properties.head()

# Distance between cells

Look at what is the distance between rabies cells with the same barcode inside one ROI.

In [None]:
# Iterate on barcodes and on chamber/roi, use a KDTree to find the closest pair of
# cells with same barcode and build a full list
from scipy.spatial import KDTree

ncellbarcode = rabies_cell_properties.all_barcodes.apply(len).sum()
within_slice_distance_px = np.zeros(ncellbarcode)
rabies_cell_properties["closest_distance_px"] = [
    np.zeros(len(x)) for x in rabies_cell_properties.all_barcodes
]
i = 0
for (roi, chamber), roi_df in rabies_cell_properties.groupby(["roi", "chamber"]):
    all_barcodes = set()
    for bc in roi_df["all_barcodes"].values:
        all_barcodes.update(bc)
    for bc in all_barcodes:
        df = roi_df[roi_df["all_barcodes"].apply(lambda x: bc in x)]
        bc_index = df.all_barcodes.apply(lambda x: x.index(bc))
        if len(df) > 1:
            coords = df[["x", "y"]].values
            tree = KDTree(coords)
            dist, idx = tree.query(coords, k=2)
            within_slice_distance_px[i : i + len(df)] = dist[:, 1]
            for j in range(len(df)):
                rabies_cell_properties.loc[df.index[j], "closest_distance_px"][
                    bc_index[j]
                ] = dist[j, 1]
        else:
            within_slice_distance_px[i] = np.nan
            rabies_cell_properties.loc[df.index[0], "closest_distance_px"][
                bc_index[0]
            ] = np.nan
        i += len(df)

In [None]:
px_size = issp.io.get_pixel_size(f"{project}/{mouse}/{chamber}")
within_slice_distance_um = within_slice_distance_px * px_size

fig, axes = plt.subplots(2, 1, figsize=(6, 3))
ax = axes[0]
ax.set_title('Within slice distance between cells with same barcode')
twin_ax = ax.twinx()
ax.hist(within_slice_distance_um, bins=np.arange(0, 501, 1))
n_far = np.sum(within_slice_distance_um > 500)
twin_ax.scatter(501, n_far, color="red", label=">500um")
twin_ax.set_ylim(0, n_far * 1.3)
ax.set_xlim(0, 501)
twin_ax.legend(loc="upper right")
ax.set_ylabel("Number of cells")
ax = axes[1]
ax.hist(within_slice_distance_um, bins=np.arange(0, 101, 1))
ax.set_ylabel("Number of cells")
ax.set_xlabel("Distance (um)")
fig.tight_layout()
plt.show()

In [None]:
closeby = rabies_cell_properties[
    rabies_cell_properties.closest_distance_px.map(np.nanmin) < (30 / px_size)
]
print(f"Number of cells with a close neighbor: {len(closeby)}")

# Get ARA coordinates of spots

In [None]:
# Get transform to rotate ARA coordinates to slice
from iss_analysis.registration import ara_registration

transform = ara_registration.get_ara_to_slice_rotation_matrix(rab_spot_df)
rabies_cell_properties = ara_registration.rotate_ara_coordinate_to_slice(
    rabies_cell_properties, transform=transform
)
rab_spot_df = ara_registration.rotate_ara_coordinate_to_slice(
    rab_spot_df, transform=transform
)

In [None]:
from matplotlib import cm

fig, axes = plt.subplots(3, 3, figsize=(10, 6))
step = 0.01
bins = [
    np.arange(5, 9.2, step),
    np.arange(step, 8, step * 10),
    np.arange(4, 12, step * 10),
]
rab_spot_df["slice"] = (
    rab_spot_df["chamber"] + "_" + rab_spot_df["roi"].map(lambda x: f"{x:02d}")
)
rabies_cell_properties["slice"] = (
    rabies_cell_properties["chamber"]
    + "_"
    + rabies_cell_properties["roi"].map(lambda x: f"{x:02d}")
)
slices = sorted(rab_spot_df["slice"].unique())
colors = cm.get_cmap("viridis", len(slices))
for islice, slice in enumerate(slices):
    cell_prop = rabies_cell_properties[rabies_cell_properties["slice"] == slice]
    for iax, coord in enumerate("xyz"):
        axes[iax, 0].hist(
            cell_prop[f"ara_{coord}"],
            alpha=0.5,
            label=f"{chamber} {roi}",
            histtype="step",
            bins=bins[iax],
            color=colors(islice),
        )
        axes[iax, 1].hist(
            cell_prop[f"ara_{coord}_rot"],
            alpha=0.5,
            label=f"{chamber} {roi}",
            histtype="step",
            bins=bins[iax],
            color=colors(islice),
        )
        axes[iax, 2].scatter(
            cell_prop[f"ara_{coord}"],
            cell_prop[f"ara_{coord}_rot"] - cell_prop[f"ara_{coord}"],
            label=f"{chamber} {roi}",
            s=5,
            alpha=0.5,
            color=colors(islice),
        )
        for i in range(3):
            axes[iax, i].set_xlabel(f"ARA {coord}")
        axes[iax, 2].set_ylabel(f"Rotated {coord} - ARA {coord}")
for i in range(2):
    for j in range(3):
        axes[j, i].set_ylabel(f"Number of cells")

axes[0, 0].set_title("ARA coordinates")
axes[0, 1].set_title("ARA coordinates rotated")
axes[0, 2].set_title("Difference")
fig.tight_layout()

In [None]:
# now calculate the 3d dimensions in the rotated ara coordinates
from tqdm import tqdm

all_barcodes = set()
for bc in rabies_cell_properties["all_barcodes"].values:
    all_barcodes.update(bc)
ncellbarcode = rabies_cell_properties.all_barcodes.apply(len).sum()
ara_3d_distance = np.zeros(ncellbarcode)
ara_within_distance = np.zeros(ncellbarcode)
i = 0
for barcode in tqdm(all_barcodes):
    df = rabies_cell_properties[
        rabies_cell_properties["all_barcodes"].apply(lambda x: barcode in x)
    ]
    ara_coords = df[["ara_x_rot", "ara_y_rot", "ara_z_rot"]].values
    ara_coords[~np.isfinite(ara_coords)] = 0
    tree = KDTree(ara_coords)
    dist, idx = tree.query(ara_coords, k=2)
    ara_3d_distance[i : i + len(df)] = dist[:, 1]
    for (chamber, roi), cell_prop in df.groupby(["chamber", "roi"]):
        ara_coords = cell_prop[["ara_x_rot", "ara_y_rot", "ara_z_rot"]].values
        ara_coords[~np.isfinite(ara_coords)] = 0
        tree = KDTree(ara_coords)
        dist, idx = tree.query(ara_coords, k=2)
        ara_within_distance[i : i + len(cell_prop)] = dist[:, 1]
        i += len(cell_prop)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 3))
ax.hist(
    np.clip(ara_3d_distance * 1000, 0, 500),
    bins=np.arange(0, 501, 5),
    label="All slices",
    alpha=0.5,
    cumulative=False,
)
ax.hist(
    np.clip(ara_within_distance * 1000, 0, 500),
    bins=np.arange(0, 501, 5),
    alpha=0.5,
    label="Within slice",
    cumulative=False,
)
ax.set_xlabel("Distance between cells with same barcode (um)")
ax.set_ylabel("Number of cells")
ax.legend()
ax.set_ylim(0, 500)
ax.set_xlim(0, 500)

In [None]:
# Now look only at neighboring slices and project the 3d distance to the slice
from iss_analysis.registration import utils

close_within_slice = dict()
close_around_slice = dict()
ara_2d_distance = []
ara_within_2d_distance = []
for barcode in tqdm(all_barcodes):
    cells_df = rabies_cell_properties[
        rabies_cell_properties["all_barcodes"].apply(lambda x: barcode in x)
    ]
    for (chamber, roi), cell_prop in cells_df.groupby(["chamber", "roi"]):
        # within first
        ara_coords = cell_prop[["ara_y_rot", "ara_z_rot"]].values
        ara_coords[~np.isfinite(ara_coords)] = 0
        tree = KDTree(ara_coords)
        dist, idx = tree.query(ara_coords, k=2)
        ara_within_2d_distance.extend(dist[:, 1])
        closeby = np.where(dist[:, 1] < 30 / 1000)[0]
        for c in closeby:
            csource = cell_prop.index[c]
            if csource not in close_within_slice:
                close_within_slice[csource] = []
            close_within_slice[csource].append([cell_prop.index[idx[c]], dist[c, 1]])

        # now to surrounding slices
        surrounding_rois = utils.get_surrounding_slices(
            chamber, roi, project, mouse, include_ref=False
        )
        surrounding_slices = (
            surrounding_rois.chamber
            + "_"
            + surrounding_rois.roi.map(lambda x: f"{x:02d}")
        )
        surrounding_cells = cells_df[cells_df["slice"].isin(surrounding_slices)]
        surrounding_coords = surrounding_cells[["ara_y_rot", "ara_z_rot"]].values
        surrounding_coords[~np.isfinite(surrounding_coords)] = 0
        tree = KDTree(surrounding_coords)
        dist, idx = tree.query(ara_coords, k=1)
        ara_2d_distance.extend(dist)
        closeby = np.where(dist < 30 / 1000)[0]
        for c in closeby:
            csource = cell_prop.index[c]
            if csource not in close_around_slice:
                close_around_slice[csource] = []
            close_around_slice[csource].append(
                [surrounding_cells.index[idx[c]], dist[c]]
            )

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 3))
params = dict(
    bins=np.arange(0, 201, 3), alpha=1, cumulative=False, lw=2, histtype="step"
)
ax.hist(
    np.clip(np.array(ara_2d_distance) * 1000, 0, 200),
    label="Neighbouring slices",
    **params
)
ax.hist(
    np.clip(np.array(ara_within_2d_distance) * 1000, 0, 200),
    label="Within slice",
    **params
)
ax.set_xlabel("Distance between cells with same barcode (um)")
ax.set_ylabel("Number of cells")
ax.legend()
ax.set_ylim(0, 150)

In [None]:
print(f"{len(close_within_slice)} cells with close neighbors within slice")
print(f"{len(close_around_slice)} cells with close neighbors around slice")

In [None]:
for c in close_around_slice.keys():
    if "chamber_09_3" in c:
        print(c, close_around_slice[c])

In [None]:
list(close_around_slice.keys())[0]

# Plot example close cells

Find an example of cells that are close to each other and plot them.
We want the rotated ara of the spots and mask of the first cells, same for the second
And the row slice for both cells in a window around the cell

In [None]:
example_cell = "chamber_09_3_14937"

neighbour, dst = close_around_slice[example_cell][0]
print(f"Cell {example_cell} is {dst * 1000:.2f} um away from {neighbour}")

cells = rabies_cell_properties.loc[[example_cell, neighbour]]
cells

In [None]:
# plot the ara part
from iss_analysis import vis
from functools import partial

fig, axes = plt.subplots(2, 2, figsize=(10, 10))
ax_ara = axes[0, 0]
# window size in um around the source to plot background
window = 200


source = cells.iloc[0]
window = np.array([-1, 1]) * window / 1000
center = cells.iloc[0][["ara_y_rot", "ara_z_rot"]].values

_get_spot = partial(
    vis.get_spot_part,
    xlim=center[0] + window,
    ylim=center[1] + window,
    xcol="ara_y_rot",
    ycol="ara_z_rot",
)

labels = ["Source", "Neighbour"]
colors = cm.get_cmap("Set2").colors[3:]
bg_colors = ["lightblue", "lightgrey"]
for ic, (cname, cell) in enumerate(cells.iterrows()):
    c, r = cell.chamber, cell.roi
    bg_cells = rabies_cell_properties[
        (rabies_cell_properties.chamber == c) & (rabies_cell_properties.roi == r)
    ]
    bg_cells = _get_spot(bg_cells)
    ax_ara.scatter(
        bg_cells["ara_y_rot"] * 1000,
        bg_cells["ara_z_rot"] * 1000,
        color=bg_colors[ic],
        s=100,
        label=f"{c} roi {r}",
        zorder=0,
        alpha=0.5,
    )

    bg_spots = rab_spot_df[(rab_spot_df.chamber == c) & (rab_spot_df.roi == r)]
    bg_spots = _get_spot(bg_spots)
    ax_ara.scatter(
        bg_spots["ara_y_rot"] * 1000,
        bg_spots["ara_z_rot"] * 1000,
        color=bg_colors[ic],
        s=10,
        zorder=0,
        alpha=0.5,
    )

    ax_ara.scatter(
        cell["ara_y_rot"] * 1000,
        cell["ara_z_rot"] * 1000,
        color=colors[ic],
        label=f"{cname}",
        s=100,
    )
    # find spots of this cell
    spots = rab_spot_df[rab_spot_df["cell_mask"] == cell.cell_id]
    ax_ara.scatter(
        spots["ara_y_rot"] * 1000,
        spots["ara_z_rot"] * 1000,
        color=colors[ic],
        ec="k",
        label=f"{labels[ic]} spots",
        s=10,
    )

    ax = axes[1, ic]
    ax.scatter(bg_cells.x, bg_cells.y, c=bg_cells.cell_id % 20, cmap='tab20', vmin=0, vmax=19, s=100, alpha=0.5)
    bg = bg_spots.cell_mask == -1
    #ax.scatter(bg_spots[bg].x, bg_spots[bg].y, color='k', s=10, alpha=0.5)
    #ax.scatter(bg_spots[~bg].x, bg_spots[~bg].y, c=bg_spots[~bg].cell_mask % 20, s=10, alpha=0.5, cmap='tab20', vmin=0, vmax=19)
    ax.scatter(bg_spots.x, bg_spots.y, c=bg_spots.index % 20, s=10, alpha=0.5, cmap='tab20', vmin=0, vmax=19)
    ax.scatter(cell["x"], cell["y"], color=colors[ic], s=100, marker='x')
    if False:
        ax.scatter(
            cell["x"],
            cell["y"],
            color=colors[ic],
            s=100,
        )
        ax.scatter(
            spots["x"],
            spots["y"],
            color=colors[ic],
            ec="k",
            s=10,
        )

for x in axes.flatten():
    x.axis('off')
ax_ara.legend(loc="upper right", ncol=2)
fig.tight_layout()

In [None]:
barcodes = list(rab_spot_df.corrected_bases.unique())
rab_spot_df['bc_index'] = rab_spot_df.corrected_bases.map(lambda x: barcodes.index(x))
rab_spot_df.bc_index

In [None]:
# plot the ara part
from iss_analysis import vis
from functools import partial

fig, axes = plt.subplots(1, 2, figsize=(20, 10))

# window size in um around the source to plot background
window = 300


source = cells.iloc[0]
window = np.array([-1, 1]) * window / 1000
center = cells.iloc[1][["ara_y_rot", "ara_z_rot"]].values + 0.4

_get_spot = partial(
    vis.get_spot_part,
    xlim=center[0] + window,
    ylim=center[1] + window,
    xcol="ara_y_rot",
    ycol="ara_z_rot",
)

labels = ["Source", "Neighbour"]
colors = cm.get_cmap("Set2").colors[3:]
bg_colors = ["lightblue", "lightgrey"]
for ic, (cname, cell) in enumerate(cells.iterrows()):
    c, r = cell.chamber, cell.roi
    bg_cells = rabies_cell_properties[
        (rabies_cell_properties.chamber == c) & (rabies_cell_properties.roi == r)
    ]
    bg_cells = _get_spot(bg_cells)
    axes[ic].scatter(
        bg_cells["ara_y_rot"] * 1000,
        bg_cells["ara_z_rot"] * 1000,
        color=bg_colors[ic],
        s=100,
        label=f"{c} roi {r}",
        zorder=0,
        alpha=0.1,
    )

    bg_spots = rab_spot_df[(rab_spot_df.chamber == c) & (rab_spot_df.roi == r)]
    bg_spots = _get_spot(bg_spots)
    axes[ic].scatter(
        bg_spots["ara_y_rot"] * 1000,
        bg_spots["ara_z_rot"] * 1000,
        c=bg_spots.bc_index % 20,
        cmap='tab20',
        vmin=0,
        vmax=19,
        s=20,
        zorder=0,
        alpha=1,
    )


for x in axes.flatten():
    x.axis('on')
    x.grid(True)
ax_ara.legend(loc="upper right", ncol=2)
fig.tight_layout()

# Register ara data

The aim is to register slices together but it's hard to do directly, so first we 
register to the ara, rotate the ara to align with the slicing plane and then only
register the slices. `rab_spot_df` already has the `ara_x/y/z_rot` information. 

We will register locally, around one cell of interest. First we subselect the data
in window around that cell, then we register the two slices together.
For that we get barcodes present in both slices, keep the 
one that have enough dots and make a blurred version of the mask to register the two.


In [None]:
cell_of_interest = rabies_cell_properties.iloc[200]

ref_chamber, ref_roi = cell_of_interest.chamber, cell_of_interest.roi
surrounding_rois = utils.get_surrounding_slices(
            ref_chamber, ref_roi, project, mouse, include_ref=True
        )

# we will look only at the previous slice for now
surrounding_rois = surrounding_rois.iloc[:2]



In [None]:
from iss_analysis.registration import register_serial_sections
ref_slice = surrounding_rois.iloc[0].chamber + "_" + f"{surrounding_rois.iloc[0].roi:02d}"
target_slice = surrounding_rois.iloc[1].chamber + "_" + f"{surrounding_rois.iloc[1].roi:02d}"
shift, maxcorr, phase_corrs, spot_images, barcodes = register_serial_sections.register_local_spots(spot_df=rab_spot_df,
    ref_slice=ref_slice,
    target_slice=target_slice,
    center_point=cell_of_interest[["ara_y_rot", "ara_z_rot"]].values,
    window_size=250,
    min_spots = 5,
    max_barcode_number = 500,
    gaussian_width= 30,
    verbose= True,
)

In [None]:
window = 2000

cell_pos = cell_of_interest[["ara_y_rot", "ara_z_rot"]].values.astype(float)
win_around = np.array([-1, 1]) * window / 1000 + cell_pos[None,:].T

print(f"Cropping around {np.round(cell_pos,2)} with window of {window}um")

barcodes_by_roi = []
spots_by_roi = []
for r, rdf in surrounding_rois.iterrows():
    spots = rab_spot_df.query(f"chamber == '{rdf.chamber}' and roi == {rdf.roi}")
    for i, coord in enumerate('yz'):
        w = win_around[i]
        spots = spots.query(f"ara_{coord}_rot >= {w[0]} and ara_{coord}_rot <= {w[1]}")
    spots_by_roi.append(spots)
    barcodes_by_roi.append(set(spots.corrected_bases.unique()))

print(f"Found {len(spots)} spots in the surrounding slice")

barcodes = barcodes_by_roi[0].intersection(barcodes_by_roi[1])
print(f"Found {len(barcodes)} barcodes in common (intersection of {len(barcodes_by_roi[0])} and {len(barcodes_by_roi[1])})") 


In [None]:
# select the barcodes that are present in both slices in large numbers
spots = pd.concat(spots_by_roi)
spots = spots.query("corrected_bases in @barcodes")
bc_per_roi = spots.groupby(["slice", "corrected_bases"]).size().unstack().fillna(0)
best_barcodes = bc_per_roi.min(axis=0).sort_values(ascending=False)
best_barcodes = best_barcodes[best_barcodes > 5]
bc_per_roi[best_barcodes.index]

spots = spots.query("corrected_bases in @best_barcodes.index")
print(f"Found {len(spots)} spots in the pair of slices with the selected {len(best_barcodes)} barcodes")

In [None]:
# make a spot image for each barcode
from tqdm import tqdm
from iss_preprocess.segment.spots import make_spot_image
gaussian_width = 30

origin = np.array([spots.ara_y_rot.min(), spots.ara_z_rot.min()])
corner = np.array([spots.ara_y_rot.max(), spots.ara_z_rot.max()]) + (1 + gaussian_width * 20) / 1000
output_shape = ((corner - origin) * 1000).astype(int)

spot_images = np.empty((len(best_barcodes), 2, *output_shape), dtype="single")
for ibc, bc in tqdm(enumerate(best_barcodes.index), total=len(best_barcodes)):
    bc_df = spots[spots["corrected_bases"]==bc]
    for islice, (slice, slice_df) in enumerate(bc_df.groupby("slice")):
        # rename to x, y for make_spot_image
        sp = pd.DataFrame(slice_df[["ara_y_rot", "ara_z_rot"]].values - origin, columns=["x", "y"])
        sp *= 1000
        img = make_spot_image(sp, gaussian_width=gaussian_width, dtype="single", output_shape=output_shape)
        spot_images[best_barcodes.index.get_loc(bc), islice] = img

In [None]:
sz = 2
max2plot = 20
fig, axes = plt.subplots(10, 4, figsize=(sz*4, sz*10))
colors = cm.get_cmap("tab20").colors
for ibc, bc in enumerate(best_barcodes.index):
    if ibc >= max2plot:
        break
    bc_df = spots[spots["corrected_bases"]==bc]
    for islice, (slice, slice_df) in enumerate(bc_df.groupby("slice")):
        axes[ibc%10, islice + ibc//10 * 2].imshow(spot_images[ibc, islice], cmap='Greys', extent=[origin[0]*1000, corner[0]*1000, corner[1]*1000, origin[1]*1000])
        axes[ibc%10, islice+ ibc//10 * 2].scatter(slice_df.ara_y_rot*1000, slice_df.ara_z_rot*1000, s=5, alpha=0.3, color=colors[ibc%20])
for x in axes.flatten():
    x.axis('equal')
    x.set_xticks([])
    x.set_yticks([])
fig.tight_layout()

In [None]:
# do phase correlation for each pair

from image_tools.registration.phase_correlation import phase_correlation

shifts = np.zeros((len(best_barcodes), 2))
max_corrs = np.zeros(len(best_barcodes))
phase_corrs = np.zeros((len(best_barcodes), *output_shape))
for ibc in tqdm(range(len(best_barcodes))):
    ref = np.nan_to_num(spot_images[ibc, 0])
    target = np.nan_to_num(spot_images[ibc, 1])
    shifts[ibc], max_corrs[ibc], phase_corrs[ibc], _ = phase_correlation(ref, target, whiten=False)


In [None]:
# plot all the phase correlations
fig, axes = plt.subplots(4, 5, figsize=(10, 8))
for i in range(20):
    if i >= max2plot:
        break
    axes.flatten()[i].imshow(phase_corrs[i], cmap='viridis')
    axes.flatten()[i].axis('off')
fig.tight_layout()

In [None]:
sum_corr = phase_corrs.sum(axis=0)
# find the max and the corresponding shift
maxcorr = np.max(sum_corr)
argmax = np.array(np.unravel_index(np.argmax(sum_corr), sum_corr.shape))
# shift is relative to center of image
shift = argmax - np.array(sum_corr.shape) // 2
print(f"Max correlation: {maxcorr} at shift {shift}")

In [None]:
fig = plt.figure(figsize=(10, 5))

ax = fig.add_subplot(2,2,1)
ax.imshow(phase_corrs.sum(axis=0).T, cmap='viridis', extent=(-output_shape[1]//2, output_shape[1]//2, -output_shape[0]//2, output_shape[0]//2),
origin='lower')
ax.scatter(*shift, color='red', marker='x')
ax.set_title('Sum of phase correlations')


ax = fig.add_subplot(2,2,3)
ax.scatter(shift[0], shift[1], s=500, color='red', marker='x')
ax.scatter(shifts[:,0], shifts[:,1], s=max_corrs/max_corrs.max()*10, alpha=0.5)
ax.set_aspect('equal')
ax.set_xlabel('Y shift (px)')
ax.set_ylabel('Z shift (px)')
ax.set_title(r'Shifts for each pair - dot size $\alpha$ corr. coeff.')

for i in range(2):
    sh = shifts[:,i]
    ax = plt.subplot(2,2,2 + i * 2)
    ax.hist(sh, bins=np.arange(sh.min(), sh.max(), 10), alpha=0.5)
    ax.set_xlabel(f"{['Y', 'Z'][i]} shifts")
    ax.set_ylabel("Number of barcodes")
    ax.axvline(shift[i], color='red', label='Max correlation')
fig.tight_layout()

In [None]:
# for each spot_image pair, find the argmax of the ref
argmaxes = np.zeros((len(best_barcodes), 2))
for ipair, img_pair in enumerate(spot_images):
    argmaxes[ipair] = np.unravel_index(np.argmax(img_pair[0]), img_pair[0].shape)

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
for i in range(2):
    sc = axes[i].scatter(argmaxes[:,0], argmaxes[:,1], c=shifts[:,i], cmap='RdBu', vmin=shift[i]-100, vmax=shift[i]+100)
    axes[i].set_aspect('equal')
    cb = fig.colorbar(sc, ax=axes[i])
    cb.ax.axhline(shift[i], color='red', label='Max correlation')
    axes[i].set_xlabel('Y max of spot image (um)')
    axes[i].set_ylabel('Z max of spot image (um)')
    axes[i].set_title(f'Shift {["Y", "Z"][i]}')