# Long range connections

In [None]:
%load_ext autoreload
%autoreload 2

## Load data

Load big dataframe and add projected coordinates as well as retinotopy

In [None]:
# Set path to lab folder
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.font_manager as fm

arial_font_path = "/nemo/lab/znamenskiyp/home/shared/resources/fonts/arial.ttf"  # update path as needed
arial_prop = fm.FontProperties(fname=arial_font_path)
plt.rcParams["font.family"] = arial_prop.get_name()
plt.rcParams.update({'mathtext.default': 'regular'}) # make math mode also Arial
fm.fontManager.addfont(arial_font_path)
matplotlib.rcParams["pdf.fonttype"] = 42


# This is the path to lab-znamenskiyp, the folder which contains home/shared/presentation
lab_folder_path = Path("Z:/")
save_fig = True
figname = "fig6_long_range"

In [None]:
from iss_preprocess.io import get_processed_path
from brisc.manuscript_analysis import load

error_correction_ds_name = "BRAC8498.3e_error_corrected_barcodes_26"
processed_path = get_processed_path(f"becalia_rabies_barseq/BRAC8498.3e/analysis")
save_path = lab_folder_path / "home/shared/presentations/becalick_2025"

cell_barcode_df = load.load_cell_barcode_data(
    processed_path,
    areas_to_empty=["fiber tracts", "outside"],
    valid_areas=["Isocortex", "TH"],
    distance_threshold=150,
    error_correction_ds_name=error_correction_ds_name,
)

barcoded_cells = cell_barcode_df[cell_barcode_df.unique_barcodes.map(len) > 0].copy()
print(f"{len(barcoded_cells)} cells with a unique barcode")

## Get azimuth and elevation

And flatmap coordinates

In [None]:
# project barcoded cells on flatmap
from brisc.manuscript_analysis.flatmap_projection import compute_flatmap_coors

flat_coors = compute_flatmap_coors(barcoded_cells)
barcoded_cells["flatmap_x"] = flat_coors[:, 0]
barcoded_cells["flatmap_y"] = flat_coors[:, 1]
barcoded_cells["flatmap_z"] = flat_coors[:, 2]

# The azimuth/elevation map is given on a top view, not flatmap. Add those
top_coors = compute_flatmap_coors(barcoded_cells, projection="top")
barcoded_cells["x_top"] = top_coors[:, 0]
barcoded_cells["y_top"] = top_coors[:, 1]
barcoded_cells["z_top"] = top_coors[:, 2]

In [None]:
# Add azimuth and elevation
import numpy as np
from cricksaw_analysis import atlas_utils

ara_elevation, ara_azimuth = atlas_utils.get_ara_retinotopic_map()

barcoded_cells["elevation"] = np.nan
barcoded_cells["azimuth"] = np.nan
# Keep only cells for which x_top and y_top are defined
cell_mask = barcoded_cells.dropna(subset=["x_top", "y_top"])
top_coors = cell_mask[["x_top", "y_top"]].values.astype(int)
for col, retino in zip(["elevation", "azimuth"], [ara_elevation, ara_azimuth]):
    barcoded_cells.loc[cell_mask.index, col] = retino[top_coors[:, 1], top_coors[:, 0]]

In [None]:
# Find azimuth on flatmap

from brisc.manuscript_analysis.flatmap_projection import get_projector
top_proj = get_projector("top")
flat_proj= get_projector("flatmap_dorsal")

# streamlines and lookup are only defined on left
topmidline = 1140//2
left_ara_az_top = ara_azimuth[:,:topmidline]
voxels = np.vstack(np.where(~np.isnan(left_ara_az_top))).T

# Find the flatten voxels on the 2d top view.
voxel_inds = np.ravel_multi_index(
            tuple(voxels[:, 1-i] for i in range(voxels.shape[1])),
            top_proj.view_size
        )
# And the corresponding 3d surface voxel
# sort the other col since we reverse the search
order = top_proj.view_lookup[:,0].argsort()
top_lookup = top_proj.view_lookup[order,:]
topind = top_lookup[:,0].searchsorted(voxel_inds)
topsurf = top_lookup[topind,1]

# For each top project voxel, find the closest in flat
order = flat_proj.view_lookup[:,1].argsort()
flat_lookup = flat_proj.view_lookup[order,:]
index = flat_lookup[:,1].searchsorted(topsurf)
valid = index < flat_lookup.shape[0]
flat_index =  np.unravel_index(
            flat_lookup[index[valid],0],
            flat_proj.view_size
        )
flat_index = np.array(flat_index).T
flat_x = np.zeros(len(index))
flat_y = np.zeros(len(index))
flat_x[valid] = flat_index[:,0]
flat_y[valid] = flat_index[:,1]

# Populate flatmap with azimuth values
flat_az = np.zeros(flat_proj.view_size[::-1]) + np.nan
flat_el = np.zeros(flat_proj.view_size[::-1]) + np.nan
for fx, fy, txty in zip(flat_x, flat_y, voxels):
    tx, ty = txty
    flat_az[int(fy), int(fx)] = left_ara_az_top[int(tx), int(ty)]
    flat_el[int(fy), int(fx)] = ara_elevation[int(tx), int(ty)]

# Interpolate missing points
from scipy.interpolate import griddata
flat_retino = {}
for fmap, ret_axis in zip([flat_az, flat_el], ['azimuth', 'elevation']):
    indices = np.argwhere(np.ones(fmap.shape, dtype=bool))
    # Find the indices of the non-NaN values
    non_nan_indices = np.argwhere(~np.isnan(fmap))
    # Get the values at the non-NaN indices
    non_nan_values = fmap[~np.isnan(fmap)]
    # Perform bicubic interpolation
    interpolated_values = griddata(non_nan_indices, non_nan_values, indices, method='cubic')
    # Replace the NaN values with the interpolated values
    fmap = interpolated_values.reshape(fmap.shape)

    # Create final flatmap by mirroring along midline
    view_space_for_other_hemisphere = 110 # constant from ccf_streamlines
    max_x = flat_proj.view_size[0] - view_space_for_other_hemisphere
    final_shape = [flat_proj.view_size[1], 2 * max_x]
    flat_final = np.zeros(final_shape) + np.nan
    flat_final[:,:max_x] = fmap[:,:max_x]
    flat_final[:, max_x:] = np.flip(fmap[:, :max_x], axis=1)
    flat_retino[ret_axis]= flat_final
    
    

In [None]:
plt.subplot(2,2,1)
plt.imshow(ara_azimuth, vmin=0, vmax=60)
atlas_utils.plot_flatmap(
    plt.gca(),
    hemisphere="both",
    area_colors={},
    alpha=0,
    ccf_streamlines_folder=None,
    ara_projection='top'
)
plt.xlim(0, 550)
plt.ylim(1100, 500)
plt.subplot(2,2,2)
plt.imshow(flat_retino['azimuth'], vmin=0, vmax=60)
atlas_utils.plot_flatmap(
    plt.gca(),
    hemisphere="both",
    area_colors={},
    alpha=0,
    ccf_streamlines_folder=None,
)

plt.xlim(0, 1200)
plt.ylim(1400, 500)
plt.subplot(2,2,3)
plt.imshow(ara_elevation, vmin=-20, vmax=20)
atlas_utils.plot_flatmap(
    plt.gca(),
    hemisphere="both",
    area_colors={},
    alpha=0,
    ccf_streamlines_folder=None,
    ara_projection='top'
)
plt.xlim(0, 550)
plt.ylim(1100, 500)
plt.subplot(2,2,4)
plt.imshow(flat_retino['elevation'], vmin=-20, vmax=20)

atlas_utils.plot_flatmap(
    plt.gca(),
    hemisphere="both",
    area_colors={},
    alpha=0,
    ccf_streamlines_folder=None,
)


plt.xlim(0, 1200)
plt.ylim(1400, 500)

## Attribute starter to all cells

We will exclude cells that have barcodes for multiple starters

In [None]:
# For each presynaptic cell, add a starter_id
starter_cells = barcoded_cells.query("is_starter == True")
presynpatic = barcoded_cells.query("is_starter == False")

pres_with_multi_starter = []
for cell_id, pres_series in presynpatic.iterrows():
    starters = starter_cells[
        starter_cells["unique_barcodes"].apply(
            lambda barcodes: len(pres_series.unique_barcodes & barcodes) > 0
        )
    ]
    if len(starters) == 1:
        barcoded_cells.loc[cell_id, "starter_id"] = starters.iloc[0].name
    elif len(starters) > 1:
        pres_with_multi_starter.append(cell_id)
    else:
        raise ValueError("This should not happen")
print(f"{len(pres_with_multi_starter)} cells with multiple starters")
print(
    f"{len(barcoded_cells) - barcoded_cells.starter_id.isna().sum()} cells with 1 starters"
)

In [None]:
# Reduce that to starter in V1
starter_cells = barcoded_cells.query("is_starter == True")
v1_starter_cells = starter_cells.query("cortical_area == 'VISp'").copy()
print(f"Found {len(v1_starter_cells)}/{len(starter_cells)} V1 starter cells")

cell_with_v1_starter = barcoded_cells.query(
    "starter_id.isin(@v1_starter_cells.index)"
).copy()
print(f"... and {len(cell_with_v1_starter)} cells with starter in V1")

# Calculate average map

For each position in the flatmap make a gaussian weighted average

In [None]:
# Function to calculate_weighted_average

from numba import njit
import numpy as np


@njit
def calculate_weighted_average(x: float, y: float, starter_pos, xy, sigma):
    """
    Calculates the weighted average of starter_pos based on a Gaussian function of the distance.

    Args:
        x (float): X coordinate.
        y (float): Y coordinate.
        starter_pos (np.array): array of position
        xy: The center point (x, y) for the Gaussian.
        sigma: The standard deviation of the Gaussian.

    Returns:
        The weighted average value and the total weight

    """
    distances = np.sqrt((x - xy[0, :]) ** 2 + (y - xy[1, :]) ** 2)
    weights = np.exp(-(distances**2) / (2 * sigma**2))
    if np.sum(weights) > 0:
        weighted_average = np.sum(starter_pos * weights) / np.sum(weights)
    else:
        weighted_average = np.nan
    return weighted_average, np.sum(weights)


@njit
def gaussian_smooth_2d(xx, yy, starter_pos, presy_xy, sigma):
    """
    Calculate the weighted average on each pixel of a grid

    Args:
        xx (np.array): x component of meshgrid
        yy (np.array): y component of meshgrid
        starter_pos (np.array):
        xy: The center point (x, y) for the Gaussian.
        sigma: The standard deviation of the Gaussian.

    """
    img = np.empty(xx.shape, dtype=float)
    weights = np.empty(xx.shape, dtype=float)
    for i in range(xx.shape[0]):
        for j in range(xx.shape[1]):
            img[i, j], weights[i, j] = calculate_weighted_average(
                xx[i, j], yy[i, j], starter_pos, presy_xy, sigma
            )
    return img, weights

In [None]:
# CONSTANTS

# xlim and ylim are in flatmap coordinates
xlim = (150, 1050)
ylim = (810, 1330)
scale = 10 / 1000  # scale of flatmap, 10um per pixel, output in mm
sigma = 200  # sigma in micron
clims = (-1, 1)
cmap = "turbo_r"

In [None]:
# Calculate the smooth img for cortex
import skimage.morphology as morphology

presynaptic_cells = cell_with_v1_starter.query("is_starter == False")
presy_xy = presynaptic_cells[["flatmap_x", "flatmap_y"]].values.T
presy_azel = presynaptic_cells[["azimuth", "elevation"]].values.T
starter_pos = v1_starter_cells.loc[presynaptic_cells.starter_id, "flatmap_x"].values
# remove cells with NaN
valid = (np.isnan(presy_xy).sum(axis=0) + np.isnan(starter_pos)) == 0
presy_xy = presy_xy[:, valid]
starter_pos = starter_pos[valid]
presy_azel = presy_azel[:, valid]
# We have registered the left hemisphere to the right, so flip the ML pos. Express it
# relative to center of starter as it is somewhat arbitrary on the flatmap
center_abs = np.nanmean(starter_pos)


def rel_pos(x):
    return -(x - center_abs)


relative_starter_pos = rel_pos(starter_pos)

scale = 10 / 1000  # scale of flatmap, 10um per pixel, output in mm
pixel_size = 10 / 1000
step = pixel_size / scale
xx, yy = np.meshgrid(
    np.arange(xlim[0], xlim[1] + 1.1 * step, step),
    np.arange(ylim[0], ylim[1] + 1.1 * step, step),
)
# divid sigma by 10 as atlas is 10um /px
ctx_img, total_weights = gaussian_smooth_2d(
    xx, yy, relative_starter_pos * scale, presy_xy, sigma=sigma / 10
)

# Make a clipping mask to hide value outside of data range
mask = np.zeros(xx.shape, dtype=bool)
presy_index = (
    (presy_xy - np.array([xlim[0], ylim[0]])[:, None]) * scale / pixel_size
).astype(int)
mask[presy_index[1], presy_index[0]] = True
ctx_mask = morphology.convex_hull_image(mask).astype(float)

In [None]:
import shapely
from shapely import MultiPoint
import cv2

multi_point = MultiPoint(
    ((presy_xy - np.array([xlim[0], ylim[0]])[:, None]) * scale / pixel_size)
    .astype(int)
    .T
)
poly = shapely.concave_hull(multi_point, ratio=0.1)
points = [[x, y] for x, y in zip(*poly.boundary.coords.xy)]

ctx_mask = cv2.fillPoly(
    np.zeros(xx.shape, dtype=float), np.array([points]).astype(np.int32), color=1
)

In [None]:
# Make a smooth version of runing average with gaussian KDE
def gauss_kde_1d(x_out, x_data, data_value, sigma):
    """ "1D weighted average with gaussian kernel"""
    distances = x_out - x_data[:, None]
    weights = np.exp(-(distances**2) / (2 * sigma**2))
    total_weights = np.sum(weights, axis=0)
    return (data_value[None, :] @ weights)[0] / total_weights, total_weights


x_calc = np.arange(xlim[0], xlim[1], 1)
pres_vs_start_kde, pres_vs_start_weight = gauss_kde_1d(
    x_calc, presy_xy[0], relative_starter_pos, sigma / 10
)

# Calculate bootstrap CI for ML position
from tqdm import tqdm

seed = 22
n_boot = 1000

rand_gen = np.random.default_rng(seed)
n_samples = len(v1_starter_cells)

print(f"Resampling starter cells {n_boot} times")
shuffled = []
for iboot in range(n_boot):
    bootstrap_sample = v1_starter_cells.sample(
        n=n_samples, replace=True, random_state=rand_gen
    )
    presy = presynaptic_cells[presynaptic_cells.starter_id.isin(bootstrap_sample.index)]
    presy_x = presy["flatmap_x"].values
    start_x = v1_starter_cells.loc[presy.starter_id, "flatmap_x"]
    valid = (~np.isnan(presy_x)) & (~np.isnan(start_x))
    shuffled.append(np.vstack([presy_x[valid], start_x[valid]]))

# Smooth KDE for shuffled
print("Calculating KDE")
shuffled_kde = np.empty((n_boot, len(x_calc)))
for iboot, boot in tqdm(enumerate(shuffled), total=n_boot):
    shuffled_kde[iboot], _ = gauss_kde_1d(x_calc, boot[0], rel_pos(boot[1]), sigma / 10)

conf_int = np.percentile(shuffled_kde, [2.5, 97.5], axis=0)
median_position = np.nanmedian(relative_starter_pos)
mean_position = np.nanmean(relative_starter_pos)
print(f"Mean starter position is {mean_position:.1f}")

## Figure

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
savefig = True
rasterized_scatter = True
# save_path = Path("Z:/home/shared/presentations/becalick_2025")
save_path = Path("/nemo/lab/znamenskiyp/home/shared/presentations/becalick_2025")

fontsize_dict = {"title": 8, "label": 7, "tick": 6, "legend": 6}
pad_dict = {"label": 1, "tick": 1, "legend": 5}
cm = 1 / 2.54
fig = plt.figure(figsize=(8.8 * cm, 20 * cm), dpi=150, frameon=True)

if True:
    # Temporary ax to mark the shape of the figure since jupyper crops
    ax_frame = fig.add_axes([0, 0, 1, 1])
    ax_frame.set_xticks([])
    ax_frame.set_yticks([])

# Define axes and plot atlas borders
ax_presyn = fig.add_axes([0.07, 0.73, 0.85, 0.3])
ax_img = fig.add_axes([0.07, 0.45, 0.85, 0.3])
ax_starter = fig.add_axes([0.07, 0.66, 0.4, 0.27])  # bottom left
ax_retino = fig.add_axes([0.07, 0.33, 0.4, 0.27])  # bottom left
retcax = fig.add_axes([0.49, 0.4, 0.017, 0.08])
cax = fig.add_axes([0.75, 0.73, 0.15, 0.01], facecolor='w')
ax_shuffle = fig.add_axes([0.18, 0.15, 0.8, 0.20])
ax_azi = fig.add_axes([0.18, 0.05, 0.8, 0.09])

for ax in [ax_starter, ax_presyn, ax_img]:
    atlas_utils.plot_flatmap(
        ax,
        hemisphere="right",
        area_colors={},
        alpha=0,
        ccf_streamlines_folder=None,
    )
atlas_utils.plot_flatmap(
        ax_retino,
        hemisphere="both",
        area_colors={},
        alpha=0,
        ccf_streamlines_folder=None,
    )
# Plot starter cells
if True:
    ax_starter.scatter(
        *presy_xy,
        color="#555555",
        marker=".",
        s=0.1,
        alpha=0.4,
        rasterized=rasterized_scatter,
    )
    v1_xy = v1_starter_cells[["flatmap_x", "flatmap_y"]].values.T
    sc = ax_starter.scatter(
        *v1_xy,
        c=rel_pos(v1_xy[0]) * scale,
        ec="w",
        cmap=cmap,
        s=4,
        marker="o",
        alpha=0.8,
        linewidths=0.08,
        clim=clims,
        rasterized=rasterized_scatter,
    )
    scale_bar = plt.Rectangle([680, 1150], 100, 10, color="k")
    ax_starter.add_artist(scale_bar)
    for spine in ax_starter.spines.values():
        spine.set_edgecolor("gray")
        spine.set_linewidth(1)

# Add retino plot
if True:
    im = ax_retino.imshow(flat_retino['azimuth'], cmap="rainbow_r", vmin=-10, vmax=70)
    scale_bar = plt.Rectangle([350, 1250], 100, 10, color="k")
    ax_retino.add_artist(scale_bar)
    cb = plt.colorbar(im, cax=retcax, orientation='vertical')
    cb.set_label('Azimuth\n(degrees)', fontsize=fontsize_dict['label'])
    retcax.set_yticks([0,50], labels=[0,50], fontsize=fontsize_dict['tick'])
    for spine in ax_retino.spines.values():
        spine.set_edgecolor("gray")
        spine.set_linewidth(1)
    if False:
        # not particularly useful. Remove for now
        ax_retino.contour(ctx_mask, levels=[0.5], colors='k', linestyles='--', linewidths=0.5,
                             extent=[max_x-xlim[0], max_x-xlim[1], ylim[0], ylim[1]])

# Plot presynaptic cells
if True:
    ax_presyn.scatter(
        *presy_xy,
        c=relative_starter_pos * scale,
        cmap=cmap,
        s=1,
        marker=".",
        alpha=0.3,
        clim=clims,
        rasterized=rasterized_scatter,
    )
    scale_bar = plt.Rectangle([160, 1280], 100, 10, color="k")
    ax_presyn.add_artist(scale_bar)

# Plot smoothed image
if True:
    img = ax_img.imshow(
        ctx_img * ctx_mask,
        cmap=cmap,
        alpha=np.clip(total_weights / 50, 0, 1) * ctx_mask,
        origin="lower",
        extent=[xlim[0], xlim[1], ylim[0], ylim[1]],
        vmin=clims[0],
        vmax=clims[1],
    )
    scale_bar = plt.Rectangle([160, 1280], 100, 10, color="k")
    ax_img.add_artist(scale_bar)
    # add colorbar
    fig.colorbar(img, cax=cax, orientation="horizontal")
    cax.tick_params(
        axis="both",
        which="major",
        labelsize=fontsize_dict["tick"],
    )
    # white background for colorbars
    rec = plt.Rectangle((80, 750), 400, 200, color='w')
    art = ax_img.add_artist(rec)
    art.set_zorder(100)
    # white background for inset retino
    rec = plt.Rectangle((400, 1200), 600, 500, color='w')
    art = ax_img.add_artist(rec)
    art.set_zorder(100)
    cax.set_xlabel("Starter ML position (mm)", fontsize=fontsize_dict["label"])

# Plot shuffle
if True:
    plt.sca(ax_shuffle)
    plt.scatter(
        rel_pos(presy_xy[0]) * scale,
        relative_starter_pos * scale,
        marker="o",
        ec="w",
        linewidths=0.1,
        color="k",
        s=3,
        alpha=0.3,
        rasterized=rasterized_scatter,
        clip_on=False,
    )
    plt.axhline(mean_position * scale, color="k", ls="--", lw=1.5)
    plt.fill_between(
        rel_pos(x_calc) * scale,
        conf_int[0] * scale,
        conf_int[1] * scale,
        color="darkorchid",
        alpha=0.4,
        linewidth=0,
    )
    plt.plot(
        rel_pos(x_calc) * scale, pres_vs_start_kde * scale, color="darkorchid", lw=2
    )
    plt.xticks([-4, 0, 4], fontsize=fontsize_dict["tick"], labels=[])
    plt.xlim(-4.5, 4.5)
    plt.ylim([-1,1])
    plt.yticks([-1, 0, 1], fontsize=fontsize_dict["tick"])
    plt.ylabel("Starter ML position (mm)", fontsize=fontsize_dict["label"])

# Add azimuth
if True:
    plt.sca(ax_azi)
    valid = ~np.isnan(presy_azel[0])
    azi_kde, w = gauss_kde_1d(
        x_calc, presy_xy[0, valid], presy_azel[0, valid], sigma / 10
    )
    ax_azi.plot(rel_pos(x_calc) * scale, azi_kde, color="k", lw=2)
    plt.xlim(-4.5, 4.5)
    plt.xticks([-4, 0, 4], fontsize=fontsize_dict["tick"])
    plt.xlabel("Presynaptic ML position (mm)", fontsize=fontsize_dict["label"])
    plt.ylim(0, 60)
    plt.yticks([0, 30, 60], fontsize=fontsize_dict["tick"])
    plt.ylabel("Receptive field\nazimuth (degrees)", fontsize=fontsize_dict["label"])

# clean up axes
if True:
    for ax in [ax_azi, ax_shuffle]:
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
    ax_shuffle.spines['bottom'].set_visible(False)
    ax_shuffle.set_xticks([])
    for ax in [ax_presyn, ax_starter, ax_img, ax_retino]:
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_axis_off()
        ax.set_xlim(xlim[::-1])
        ax.set_ylim(ylim[::-1])
    ax_starter.set_xlim([800, 400])
    ax_starter.set_ylim([1200, 950])
    ax_starter.set_axis_on()
    ax_retino.set_xlim([300, 900])
    ax_retino.set_ylim([1300, 850])
    ax_retino.set_axis_on()
if savefig:
    fig.savefig(save_path / f"{figname}.pdf", format="pdf", dpi=600)
    fig.savefig(save_path / f"{figname}.png", format="png")
    print(f"Figure saved in {save_path / figname}.pdf")

# Supplementary analysis

Thalamus: would be nice but we only have the very posterior edge, not enough to do much

In [None]:
# Same for LGd
import brainglobe_atlasapi as bga

th_presy = presynaptic_cells.query("area_acronym_ancestor_rank1 == 'TH'")
print(f"{len(th_presy)} cells in thalamus")
th_presy_xy = th_presy[["ara_z", "ara_y"]].values.T / scale  # make them in 10s of um
th_starter_pos = v1_starter_cells.loc[th_presy.starter_id, "flatmap_x"].values

th_xlim = (650, 950)
th_ylim = (250, 450)
th_xx, th_yy = np.meshgrid(
    np.arange(th_xlim[0], th_xlim[1], step),
    np.arange(th_ylim[0], th_ylim[1], step),
)
th_img, th_total_weights = gaussian_smooth_2d(
    th_xx,
    th_yy,
    th_starter_pos * scale,
    th_presy_xy,
    sigma=75 / 10,
)


# Make a clipping mask to hide value outside of data range
mask = np.zeros(th_xx.shape)
th_presy_index = (th_presy_xy - np.array([th_xlim[0], th_ylim[0]])[:, None]).astype(int)
mask[th_presy_index[1], th_presy_index[0]] = True
th_mask = morphology.convex_hull_image(mask).astype(bool)

# Get a coronal slice of atlas for plotting
atlas = bga.bg_atlas.BrainGlobeAtlas(f"allen_mouse_10um")
plane = np.nanmedian(th_presy.ara_x) * 100
cor_atlas = np.array(atlas.annotation[int(plane), :, :])


# plot thalamus
fig = plt.figure(figsize=(10, 7))
ax_sc = plt.subplot(1, 2, 1, aspect="equal")
ax_th = plt.subplot(1, 2, 2, aspect="equal")
for ax in [ax_sc, ax_th]:
    atlas_utils.plot_borders_and_areas(
        ax,
        label_img=cor_atlas,
        areas_to_plot=[],
        color_kwargs=dict(colors="Gray", alpha=0.2),
        border_kwargs=dict(linewidths=1, colors="k", zorder=-1),
        label_atlas=atlas,
        get_descendants=True,
        plot_borders=True,
        label_filled_areas=False,
    )

im = ax_th.imshow(
    th_img * th_mask,
    cmap=cmap,
    alpha=np.clip(th_total_weights / 1, 0, 1) * th_mask,
    origin="lower",
    extent=[th_xlim[0], th_xlim[1], th_ylim[0], th_ylim[1]],
    vmin=clims[0],
    vmax=clims[1],
)


ax_sc.scatter(
    *th_presy_xy,
    c=th_starter_pos * scale,
    cmap=cmap,
    s=50,
    marker="o",
    ec="w",
    linewidths=1,
    alpha=0.8,
    clim=clims,
    rasterized=rasterized_scatter,
)
for ax in [ax_sc, ax_th]:
    ax.set_ylim(th_ylim[::-1])
    ax.set_axis_off()
    ax.set_xlim(th_xlim)
# plt.colorbar(im)