# Long range connections

In [None]:
%load_ext autoreload
%autoreload 2

## Load data

Load big dataframe and add projected coordinates as well as retinotopy

In [None]:
# Set path to lab folder
from pathlib import Path

# This is the path to lab-znamenskiyp, the folder which contains home/shared/presentation
lab_folder_path = Path("Z:/")
save_fig = True
figname = "fig6_long_range"

In [None]:
from iss_preprocess.io import get_processed_path
from brisc.manuscript_analysis import load

processed_path = get_processed_path(
    "becalia_rabies_barseq/BRAC8498.3e/analysis/cell_barcode_df.pkl"
)
save_path = lab_folder_path / "home/shared/presentations/becalick_2025"

cell_barcode_df = load.load_cell_barcode_data(
    processed_path,
    areas_to_empty=["fiber tracts", "outside"],
    valid_areas=["Isocortex", "TH"],
    distance_threshold=150,
)

barcoded_cells = cell_barcode_df[cell_barcode_df.unique_barcodes.map(len) > 0].copy()
print(f"{len(barcoded_cells)} cells with a unique barcode")

## Get azimuth and elevation

And flatmap coordinates

In [None]:
# project barcoded cells on flatmap
from brisc.exploratory_analysis.plot_summary_for_all_bc import compute_flatmap_coors

flat_coors = compute_flatmap_coors(barcoded_cells)
barcoded_cells["flatmap_x"] = flat_coors[:, 0]
barcoded_cells["flatmap_y"] = flat_coors[:, 1]
barcoded_cells["flatmap_z"] = flat_coors[:, 2]

# The azimuth/elevation map is given on a top view, not flatmap. Add those
top_coors = compute_flatmap_coors(barcoded_cells, projection="top")
barcoded_cells["x_top"] = top_coors[:, 0]
barcoded_cells["y_top"] = top_coors[:, 1]
barcoded_cells["z_top"] = top_coors[:, 2]

In [None]:
# Add azimuth and elevation
import numpy as np
from cricksaw_analysis import atlas_utils

ara_elevation, ara_azimuth = atlas_utils.get_ara_retinotopic_map()

barcoded_cells["elevation"] = np.nan
barcoded_cells["azimuth"] = np.nan
# Keep only cells for which x_top and y_top are defined
cell_mask = barcoded_cells.dropna(subset=["x_top", "y_top"])
top_coors = cell_mask[["x_top", "y_top"]].values.astype(int)
for col, retino in zip(["elevation", "azimuth"], [ara_elevation, ara_azimuth]):
    barcoded_cells.loc[cell_mask.index, col] = retino[top_coors[:, 1], top_coors[:, 0]]


## Attribute starter to all cells

We will exclude cells that have barcodes for multiple starters

In [None]:
# For each presynaptic cell, add a starter_id
starter_cells = barcoded_cells.query("is_starter == True")
presynpatic = barcoded_cells.query("is_starter == False")

pres_with_multi_starter = []
for cell_id, pres_series in presynpatic.iterrows():
    starters = starter_cells[
        starter_cells["unique_barcodes"].apply(
            lambda barcodes: len(pres_series.unique_barcodes & barcodes) > 0
        )
    ]
    if len(starters) == 1:
        barcoded_cells.loc[cell_id, 'starter_id'] = starters.iloc[0].name
    elif len(starters) > 1:
        pres_with_multi_starter.append(cell_id)
    else:
        raise ValueError('This should not happen')
print(f'{len(pres_with_multi_starter)} cells with multiple starters')
print(f'{len(barcoded_cells) - barcoded_cells.starter_id.isna().sum()} cells with 1 starters')

In [None]:
# Reduce that to starter in V1
starter_cells = barcoded_cells.query("is_starter == True")
v1_starter_cells = starter_cells.query("cortical_area == 'VISp'").copy()
print(f"Found {len(v1_starter_cells)}/{len(starter_cells)} V1 starter cells")

cell_with_v1_starter = barcoded_cells.query(
    "starter_id.isin(@v1_starter_cells.index)"
).copy()
print(f"... and {len(cell_with_v1_starter)} cells with starter in V1")

# Calculate average map

For each position in the flatmap make a gaussian weighted average

In [None]:
# Function to calculate_weighted_average

from numba import njit
import numpy as np

@njit
def calculate_weighted_average(x:float, y:float, starter_pos, xy, sigma):
    """
    Calculates the weighted average of starter_pos based on a Gaussian function of the distance.

    Args:
        x (float): X coordinate.
        y (float): Y coordinate.
        starter_pos (np.array): array of position
        xy: The center point (x, y) for the Gaussian.
        sigma: The standard deviation of the Gaussian.

    Returns:
        The weighted average value and the total weight

    """
    distances = np.sqrt((x - xy[0,:])**2 + (y - xy[1,:])**2)
    weights = np.exp(-distances**2 / (2 * sigma**2))
    if np.sum(weights) > 0:
        weighted_average = np.sum(starter_pos * weights) / np.sum(weights)
    else:
        weighted_average = np.nan
    return weighted_average, np.sum(weights)

@njit
def gaussian_smooth(xx, yy, starter_pos, presy_xy, sigma):
    """
    Calculate the weighted average on each pixel of a grid

    Args:
        xx (np.array): x component of meshgrid
        yy (np.array): y component of meshgrid
        starter_pos (np.array): 
        xy: The center point (x, y) for the Gaussian.
        sigma: The standard deviation of the Gaussian.

    """
    img = np.empty(xx.shape, dtype=float)
    weights = np.empty(xx.shape, dtype=float)
    for i in range(xx.shape[0]):
        for j in range(xx.shape[1]):
            img[i,j], weights[i,j] = calculate_weighted_average(xx[i,j], yy[i,j], starter_pos, presy_xy, sigma)
    return img, weights

In [None]:
# CONSTANTS

xlim = (100, 1150)
ylim = (800, 1400)
scale = 10 / 1000 # scale of flatmap, 10um per pixel, output in mm
sigma = 200 # sigma in micron
clims = (-1, 1)
cmap='turbo_r'

In [None]:
# Calculate the smooth img for cortex
import skimage.morphology as morphology

presynaptic_cells = cell_with_v1_starter.query("is_starter == False")
presy_xy = presynaptic_cells[['flatmap_x', 'flatmap_y']].values.T
presy_azel = presynaptic_cells[['azimuth', 'elevation']].values.T
starter_pos = v1_starter_cells.loc[presynaptic_cells.starter_id, 'flatmap_x'].values
# remove cells with NaN 
valid = (np.isnan(presy_xy).sum(axis=0) + np.isnan(starter_pos)) == 0
presy_xy = presy_xy[:, valid]
starter_pos=starter_pos[valid]
presy_azel = presy_azel[:, valid]
# We have registered the left hemisphere to the right, so flip the ML pos. Express it
# relative to center of starter as it is somewhat arbitrary on the flatmap
center_abs =np.nanmean(starter_pos)
def rel_pos(x):
    return -(x-center_abs)
relative_starter_pos = rel_pos(starter_pos)

scale = 10 / 1000 # scale of flatmap, 10um per pixel, output in mm
pixel_size = 100
xx, yy = np.meshgrid(
    np.arange(xlim[0], xlim[1], pixel_size / 100),
    np.arange(ylim[0], ylim[1], pixel_size / 100),
)

# divid sigma by 10 as atlas is 10um /px
ctx_img, total_weights =  gaussian_smooth(xx, yy, relative_starter_pos *scale, presy_xy, sigma=sigma/10)

# Make a clipping mask to hide value outside of data range
mask = np.zeros(xx.shape)
presy_index = (presy_xy - np.array([xlim[0], ylim[0]])[:,None]).astype(int)
mask[presy_index[1], presy_index[0]] = True
ctx_mask = morphology.convex_hull_image(mask).astype(bool)


In [None]:
# Make a smooth version of runing average with gaussian KDE
def gauss_kde_1d(x_out, x_data, data_value, sigma):
    """"1D weighted average with gaussian kernel"""
    distances = x_out - x_data[:,None]
    weights = np.exp(-distances**2 / (2 * sigma**2))
    total_weights = np.sum(weights, axis=0)
    return (data_value[None,:] @ weights)[0]/total_weights, total_weights


x_calc = np.arange(xlim[0], xlim[1], 1)
pres_vs_start_kde, pres_vs_start_weight = gauss_kde_1d(x_calc, presy_xy[0], relative_starter_pos, sigma/10)

# Calculate bootstrap CI for ML position
from tqdm import tqdm

seed = 22
n_boot = 1000

rand_gen = np.random.default_rng(seed)
n_samples = len(v1_starter_cells)

print(f'Resampling starter cells {n_boot} times')
shuffled = []
for iboot in range(n_boot):
    bootstrap_sample = v1_starter_cells.sample(n=n_samples, replace=True, random_state=rand_gen)
    presy = presynaptic_cells[presynaptic_cells.starter_id.isin(bootstrap_sample.index)]
    presy_x =presy['flatmap_x'].values
    start_x = v1_starter_cells.loc[presy.starter_id, 'flatmap_x']
    valid = (~np.isnan(presy_x)) & (~np.isnan(start_x)) 
    shuffled.append(np.vstack([presy_x[valid], start_x[valid]]))

# Smooth KDE for shuffled
print('Calculating KDE')
shuffled_kde = np.empty((n_boot, len(x_calc)))
for iboot, boot in tqdm(enumerate(shuffled), total=n_boot):
    shuffled_kde[iboot], _ = gauss_kde_1d(x_calc, boot[0], rel_pos(boot[1]), sigma/10)

conf_int = np.percentile(shuffled_kde, [2.5, 97.5], axis=0)
median_position = np.nanmedian(relative_starter_pos)
mean_position = np.nanmean(relative_starter_pos)
print(f'Mean starter position is {mean_position:.1f}')

## Figure

In [None]:
import matplotlib.pyplot as plt

fontsize_dict = {"title": 8, "label": 7, "tick": 6, "legend": 6}
pad_dict = {"label": 1, "tick": 1, "legend": 5}
aspect_ratio = np.diff(xlim)[0]/np.diff(ylim)[0]
cm = 1 / 2.54
fig = plt.figure(figsize=(8.8 * cm, 8.8 / aspect_ratio * cm), dpi=300)


# Define axes and plot atlas borders
ax_presyn = fig.add_axes([0.1, 0.45, 0.4, 0.4])
ax_starter = fig.add_axes([0.07 , 0.48, 0.18, 0.18]) # bottom left
ax_img = fig.add_axes([0.1, 0.05, 0.4, 0.4])
cax = fig.add_axes([0.1, 0.1, 0.15, 0.01]) # horizontal

ax_shuffle = fig.add_axes([0.6, 0.4, 0.3, 0.4])
ax_azi = fig.add_axes([0.6, 0.1, 0.3, 0.27])

for ax in [ax_starter, ax_presyn, ax_img]:
    atlas_utils.plot_flatmap(
        ax,
        hemisphere="right",
        area_colors={},
        alpha=0,
        ccf_streamlines_folder=None,
    )

# Plot starter cell positions
ax_starter.scatter(*presy_xy, color="#555555", marker='.', s=0.1, alpha=0.4)
v1_xy = v1_starter_cells[['flatmap_x', 'flatmap_y']].values.T
sc = ax_starter.scatter(*v1_xy, c=rel_pos(v1_xy[0])*scale, ec='w', 
                        cmap=cmap, s=4, marker='o', alpha=0.8, linewidths=0.08, clim=clims)
scale_bar = plt.Rectangle([780, 1200], 100, 10, color='k')
ax_starter.add_artist(scale_bar)
for spine in ax_starter.spines.values():
    spine.set_edgecolor('gray')
    spine.set_linewidth(1)

# Plot presynaptic cell positions
ax_presyn.scatter(*presy_xy, c=relative_starter_pos *scale,  cmap=cmap, s=1, marker='.', alpha=0.3, clim=clims)
scale_bar = plt.Rectangle([120, 1300], 100, 10, color='k')
ax_presyn.add_artist(scale_bar)

# Plot smoothed image
img = ax_img.imshow(ctx_img * ctx_mask, cmap=cmap, alpha=np.clip(total_weights/20,0,1)*ctx_mask,     origin="lower",
    extent=[xlim[0], xlim[1], ylim[0], ylim[1]],
    vmin=clims[0],
    vmax=clims[1],
)
scale_bar = plt.Rectangle([120, 1300], 100, 10, color='k')
ax_img.add_artist(scale_bar)

# add colorbar
fig.colorbar(img, cax=cax, orientation='horizontal')
cax.tick_params(
        axis="both",
        which="major",
        labelsize=fontsize_dict['tick'],
    )
cax.set_xlabel('Starter ML position (mm)', fontsize=fontsize_dict['label'])

# Plot shuffle
plt.sca(ax_shuffle)
plt.scatter(rel_pos(presy_xy[0]) * scale, relative_starter_pos *scale, marker='o', 
            ec='w', linewidths=0.1, color='k', s=3, alpha=0.3)
plt.axhline(mean_position *scale, color='k', ls='--', lw=1.5)
plt.fill_between(rel_pos(x_calc)* scale, conf_int[0] *scale, conf_int[1] *scale, color='darkorchid', alpha=0.4, linewidth=0)
plt.plot(rel_pos(x_calc)* scale, pres_vs_start_kde *scale, color='darkorchid', lw=2)
plt.xticks([-5,0,5], fontsize=fontsize_dict['tick'], labels=[])
plt.xlim(-5, 5)
plt.yticks([-1,0,1], fontsize=fontsize_dict['tick'])
plt.ylabel('Starter ML position (mm)', fontsize=fontsize_dict['label'])

# Add azimuth
plt.sca(ax_azi)
valid = ~np.isnan(presy_azel[0])
azi_kde, w = gauss_kde_1d(x_calc, presy_xy[0, valid], presy_azel[0, valid], sigma/10)
ax_azi.plot(rel_pos(x_calc)* scale, azi_kde, color='k', lw=2)
plt.xlim(-5, 5)
plt.xticks([-5,0,5], fontsize=fontsize_dict['tick'])
plt.xlabel('Presynaptic ML position (mm)', fontsize=fontsize_dict['label'])
plt.ylim(0, 60)
plt.yticks([0,30,60], fontsize=fontsize_dict['tick'])
plt.ylabel('Receptive field\nazimuth (degrees)', fontsize=fontsize_dict['label'])

for ax in [ax_azi, ax_shuffle]:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

for ax in [ax_presyn, ax_starter, ax_img]:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_axis_off()
    ax.set_xlim(xlim[::-1])
    ax.set_ylim(ylim[::-1])

ax_starter.set_ylim([1250, 900])
ax_starter.set_xlim([900, 300])
ax_starter.set_axis_on()


xt = np.arange(100, 1200, 300)
ax_img.set_xticks(xt)
_=ax_img.set_xticklabels([int(x) for x in xt*scale], fontsize=fontsize_dict['tick'])

if save_fig:
    fig.savefig(save_path / f"{figname}.pdf", format="pdf")
    fig.savefig(save_path / f"{figname}.png", format="png")


In [None]:
save_path

# Supplementary analysis

Thalamus: would be nice but we only have the very posterior edge, not enough to do much

In [None]:
# Same for LGd
import brainglobe_atlasapi as bga

th_presy = presynaptic_cells.query("area_acronym_ancestor_rank1 == 'TH'")
print(f'{len(th_presy)} cells in thalamus')
th_presy_xy = th_presy[['ara_z', 'ara_y']].values.T / scale # make them in 10s of um
th_starter_pos = v1_starter_cells.loc[th_presy.starter_id, 'flatmap_x'].values

th_xlim = (650, 950)
th_ylim = (250, 450)
th_xx, th_yy = np.meshgrid(
    np.arange(th_xlim[0], th_xlim[1], pixel_size / 100),
    np.arange(th_ylim[0], th_ylim[1], pixel_size / 100),
)
th_img, th_total_weights =  gaussian_smooth(th_xx, th_yy, th_starter_pos *scale, th_presy_xy, sigma=75/10,
                                              )


# Make a clipping mask to hide value outside of data range
mask = np.zeros(th_xx.shape)
th_presy_index = (th_presy_xy - np.array([th_xlim[0], th_ylim[0]])[:,None]).astype(int)
mask[th_presy_index[1], th_presy_index[0]] = True
th_mask = morphology.convex_hull_image(mask).astype(bool)

# Get a coronal slice of atlas for plotting
atlas = bga.bg_atlas.BrainGlobeAtlas(f"allen_mouse_10um")
plane = np.nanmedian(th_presy.ara_x)*100
cor_atlas = np.array(atlas.annotation[int(plane), :, :])


# plot thalamus
fig = plt.figure(figsize=(10,7))
ax_sc = plt.subplot(1,2,1, aspect='equal')
ax_th = plt.subplot(1,2,2, aspect='equal')
for ax in [ax_sc, ax_th]:
    atlas_utils.plot_borders_and_areas(
            ax,
            label_img=cor_atlas,
            areas_to_plot=[],
            color_kwargs=dict(colors='Gray',alpha=0.2),
            border_kwargs=dict(linewidths=1, colors='k', zorder=-1),
            label_atlas=atlas,
            get_descendants=True,
            plot_borders=True,
            label_filled_areas=False,
        )

im = ax_th.imshow(th_img * th_mask, cmap=cmap, alpha=np.clip(th_total_weights/1,0,1)*th_mask,     origin="lower",
    extent=[th_xlim[0], th_xlim[1], th_ylim[0], th_ylim[1]],
    vmin=clims[0],
    vmax=clims[1],
)


ax_sc.scatter(*th_presy_xy, c=th_starter_pos *scale,  
            cmap=cmap, s=50, marker='o',ec='w',linewidths=1, alpha=0.8, clim=clims)
for ax in [ax_sc, ax_th]:
    ax.set_ylim(th_ylim[::-1])
    ax.set_axis_off()
    ax.set_xlim(th_xlim)
#plt.colorbar(im)