# Register serial sections

Use rabies barcode to register serial section

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
project = "becalia_rabies_barseq"
mouse = "BRAC8498.3e"
error_correction_ds_name = "BRAC8498.3e_error_corrected_barcodes_26"

## Run serial section registration



In [None]:
from iss_analysis.registration import register_serial_sections
from iss_analysis.io import get_sections_info

# To reload the data, set reload=True and use_slurm=False
serial_registration = register_serial_sections.register_all_serial_sections(
    project=project,
    mouse=mouse,
    error_correction_ds_name=error_correction_ds_name,
    correlation_window_size=500,
    min_spots=10,
    max_barcode_number=50,
    gaussian_width=20,
    n_workers=20,
    verbose=True,
    use_slurm=False,
    reload=True,
    slice_window=(-1, 3),
)
section_infos = get_sections_info(project, mouse)

In [None]:
# Load big dataframe
import pandas as pd
import flexiznam as flz

df_file = flz.get_processed_path(
    "becalia_rabies_barseq/BRAC8498.3e/analysis/cell_barcode_df.pkl"
)

full_df = pd.read_pickle(df_file)
# Add a slice id which is the concatenation of chamber and roi
full_df["slice_id"] = full_df["chamber"].astype(str) + "_" + full_df["roi"].astype(str)
print(f"Loaded {len(full_df)} cells")
barcoded_cells = full_df.query("main_barcode.notna()").copy()
print(f"Found {len(barcoded_cells)} barcoded cells")

In [None]:
# Add slice number to rabies cell properties
import numpy as np

barcoded_cells["slice"] = np.nan
for slice_index, slice_prop in section_infos.iterrows():
    chamber, roi = slice_prop[["chamber", "roi"]]
    cell_this_slice = (barcoded_cells["chamber"] == chamber) & (
        barcoded_cells["roi"] == roi
    )
    barcoded_cells.loc[cell_this_slice, "slice"] = slice_index

In [None]:
reference_slice = 20
serial_registration[reference_slice]["next"].head()

In [None]:
# add rotated ara coordinates
from iss_analysis.registration import ara_registration

ara2slicingplane_rotation = ara_registration.get_ara_to_slice_rotation_matrix(barcoded_cells)
barcoded_cells = ara_registration.rotate_ara_coordinate_to_slice(barcoded_cells, transform=ara2slicingplane_rotation)
barcoded_cells.head()

In [None]:
# Register the next slice
import numpy as np
import matplotlib.pyplot as plt

from scipy.interpolate import RBFInterpolator

next_slice = reference_slice + 1
res = serial_registration[reference_slice]["next"]


def get_interpolators(res, threshold=400, smoothing=10):
    shifts = res[["shift_y", "shift_z"]].values
    shift_ampl = np.linalg.norm(shifts, axis=1)
    valid = shift_ampl < threshold
    shifts = shifts[valid]
    good_idx = res.index[valid]
    cell_coords = res.loc[good_idx, ["ara_y_rot", "ara_z_rot"]]
    z_shift_interpolator = RBFInterpolator(
        cell_coords, shifts[:, 0], smoothing=smoothing
    )
    y_shift_interpolator = RBFInterpolator(
        cell_coords, shifts[:, 1], smoothing=smoothing
    )
    return z_shift_interpolator, y_shift_interpolator


z_shift_interpolator, y_shift_interpolator = get_interpolators(res)

# Find cells in the next slice
cell_next = barcoded_cells[barcoded_cells["slice"] == next_slice].copy()
cell_next["ara_y_rot_serial"] = cell_next["ara_y_rot"]
cell_next["ara_z_rot_serial"] = cell_next["ara_z_rot"]
cell_coords = cell_next[["ara_y_rot", "ara_z_rot"]].values
z_shift = z_shift_interpolator(cell_coords) / 1000
y_shift = y_shift_interpolator(cell_coords) / 1000
cell_next["ara_z_rot_serial"] += z_shift
cell_next["ara_y_rot_serial"] += y_shift

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].scatter(
    cell_next["ara_y_rot"], cell_next["ara_z_rot"], c="r", s=1, label="Slice 21"
)
axes[0].scatter(
    cell_next["ara_y_rot_serial"],
    cell_next["ara_z_rot_serial"],
    c="b",
    s=1,
    label="Slice 21 registered to 20",
)

axes[1].scatter(
    cell_next["ara_y_rot_serial"],
    cell_next["ara_z_rot_serial"],
    c="b",
    s=1,
    label="Slice 21 registered",
)

cell_ref = barcoded_cells[barcoded_cells["slice"] == reference_slice]
axes[1].scatter(
    cell_ref["ara_y_rot"], cell_ref["ara_z_rot"], c="g", s=1, label="Slice 20"
)
plt.tight_layout()
for x in axes:
    x.legend()
    x.set_xlim(0.5, 6)
    x.set_ylim(6, 10)
    x.set_aspect("equal")

In [None]:
# Make a 3d image of shifts
# Use the ara_y_rot and ara_z_rot of each images and create a nslice x ny x nz x 2
# matrix of y and z shifts
reference_slice = 20
slices = list(sorted(serial_registration.keys()))
# Some cells that do not match the atlas end up in 0,0, ignore those
valid = barcoded_cells["ara_y_rot"] > 0
yrange = np.array(
    [barcoded_cells["ara_y_rot"][valid].min(), barcoded_cells["ara_y_rot"].max()]
)
valid = barcoded_cells["ara_z_rot"] > 0
zrange = np.array(
    [barcoded_cells["ara_z_rot"][valid].min(), barcoded_cells["ara_z_rot"].max()]
)
# make into bins of 10 microns, not mm
yrange = (yrange * 100).astype(int)
zrange = (zrange * 100).astype(int)

ybins = np.arange(yrange[0], yrange[1] + 2)
zbins = np.arange(zrange[0], zrange[1] + 2)
zz, yy = np.meshgrid(zbins, ybins)
pos2interpolate = np.vstack([yy.flatten(), zz.flatten()]).T / 100

shift_matrix = np.zeros((len(slices), len(ybins), len(zbins), 2))
for slice_index in np.arange(max(slices), reference_slice, -1)[::-1]:
    res = serial_registration[slice_index - 1]["next"]
    z_shift_interpolator, y_shift_interpolator = get_interpolators(res)
    y_shifts = y_shift_interpolator(pos2interpolate)
    z_shifts = z_shift_interpolator(pos2interpolate)
    shift_matrix[slice_index, :, :, 0] = y_shifts.reshape(zz.shape)
    shift_matrix[slice_index, :, :, 1] = z_shifts.reshape(zz.shape)
for slice_index in range(0, reference_slice):
    res = serial_registration[slice_index + 1]["previous"]
    z_shift_interpolator, y_shift_interpolator = get_interpolators(res)
    y_shifts = y_shift_interpolator(pos2interpolate)
    z_shifts = z_shift_interpolator(pos2interpolate)
    shift_matrix[slice_index, :, :, 0] = y_shifts.reshape(zz.shape)
    shift_matrix[slice_index, :, :, 1] = z_shifts.reshape(zz.shape)


In [None]:
fig = plt.figure(figsize=(10,7))
plt.subplot(3,1,1)

cumulshift = np.zeros_like(shift_matrix)
cumulshift[reference_slice:] = np.cumsum(shift_matrix[reference_slice:], axis=0)
cumulshift[:reference_slice] = np.cumsum(shift_matrix[:reference_slice][::-1], axis=0)[::-1]
w =6
med_shift = np.zeros_like(shift_matrix)
for i in range(med_shift.shape[0]):
    if i < w//2:
        b = 0
        e = w
    elif i + w//2 >= med_shift.shape[0]:
        b = med_shift.shape[0]-w
        e = med_shift.shape[0]
    else:
        b = i - w//2
        e = i + w//2
    med_shift[i] = np.nanmedian(cumulshift[b:e, ...], axis=0)
    
filtered_shifts = cumulshift - med_shift

ypx, zpx = 150, 300
plt.plot(shift_matrix[:,ypx,zpx,0], 'o', label='Y')
plt.plot(shift_matrix[:,ypx,zpx,1], 'o', label = 'Z')
plt.ylabel('Shifts (um)')
plt.subplot(3,1,2)
plt.plot(cumulshift[:,ypx,zpx,0], 'o', label='Cumulative Y shift (um)')
plt.plot(cumulshift[:,ypx,zpx,1], 'o', label = 'Cumulative Z shift (um)')
plt.plot(med_shift[:,ypx,zpx, 0], color='C0', label='Running median')
plt.plot(med_shift[:,ypx,zpx, 1], color='C1')
plt.ylabel('Cumulative\nshift (um)')

plt.subplot(3,1,3)
plt.plot(filtered_shifts[:,ypx,zpx,0], 'o', label='Y')
plt.plot(filtered_shifts[:,ypx,zpx,1], 'o', label = 'Z')

plt.ylabel("Filtered shift\n(cumulative shift - rolling median)")
plt.xlabel('Slice number')
for ax in fig.axes:
    ax.axhline(0, color='k', alpha=0.5)


In [None]:
# Transform each cell with the filtered shifts
from scipy.interpolate import RegularGridInterpolator

slices = list(sorted(serial_registration.keys()))

# Calculate bin centers (which correspond to the grid points)
points = (ybins, zbins)
cmap = plt.get_cmap('viridis',filtered_shifts.shape[0])
for slice_index in np.arange(filtered_shifts.shape[0]):
    # Create a simple 2d interpolator for this slice
    y_shift_interpolator_2d = RegularGridInterpolator(points, filtered_shifts[slice_index, :,:,0], method='cubic', bounds_error=False, fill_value=0)
    z_shift_interpolator_2d = RegularGridInterpolator(points, filtered_shifts[slice_index, :,:,1], method='cubic', bounds_error=False, fill_value=0)
    
    # Get all cells for the chosen slice_index
    chamber_slice, roi_slice = section_infos.loc[slice_index, ["chamber", "roi"]]
    cells_this_slice_mask = (barcoded_cells["chamber"] == chamber_slice) & (barcoded_cells["roi"] == roi_slice)
    cells_to_shift = barcoded_cells.loc[cells_this_slice_mask].copy() # Work on a copy

    # Get the rotated ARA coordinates for interpolation, converted to10um bins
    coords_to_interpolate = cells_to_shift[["ara_y_rot", "ara_z_rot"]].values * 100

    # Interpolate the shifts at the cell locations
    interpolated_y_shifts = y_shift_interpolator_2d(coords_to_interpolate) / 1000 # Convert um to mm
    interpolated_z_shifts = z_shift_interpolator_2d(coords_to_interpolate) / 1000 # Convert um to mm
    # Add new columns for the registered coordinates
    cells_to_shift["ara_y_rot_serial"] = cells_to_shift["ara_y_rot"] + interpolated_y_shifts
    cells_to_shift["ara_z_rot_serial"] = cells_to_shift["ara_z_rot"] + interpolated_z_shifts
    barcoded_cells.loc[cells_to_shift.index, "ara_y_rot_serial"] = cells_to_shift["ara_y_rot_serial"]
    barcoded_cells.loc[cells_to_shift.index, "ara_z_rot_serial"] = cells_to_shift["ara_z_rot_serial"]
# x has not changed, copy it for convenience
barcoded_cells['ara_x_rot_serial'] = barcoded_cells['ara_x_rot'] 

# Add a version rotated back in teh ARA
rotated_coords = barcoded_cells[[f"ara_{i}_rot_serial" for i in "xyz"]].values
re_ara_coords = rotated_coords @ np.linalg.inv(ara2slicingplane_rotation)
for i, col in enumerate("xyz"):
    barcoded_cells[f"ara_{col}_serial"] = re_ara_coords[:, i]

In [None]:
if False:
    # cumulative shift application
    slices = list(sorted(serial_registration.keys()))

    barcoded_cells["ara_y_rot_serial"] = barcoded_cells["ara_y_rot"].copy()
    barcoded_cells["ara_z_rot_serial"] = barcoded_cells["ara_z_rot"].copy()

    for slice_index in np.arange(max(slices), reference_slice, -1)[::-1]:
        res = serial_registration[slice_index - 1]["next"]
        z_shift_interpolator, y_shift_interpolator = get_interpolators(res)
        print("Applying shift for slice", slice_index)
        # Shift all rabies cells
        for slice2shift in range(max(slices), slice_index - 1, -1):
            chamber, roi = section_infos.loc[slice2shift, ["chamber", "roi"]]
            cell_this_slice = (barcoded_cells["chamber"] == chamber) & (
                barcoded_cells["roi"] == roi
            )
            barcoded_cells.loc[cell_this_slice, "ara_y_rot_serial"] += (
                y_shift_interpolator(
                    barcoded_cells.loc[
                        cell_this_slice, ["ara_y_rot_serial", "ara_z_rot_serial"]
                    ]
                )
                / 1000
            )
            barcoded_cells.loc[cell_this_slice, "ara_z_rot_serial"] += (
                z_shift_interpolator(
                    barcoded_cells.loc[
                        cell_this_slice, ["ara_y_rot_serial", "ara_z_rot_serial"]
                    ]
                )
                / 1000
            )

    # same but using previous slice
    for slice_index in range(0, reference_slice):
        res = serial_registration[slice_index + 1]["previous"]
        z_shift_interpolator, y_shift_interpolator = get_interpolators(res)
        print("Applying shift for slice", slice_index)
        # Shift all rabies cells
        for slice2shift in range(0, slice_index + 1):
            chamber, roi = section_infos.loc[slice2shift, ["chamber", "roi"]]
            cell_this_slice = (barcoded_cells["chamber"] == chamber) & (
                barcoded_cells["roi"] == roi
            )
            barcoded_cells.loc[cell_this_slice, "ara_y_rot_serial"] += (
                y_shift_interpolator(
                    barcoded_cells.loc[
                        cell_this_slice, ["ara_y_rot_serial", "ara_z_rot_serial"]
                    ]
                )
                / 1000
            )
            barcoded_cells.loc[cell_this_slice, "ara_z_rot_serial"] += (
                z_shift_interpolator(
                    barcoded_cells.loc[
                        cell_this_slice, ["ara_y_rot_serial", "ara_z_rot_serial"]
                    ]
                )
                / 1000
            )

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].scatter(
    barcoded_cells["ara_y_rot"],
    barcoded_cells["ara_z_rot"],
    c=barcoded_cells["slice"],
    cmap="viridis",
    s=10,
    alpha=0.5,
    vmin=0,
    vmax=39,
)
axes[0].set_title("Original")
axes[1].scatter(
    barcoded_cells["ara_y_rot_serial"],
    barcoded_cells["ara_z_rot_serial"],
    c=barcoded_cells["slice"],
    cmap="viridis",
    s=10,
    alpha=0.5,
    vmin=0,
    vmax=39,
)
axes[1].set_title("Registered")
for ax in axes:
    ax.set_xlim(1, 6)
    ax.set_ylim(4, 10)
    ax.set_aspect("equal")
plt.tight_layout()

In [None]:
# First find barcodes that are in only one starter
starter_cells = barcoded_cells.query("is_starter == True")
starter_cells = starter_cells.query("cortical_area == 'VISp'")
starter2bc = starter_cells.all_barcodes.explode()
bc_cnt = starter2bc.value_counts()
unique_barcodes = bc_cnt[bc_cnt == 1].index
print(
    f"Found {len(unique_barcodes)} barcodes present in only 1 starter out of {len(bc_cnt)} in {len(starter_cells)} starter cells"
)

cell2bc = barcoded_cells.all_barcodes.explode()
unique_bc_cells = cell2bc[cell2bc.isin(unique_barcodes)]
print(
    f"Found {len(unique_bc_cells)} cells with unique barcodes (including the starters)"
)

barcoded_cells["is_unique_bc"] = False
barcoded_cells.loc[unique_bc_cells.index, "is_unique_bc"] = True
barcoded_cells["unique_barcodes"] = [set() for _ in range(len(barcoded_cells))]
# Find the corresponding starter index too
bc2starter = starter2bc.reset_index().set_index("all_barcodes")
barcoded_cells["starter_ids"] = [set() for _ in range(len(barcoded_cells))]
for mask, bc in unique_bc_cells.items():
    barcoded_cells.loc[mask, "unique_barcodes"].add(bc)
    barcoded_cells.loc[mask, "starter_ids"].add(bc2starter.loc[bc].mask_uid)

barcoded_cells["starter_id"] = np.nan
for idx, starter_ids in barcoded_cells.starter_ids.items():
    if len(starter_ids) == 1:
        barcoded_cells.loc[idx, "starter_id"] = list(starter_ids)[0]

In [None]:
from iss_preprocess.io.load import get_pixel_size

px_size = get_pixel_size(data_path="becalia_rabies_barseq/BRAC8498.3e/chamber_08")

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
bins = np.arange(0, 1000, 20)
coords_to_plot = [
    ["ara_x_rot", "ara_y_rot", "ara_z_rot"],
    ["ara_x_rot", "ara_y_rot_serial", "ara_z_rot_serial"],
]
labels = ["Original", "Registered"]
for i in range(2):
    dst = dict(within=[], next_prev=[], all=[])
    coords_col = coords_to_plot[i]
    for bc in unique_barcodes:
        bc_cells = cell2bc[cell2bc == bc].index
        presynaptic = barcoded_cells.loc[bc_cells].query("is_starter == False")
        # only v1
        presynaptic = presynaptic.query("cortical_area == 'VISp'")
        starter = bc2starter.loc[bc].mask_uid
        starter_prop = starter_cells.loc[starter]
        starter_slice = starter_prop["slice"]

        bc_cells_this_slice = presynaptic.query("slice == @starter_slice")
        bc_cells_prev_or_next = presynaptic.query(
            "slice == @starter_slice - 1 or slice == @starter_slice + 1"
        )

        start_coord = starter_prop[coords_col].values.astype(float)
        dst["within"].append(
            np.linalg.norm(
                bc_cells_this_slice[coords_col].values.astype(float) - start_coord,
                axis=1,
            )
        )
        dst["next_prev"].append(
            np.linalg.norm(
                bc_cells_prev_or_next[coords_col].values.astype(float) - start_coord,
                axis=1,
            )
        )
        dst["all"].append(
            np.linalg.norm(
                presynaptic[coords_col].values.astype(float) - start_coord, axis=1
            )
        )
    for iax, (key, dstara) in enumerate(dst.items()):
        dstara = np.hstack(dstara) * 1000
        axes[iax].hist(dstara, bins=bins, alpha=0.5, label=labels[i])
        axes[iax].set_title(f"{key} slice")
        axes[iax].set_xlabel("Distance (um)")
        axes[iax].set_ylabel("Count")
axes[-1].legend()
fig.tight_layout()

In [None]:
starter_cells = barcoded_cells.query("is_starter == True")
presynaptic = barcoded_cells.query("is_starter == False")
rel_ara = []
rel_coord = []
rel_coord_serial = []
rel_ara_serial = []
n = 0
for starter, starter_prop in starter_cells.iterrows():
    bcs = starter_prop.unique_barcodes
    pres = presynaptic.starter_ids.map(lambda x: starter in x)
    pres = presynaptic[pres]
    if not len(pres):
        n+=1
        continue
    rel_ara.append(np.vstack(pres[['ara_x', 'ara_y', 'ara_z']].values) - starter_prop[['ara_x', 'ara_y', 'ara_z']].values)
    rel_coord.append(np.vstack(pres[['ara_x_rot', 'ara_y_rot', 'ara_z_rot']].values) - starter_prop[['ara_x_rot', 'ara_y_rot', 'ara_z_rot']].values)
    rel_coord_serial.append(np.vstack(pres[['ara_x_rot_serial', 'ara_y_rot_serial', 'ara_z_rot_serial']].values) - starter_prop[['ara_x_rot_serial', 'ara_y_rot_serial', 'ara_z_rot_serial']].values)
    rel_ara_serial.append(np.vstack(pres[['ara_x_serial', 'ara_y_serial', 'ara_z_serial']].values) - starter_prop[['ara_x_serial', 'ara_y_serial', 'ara_z_serial']].values)
rel_ara = np.vstack(rel_ara).astype(float)
rel_coord = np.vstack(rel_coord).astype(float)
rel_coord_serial = np.vstack(rel_coord_serial).astype(float)
rel_ara_serial = np.vstack(rel_ara_serial).astype(float)


In [None]:
fig,axes = plt.subplots(1,4, figsize=(10,5))

axes[0].scatter(rel_coord[:,2], rel_coord[:,1], alpha=0.01, color='k')
axes[1].scatter(rel_coord_serial[:,2], rel_coord_serial[:,1], alpha=0.01, color='k')
axes[2].scatter(rel_ara[:,2], rel_ara[:,1], alpha=0.01, color='k')
axes[3].scatter(rel_ara_serial[:,2], rel_ara_serial[:,1], alpha=0.01, color='k')
for ax in axes:
    ax.set_aspect('equal')
    ax.set_ylim(1,-1)
    ax.set_xlim(-1,1)

In [None]:
rel_distance_ara = np.linalg.norm(rel_ara, axis=1)
rel_distance = np.linalg.norm(rel_coord, axis=1)
rel_distance_serial = np.linalg.norm(rel_coord_serial, axis=1)
rel_distance_ara_serial = np.linalg.norm(rel_ara_serial, axis=1)

plt.figure(figsize=(10, 3))
plt.hist(rel_distance_ara, bins=np.arange(0, 4,0.02), histtype='step', label='ARA')
plt.hist(rel_distance, bins=np.arange(0, 4,0.02), histtype='step', label='ARA rot')
plt.hist(rel_distance_serial, bins=np.arange(0, 4,0.02), histtype='step', label='Serial registered')
plt.hist(rel_distance_ara_serial, bins=np.arange(0, 4,0.02), histtype='step', label='Serial registered in ARA')
plt.legend(loc='upper right')
plt.xlim(0, 4)
plt.show()

In [None]:
barcoded_cells.columns

In [None]:
from brisc.manuscript_analysis import distance_between_cells as dist_cells
from brisc.manuscript_analysis.flatmap_projection import compute_flatmap_coors

flat_coors = compute_flatmap_coors(barcoded_cells)
barcoded_cells["flatmap_x"] = flat_coors[:, 0]
barcoded_cells["flatmap_y"] = flat_coors[:, 1]
barcoded_cells["flatmap_z"] = flat_coors[:, 2]
# same for serial_registration
flat_coors = compute_flatmap_coors(barcoded_cells, col_suffix='_serial')
barcoded_cells["flatmap_x_serial"] = flat_coors[:, 0]
barcoded_cells["flatmap_y_serial"] = flat_coors[:, 1]
barcoded_cells["flatmap_z_serial"] = flat_coors[:, 2]

relative_presyn_coords_flatmap, distancess_flatmap = dist_cells.determine_presynaptic_distances(barcoded_cells, col_prefix="flatmap_", col_suffix='')
relative_presyn_coords_flatmap_serial, distancess_flatmap_serial = dist_cells.determine_presynaptic_distances(barcoded_cells, col_prefix="flatmap_", col_suffix='_serial')

In [None]:

ax = plt.subplot(2,1,1)
dist_cells.plot_relative_coors(
    relative_presyn_coords_flatmap / 100,
    ax=ax,
    s=1,
    alpha=0.05,
    color="black",
    label_fontsize=5,
    tick_fontsize=5,
    coors_to_plot=(0, 2),
    lims=((-5, 5), (-1, 1)),
    labels=("", "")
)


In [None]:
all_shuffled_distances_flatmap, _ = (
    dist_cells.create_barcode_shuffled_nulls_parallel(
        barcoded_cells, N_iter=20, col_prefix="flatmap_", col_suffix=''
    )
)
all_shuffled_distances_flatmap_serial, _ = (
    dist_cells.create_barcode_shuffled_nulls_parallel(
        barcoded_cells, N_iter=20, col_prefix="flatmap_", col_suffix='_serial'
    )
)

In [None]:
import seaborn as sns

plt.figure(figsize=(10,3))
for iax, bw in enumerate([1, 0.2]):
    ax_kdeplot = plt.subplot(1,2,1 + iax)
    sns.kdeplot(
        relative_presyn_coords_flatmap[:, 0] / 100, 
        label="Original", 
        ax=ax_kdeplot,
        color="black",
        linewidth=0.9,
        bw_adjust=bw
    )
    sns.kdeplot(
    all_shuffled_distances_flatmap[0][:, 0] / 100, 
    label="Shuffle", 
    ax=ax_kdeplot,
    color="black",
    linewidth=0.9,
    linestyle=":",
    bw_adjust=bw,
    )

    sns.kdeplot(
        relative_presyn_coords_flatmap_serial[:, 0] / 100, 
        label="Registered", 
        ax=ax_kdeplot,
        color="darkorchid",
        linewidth=0.9,
        bw_adjust=bw
    )
    sns.kdeplot(
    all_shuffled_distances_flatmap_serial[0][:, 0] / 100, 
    label="Shuffle", 
    ax=ax_kdeplot,
    color="darkorchid",
    linewidth=0.9,
    linestyle=":",
    bw_adjust=bw
)


    ax_kdeplot.set_xlabel("Relative M-L location (mm)", fontsize=10)
    ax_kdeplot.set_ylabel("Density", fontsize=10)
ax_kdeplot.legend(loc='upper right')


In [None]:
relative_presyn_coords, distancess = dist_cells.determine_presynaptic_distances(barcoded_cells, col_prefix="ara_", col_suffix='')
relative_presyn_coords_serial, distancess_serial = dist_cells.determine_presynaptic_distances(barcoded_cells, col_prefix="ara_", col_suffix='_serial')

In [None]:
distancess.shape

In [None]:
bins

In [None]:

bins = np.arange(0.5, 4000, 1)
volume =4/3*np.pi*(bins/1000)**3

fig, axes = plt.subplots(2,2, figsize=(6,5))

n_starters = barcoded_cells.is_starter.sum()
colors = ['black', 'darkorchid']
for idist, dist in enumerate([distancess, distancess_serial]):
    sort_d = np.sort(dist)
    borders = sort_d.searchsorted(bins)
    for row in range(2):
        axes[row, 0].plot(bins[::20][:-1], np.diff(borders[::20])/n_starters, color=colors[idist],drawstyle='steps-post')
        axes[row, 1].plot(bins, borders/volume/n_starters, color=colors[idist])
for row in range(2):
    axes[row, 0].set_ylabel('# presynaptic cell\nper starter')
    axes[row, 1].set_ylabel('Cell density (cell.mm-3)')
    axes[row, 0].set_xlabel('Distance (um)')
    axes[row, 1].set_xlabel('Distance (um)')
axes[1, 0].set_xlim(0, 1000)
axes[1, 1].set_xlim(0, 200)
fig.tight_layout()

In [None]:
valid_starter = barcoded_cells.unique_barcodes.map(lambda x: len(valid_barcodes.intersection(x))>0)
valid_starter = barcoded_cells.loc[valid_starter].query('is_starter==False')
valid_starter

In [None]:
df = pd.DataFrame(dict(ml=rel_coords_flatmap[:,0]*10, ap=rel_coords_flatmap[:,1]*10, dv=rel_coords_flatmap[:,2]*10, starter_layer=[layer]*len(valid_starter), pre_layer=valid_starter.area_acronym.values))
df

In [None]:
rel_coords_flatmap.shape

In [None]:
valid_starter.shape

In [None]:
starter_cells = barcoded_cells.query('is_starter==True')
l23 = starter_cells.query('area_acronym== "VISp2/3"')
l5 = starter_cells.query('area_acronym== "VISp5"')
print(f"Got {len(starter_cells)} starter cells with {len(l23)} in layer 2/3 and {len(l5)} in layer 5")
rel_layer = pd.DataFrame(columns=['ml','ap', 'dv', 'starter_layer'])
rel_layer_serial = pd.DataFrame(columns=['ml','ap', 'dv', 'starter_layer'])

rel_coords_by_layer = dict()
PLOT = False
if PLOT:
    fig, axes = plt.subplots(2,2, figsize=(20,5))
for ilayer, layer in enumerate(['VISp2/3', 'VISp5']):
    starter_layer = starter_cells.query(f'area_acronym== "{layer}"')
    # keep only the cells of this layer
    valid_barcodes = set()
    for bc in starter_layer.unique_barcodes:
        valid_barcodes.update(bc)
    valid_starter = barcoded_cells.unique_barcodes.map(lambda x: len(valid_barcodes.intersection(x))>0)
    
    rel_coords_flatmap, dst_flatmap = dist_cells.determine_presynaptic_distances(barcoded_cells.loc[valid_starter], col_prefix="flatmap_", col_suffix='', subtract_z=False)
    df = pd.DataFrame(dict(ml=rel_coords_flatmap[:,0]*10, ap=rel_coords_flatmap[:,1]*10, dv=rel_coords_flatmap[:,2]*10, starter_layer=[layer]*len(rel_coords_flatmap)))
    rel_layer = pd.concat([rel_layer, df], ignore_index=True)
    rel_coords_flatmap_serial, dst_flatmap_serial = dist_cells.determine_presynaptic_distances(barcoded_cells.loc[valid_starter], col_prefix="flatmap_", col_suffix='_serial', subtract_z=False)
    df = pd.DataFrame(dict(ml=rel_coords_flatmap_serial[:,0]*10, ap=rel_coords_flatmap_serial[:,1]*10, dv=rel_coords_flatmap_serial[:,2]*10, starter_layer=[layer]*len(rel_coords_flatmap_serial)))
    rel_layer_serial = pd.concat([rel_layer_serial, df], ignore_index=True)
    rel_coords_by_layer[layer] = [rel_coords_flatmap, rel_coords_flatmap_serial]
    if PLOT:
        kw = dict(        s=5,
            alpha=0.1,
            lims=((-5, 5), (1.2, 0)),
                    label_fontsize=10,
            tick_fontsize=10,
            coors_to_plot=(0, 2),
            
            labels=("", "")
       )
        dist_cells.plot_relative_coors(
            rel_coords_flatmap / 100,
            ax=axes[0, ilayer],
            color="black",
            **kw
        )

        sns.kdeplot(x=rel_coords_flatmap[:,0] / 100,y=rel_coords_flatmap[:,2] / 100,ax=axes[0, ilayer], color='black')
        dist_cells.plot_relative_coors(
            rel_coords_flatmap_serial / 100,
            ax=axes[1, ilayer],
            color="darkorchid",
            **kw
        )
        sns.kdeplot(x=rel_coords_flatmap_serial[:,0] / 100,y=rel_coords_flatmap_serial[:,2] / 100,ax=axes[1, ilayer], color='purple')
if PLOT:
    fig.tight_layout()

In [None]:
data = rel_layer[np.abs(rel_layer.ml)<1000]
g = sns.JointGrid(
    data=data,
    x="ml", y="dv", hue="starter_layer",
    ylim=(1200, 0),
    xlim=(-1000,1000),
    ratio=3,
)
g.plot_joint(sns.kdeplot,  zorder=0, levels=6, bw_adjust=0.7)
g.plot_marginals(sns.kdeplot, bw_adjust=0.5, common_norm=False)

In [None]:
data = rel_layer_serial[np.abs(rel_layer_serial.ml)<1000]
g = sns.JointGrid(
    data=data,
    x="ml", y="dv", hue="starter_layer",
    ylim=(1200, 0),
    xlim=(-1000,1000),
    ratio=3,
)
g.plot_joint(sns.kdeplot,  zorder=0, levels=6, bw_adjust=0.7)
g.plot_marginals(sns.kdeplot, bw_adjust=0.5, common_norm=False)

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, axes = plt.subplots(1,3, figsize=(15, 20))
which = 0 # 0 for non registered, 1 for registered
col = ['#a1c9f4', '#8de5a1',]
kde_kwargs = dict(levels=[0.25, 0.5, 0.75, 0.9], bw_adjust=0.8)
for ilayer, layer in enumerate(rel_coords_by_layer):
    rel = rel_coords_by_layer[layer][which]
    rel = np.array(rel[np.abs(rel[:,0])<100]) *10
    x=rel[:,0]
    y=rel[:,2]
    sns.histplot(x=x, y=y, bins=[np.arange(-1000,1001, 100), np.arange(0, 1201, 100)], pthresh=.01, cmap="mako", ax=axes[ilayer])
    sns.kdeplot(x=x, y=y, ax=axes[ilayer], color=col[ilayer], **kde_kwargs)
    sns.scatterplot(x=x, y=y, s=10, color=".15", ax=axes[ilayer], alpha=0.5)
    divider = make_axes_locatable(axes[ilayer])

    #ax_marginal_right = divider.append_axes("right", size="5%", pad=0.05)
    
    ax_marginal_top = divider.append_axes("top", size="15%", pad=0.05)
    sns.kdeplot(ax=ax_marginal_top, x=x, color=col[ilayer],  bw_adjust=0.5)
    sns.kdeplot(x=x, y=y, linewidths=2, ax=axes[2], color=col[ilayer], **kde_kwargs)

for ax in axes.flatten():
    ax.set_aspect('equal')
    ax.set_xlim(-1000,1000)
    ax.set_ylim(1200, 0)

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, axes = plt.subplots(1,3, figsize=(15, 20))
which = 1 # 0 for non registered, 1 for registered
col = ['#a1c9f4', '#8de5a1',]
kde_kwargs = dict(levels=[0.25, 0.5, 0.75, 0.9], bw_adjust=0.8)
for ilayer, layer in enumerate(rel_coords_by_layer):
    rel = rel_coords_by_layer[layer][which]
    rel = np.array(rel[np.abs(rel[:,0])<100]) *10
    x=rel[:,0]
    y=rel[:,2]
    sns.histplot(x=x, y=y, bins=[np.arange(-1000,1001, 100), np.arange(0, 1201, 100)], pthresh=.01, cmap="mako", ax=axes[ilayer])
    sns.kdeplot(x=x, y=y, ax=axes[ilayer], color=col[ilayer], **kde_kwargs)
    sns.scatterplot(x=x, y=y, s=10, color=".15", ax=axes[ilayer], alpha=0.5)
    divider = make_axes_locatable(axes[ilayer])

    #ax_marginal_right = divider.append_axes("right", size="5%", pad=0.05)
    
    ax_marginal_top = divider.append_axes("top", size="15%", pad=0.05)
    sns.kdeplot(ax=ax_marginal_top, x=x, color=col[ilayer],  bw_adjust=0.5)
    sns.kdeplot(x=x, y=y, linewidths=2, ax=axes[2], color=col[ilayer], **kde_kwargs)

for ax in axes.flatten():
    ax.set_aspect('equal')
    ax.set_xlim(-1000,1000)
    ax.set_ylim(1200, 0)