# Register serial sections

Use rabies barcode to register serial section

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
project = "becalia_rabies_barseq"
mouse = "BRAC8498.3e"
error_correction_ds_name = "BRAC8498.3e_error_corrected_barcodes_10"

## Run serial section registration



In [None]:
from iss_analysis.registration import register_serial_sections
from iss_analysis.io import get_sections_info

# To reload the data, set reload=True and use_slurm=False
serial_registration = register_serial_sections.register_all_serial_sections(
    project=project,
    mouse=mouse,
    error_correction_ds_name=error_correction_ds_name,
    correlation_window_size=500,
    min_spots=10,
    max_barcode_number=50,
    gaussian_width=20,
    n_workers=20,
    verbose=True,
    use_slurm=False,
    reload=True,
    slice_window=(-1, 3),
)
section_infos = get_sections_info(project, mouse)

In [None]:
# Load big dataframe
import pandas as pd
import flexiznam as flz

df_file = flz.get_processed_path(
    "becalia_rabies_barseq/BRAC8498.3e/analysis/cell_barcode_df.pkl"
)

full_df = pd.read_pickle(df_file)
# Add a slice id which is the concatenation of chamber and roi
full_df["slice_id"] = full_df["chamber"].astype(str) + "_" + full_df["roi"].astype(str)
print(f"Loaded {len(full_df)} cells")
barcoded_cells = full_df.query("main_barcode.notna()").copy()
print(f"Found {len(barcoded_cells)} barcoded cells")

In [None]:
# Add slice number to rabies cell properties
import numpy as np

barcoded_cells["slice"] = np.nan
for islice, slice_prop in section_infos.iterrows():
    chamber, roi = slice_prop[["chamber", "roi"]]
    cell_this_slice = (barcoded_cells["chamber"] == chamber) & (
        barcoded_cells["roi"] == roi
    )
    barcoded_cells.loc[cell_this_slice, "slice"] = islice

In [None]:
reference_slice = 20
serial_registration[reference_slice]["next"].head()

In [None]:
# add rotated ara coordinates
from iss_analysis.registration import ara_registration

barcoded_cells = ara_registration.rotate_ara_coordinate_to_slice(barcoded_cells)
barcoded_cells.head()

In [None]:
# Register the next slice
import numpy as np
import matplotlib.pyplot as plt

from scipy.interpolate import RBFInterpolator

next_slice = reference_slice + 1
res = serial_registration[reference_slice]["next"]


def get_interpolators(res, threshold=400, smoothing=10):
    shifts = res[["shift_y", "shift_z"]].values
    shift_ampl = np.linalg.norm(shifts, axis=1)
    valid = shift_ampl < threshold
    shifts = shifts[valid]
    good_idx = res.index[valid]
    cell_coords = res.loc[good_idx, ["ara_y_rot", "ara_z_rot"]]
    z_shift_interpolator = RBFInterpolator(
        cell_coords, shifts[:, 0], smoothing=smoothing
    )
    y_shift_interpolator = RBFInterpolator(
        cell_coords, shifts[:, 1], smoothing=smoothing
    )
    return z_shift_interpolator, y_shift_interpolator


z_shift_interpolator, y_shift_interpolator = get_interpolators(res)

# Find cells in the next slice
cell_next = barcoded_cells[barcoded_cells["slice"] == next_slice].copy()
cell_next["ara_y_rot_serial"] = cell_next["ara_y_rot"]
cell_next["ara_z_rot_serial"] = cell_next["ara_z_rot"]
cell_coords = cell_next[["ara_y_rot", "ara_z_rot"]].values
z_shift = z_shift_interpolator(cell_coords) / 1000
y_shift = y_shift_interpolator(cell_coords) / 1000
cell_next["ara_z_rot_serial"] += z_shift
cell_next["ara_y_rot_serial"] += y_shift

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].scatter(
    cell_next["ara_y_rot"], cell_next["ara_z_rot"], c="r", s=1, label="Slice 21"
)
axes[0].scatter(
    cell_next["ara_y_rot_serial"],
    cell_next["ara_z_rot_serial"],
    c="b",
    s=1,
    label="Slice 21 registered to 20",
)

axes[1].scatter(
    cell_next["ara_y_rot_serial"],
    cell_next["ara_z_rot_serial"],
    c="b",
    s=1,
    label="Slice 21 registered",
)

cell_ref = barcoded_cells[barcoded_cells["slice"] == reference_slice]
axes[1].scatter(
    cell_ref["ara_y_rot"], cell_ref["ara_z_rot"], c="g", s=1, label="Slice 20"
)
plt.tight_layout()
for x in axes:
    x.legend()
    x.set_xlim(0.5, 6)
    x.set_ylim(6, 10)
    x.set_aspect("equal")

In [None]:
# Make a 3d image of shifts
# Use the ara_y_rot and ara_z_rot of each images and create a nslice x ny x nz x 2
# matrix of y and z shifts
reference_slice = 10
slices = list(sorted(serial_registration.keys()))
# Some cells that do not match the atlas end up in 0,0, ignore those
valid = barcoded_cells["ara_y_rot"] > 0
yrange = np.array(
    [barcoded_cells["ara_y_rot"][valid].min(), barcoded_cells["ara_y_rot"].max()]
)
valid = barcoded_cells["ara_z_rot"] > 0
zrange = np.array(
    [barcoded_cells["ara_z_rot"][valid].min(), barcoded_cells["ara_z_rot"].max()]
)
# make into bins of 10 microns, not mm
yrange = (yrange * 100).astype(int)
zrange = (zrange * 100).astype(int)

ybins = np.arange(yrange[0], yrange[1] + 2)
zbins = np.arange(zrange[0], zrange[1] + 2)
zz, yy = np.meshgrid(zbins, ybins)
pos2interpolate = np.vstack([yy.flatten(), zz.flatten()]).T / 100

shift_matrix = np.zeros((len(slices), len(ybins), len(zbins), 2))
for islice in np.arange(max(slices), reference_slice, -1)[::-1]:
    res = serial_registration[islice - 1]["next"]
    z_shift_interpolator, y_shift_interpolator = get_interpolators(res)
    y_shifts = y_shift_interpolator(pos2interpolate)
    z_shifts = z_shift_interpolator(pos2interpolate)
    shift_matrix[islice, :, :, 0] = y_shifts.reshape(zz.shape)
    shift_matrix[islice, :, :, 1] = z_shifts.reshape(zz.shape)
for islice in range(0, reference_slice):
    res = serial_registration[islice + 1]["previous"]
    z_shift_interpolator, y_shift_interpolator = get_interpolators(res)
    y_shifts = y_shift_interpolator(pos2interpolate)
    z_shifts = z_shift_interpolator(pos2interpolate)
    shift_matrix[islice, :, :, 0] = y_shifts.reshape(zz.shape)
    shift_matrix[islice, :, :, 1] = z_shifts.reshape(zz.shape)


In [None]:
plt.subplot(3,1,1)

ypx, zpx = 250, 300
plt.plot(shift_matrix[:,ypx,zpx,0], 'o', label='Y')
plt.plot(shift_matrix[:,ypx,zpx,1], 'o', label = 'Z')
cs = np.cumsum(shift_matrix, axis=0)
w = 6
med_shift = np.zeros_like(shift_matrix)
for i in range(med_shift.shape[0]):
    if i < w//2:
        b = 0
        e = w
    elif i + w//2 >= med_shift.shape[0]:
        b = med_shift.shape[0]-w
        e = med_shift.shape[0]
    else:
        b = i - w//2
        e = i + w//2
    med_shift[i] = np.nanmedian(shift_matrix[b:e, ...], axis=0)
    
plt.plot(med_shift[:,ypx,zpx, 0], color='C0', label='Running median')
plt.plot(med_shift[:,ypx,zpx, 1], color='C1')
plt.legend(loc='upper right')
plt.ylabel('Shifts (um)')
plt.subplot(3,1,2)
cumulshift = np.zeros_like(shift_matrix)
cumulshift[reference_slice:] = np.cumsum(shift_matrix[reference_slice:], axis=0)
cumulshift[:reference_slice] = np.cumsum(shift_matrix[:reference_slice][::-1], axis=0)[::-1]
plt.plot(cumulshift[:,ypx,zpx,0], 'o', label='Cumulative Y shift (um)')
plt.plot(cumulshift[:,ypx,zpx,1], 'o', label = 'Cumulative Z shift (um)')
plt.ylabel('Cumulative\nshift (um)')

plt.subplot(3,1,3)
cumulmed = np.zeros_like(shift_matrix)
cumulmed[reference_slice:] = np.cumsum(med_shift[reference_slice:], axis=0)
cumulmed[:reference_slice] = np.cumsum(med_shift[:reference_slice][::-1], axis=0)[::-1]
plt.plot(shift_matrix[:,ypx,zpx,0] + cumulmed[:,ypx,zpx,0], 'o', label='Y shift (um)')
plt.plot(shift_matrix[:,ypx,zpx,1] + cumulmed[:,ypx,zpx,1], 'o', label = 'cumulative median + local Z shift (um)')
plt.plot(cumulmed[:,ypx,zpx,0], 'C0')
plt.plot(cumulmed[:,ypx,zpx,1], 'C1')
plt.ylabel("cumulative\nmedian + local")
plt.xlabel('Slice number')


In [None]:
np.cumsum(cumulshift[reference_slice:], axis=0)

In [None]:
zz.max()

In [None]:
zrange

In [None]:
yrange

In [None]:
slices = list(sorted(serial_registration.keys()))

barcoded_cells["ara_y_rot_serial"] = barcoded_cells["ara_y_rot"].copy()
barcoded_cells["ara_z_rot_serial"] = barcoded_cells["ara_z_rot"].copy()

for islice in np.arange(max(slices), reference_slice, -1)[::-1]:
    res = serial_registration[islice - 1]["next"]
    z_shift_interpolator, y_shift_interpolator = get_interpolators(res)
    print("Applying shift for slice", islice)
    # Shift all rabies cells
    for slice2shift in range(max(slices), islice - 1, -1):
        chamber, roi = section_infos.loc[slice2shift, ["chamber", "roi"]]
        cell_this_slice = (barcoded_cells["chamber"] == chamber) & (
            barcoded_cells["roi"] == roi
        )
        barcoded_cells.loc[cell_this_slice, "ara_y_rot_serial"] += (
            y_shift_interpolator(
                barcoded_cells.loc[
                    cell_this_slice, ["ara_y_rot_serial", "ara_z_rot_serial"]
                ]
            )
            / 1000
        )
        barcoded_cells.loc[cell_this_slice, "ara_z_rot_serial"] += (
            z_shift_interpolator(
                barcoded_cells.loc[
                    cell_this_slice, ["ara_y_rot_serial", "ara_z_rot_serial"]
                ]
            )
            / 1000
        )

# same but using previous slice
for islice in range(0, reference_slice):
    res = serial_registration[islice + 1]["previous"]
    z_shift_interpolator, y_shift_interpolator = get_interpolators(res)
    print("Applying shift for slice", islice)
    # Shift all rabies cells
    for slice2shift in range(0, islice + 1):
        chamber, roi = section_infos.loc[slice2shift, ["chamber", "roi"]]
        cell_this_slice = (barcoded_cells["chamber"] == chamber) & (
            barcoded_cells["roi"] == roi
        )
        barcoded_cells.loc[cell_this_slice, "ara_y_rot_serial"] += (
            y_shift_interpolator(
                barcoded_cells.loc[
                    cell_this_slice, ["ara_y_rot_serial", "ara_z_rot_serial"]
                ]
            )
            / 1000
        )
        barcoded_cells.loc[cell_this_slice, "ara_z_rot_serial"] += (
            z_shift_interpolator(
                barcoded_cells.loc[
                    cell_this_slice, ["ara_y_rot_serial", "ara_z_rot_serial"]
                ]
            )
            / 1000
        )

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[1].scatter(
    barcoded_cells["ara_y_rot_serial"],
    barcoded_cells["ara_z_rot_serial"],
    c=barcoded_cells["slice"],
    cmap="viridis",
    s=10,
    alpha=0.5,
    vmin=0,
    vmax=39,
)
axes[1].set_title("Registered")
axes[0].scatter(
    barcoded_cells["ara_y_rot"],
    barcoded_cells["ara_z_rot"],
    c=barcoded_cells["slice"],
    cmap="viridis",
    s=10,
    alpha=0.5,
    vmin=0,
    vmax=39,
)
axes[0].set_title("Original")

for ax in axes:
    ax.set_xlim(1, 6)
    ax.set_ylim(4, 10)
    ax.set_aspect("equal")
plt.tight_layout()

In [None]:
# First find barcodes that are in only one starter
starter_cells = barcoded_cells.query("is_starter == True")
starter_cells = starter_cells.query("cortical_area == 'VISp'")
starter2bc = starter_cells.all_barcodes.explode()
bc_cnt = starter2bc.value_counts()
unique_barcodes = bc_cnt[bc_cnt == 1].index
print(
    f"Found {len(unique_barcodes)} barcodes present in only 1 starter out of {len(bc_cnt)} in {len(starter_cells)} starter cells"
)

cell2bc = barcoded_cells.all_barcodes.explode()
unique_bc_cells = cell2bc[cell2bc.isin(unique_barcodes)]
print(
    f"Found {len(unique_bc_cells)} cells with unique barcodes (including the starters)"
)

barcoded_cells["is_unique_bc"] = False
barcoded_cells.loc[unique_bc_cells.index, "is_unique_bc"] = True
barcoded_cells["unique_bc"] = [set() for _ in range(len(barcoded_cells))]
# Find the corresponding starter index too
bc2starter = starter2bc.reset_index().set_index("all_barcodes")
barcoded_cells["starter_ids"] = [set() for _ in range(len(barcoded_cells))]
for mask, bc in unique_bc_cells.items():
    barcoded_cells.loc[mask, "unique_bc"].add(bc)
    barcoded_cells.loc[mask, "starter_ids"].add(bc2starter.loc[bc].mask_uid)

barcoded_cells["starter_id"] = np.nan
for idx, starter_ids in barcoded_cells.starter_ids.items():
    if len(starter_ids) == 1:
        barcoded_cells.loc[idx, "starter_id"] = list(starter_ids)[0]

In [None]:
from iss_preprocess.io.load import get_pixel_size

px_size = get_pixel_size(data_path="becalia_rabies_barseq/BRAC8498.3e/chamber_08")

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
bins = np.arange(0, 1000, 20)
coords_to_plot = [
    ["ara_x_rot", "ara_y_rot", "ara_z_rot"],
    ["ara_x_rot", "ara_y_rot_serial", "ara_z_rot_serial"],
]
labels = ["Original", "Registered"]
for i in range(2):
    dst = dict(within=[], next_prev=[], all=[])
    coords_col = coords_to_plot[i]
    for bc in unique_barcodes:
        bc_cells = cell2bc[cell2bc == bc].index
        presynaptic = barcoded_cells.loc[bc_cells].query("is_starter == False")
        # only v1
        presynaptic = presynaptic.query("cortical_area == 'VISp'")
        starter = bc2starter.loc[bc].mask_uid
        starter_prop = starter_cells.loc[starter]
        starter_slice = starter_prop["slice"]

        bc_cells_this_slice = presynaptic.query("slice == @starter_slice")
        bc_cells_prev_or_next = presynaptic.query(
            "slice == @starter_slice - 1 or slice == @starter_slice + 1"
        )

        start_coord = starter_prop[coords_col].values.astype(float)
        dst["within"].append(
            np.linalg.norm(
                bc_cells_this_slice[coords_col].values.astype(float) - start_coord,
                axis=1,
            )
        )
        dst["next_prev"].append(
            np.linalg.norm(
                bc_cells_prev_or_next[coords_col].values.astype(float) - start_coord,
                axis=1,
            )
        )
        dst["all"].append(
            np.linalg.norm(
                presynaptic[coords_col].values.astype(float) - start_coord, axis=1
            )
        )
    for iax, (key, dstara) in enumerate(dst.items()):
        dstara = np.hstack(dstara) * 1000
        axes[iax].hist(dstara, bins=bins, alpha=0.5, label=labels[i])
        axes[iax].set_title(key)
        axes[iax].set_xlabel("Distance (um)")
        axes[iax].set_ylabel("Count")
axes[-1].legend()
fig.tight_layout()

In [None]:
# get a random shuffle of np.arange(10)
import numpy as np

shuffled_array = np.arange(10)
np.random.shuffle(shuffled_array)
print(shuffled_array)

In [None]:
rs = np.array([1, 5, 10, 20, 40, 80])
depth = [5, 20, 40, 80, 160, 320]


optic_flow = rs / depth
optic_flow

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 4, figsize=(15, 4))
bins = np.arange(0, 1000, 20)
ax[0].hist(dstara_rot, bins=bins)
ax[1].hist(np.array(dstara_serial), bins=bins)
ax[2].hist(np.array(dstara_all), bins=bins)
ax[3].hist(np.array(dstara_next_prev), bins=bins)
ax[0].set_title("ARA, within slice")
ax[1].set_title("ARA space, within slice")
ax[2].set_title("ARA space, all slices")
ax[3].set_title("ARA space, previous and next slices only")

for x in ax:
    x.set_xlabel("Distance to starter (um)")

fig.tight_layout