In [1]:
import argparse
import logging
import os
import sys
import warnings
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from pathlib import Path
from typing import Tuple

import scanpy as sc
import yaml
from anndata import AnnData
from spatialdata import read_zarr

sys.path.insert(1, "/dss/dsshome1/0C/ra98gaq/Git/cellseg-benchmark")
from cellseg_benchmark.adata_utils import (
    filter_genes,
    filter_low_quality_cells,
    filter_spatial_outlier_cells,
    merge_adatas,
    normalize_counts,
    pca_umap_single,
)



In [2]:
# reload
from importlib import reload

import cellseg_benchmark.adata_utils as adu

reload(adu)
from cellseg_benchmark.adata_utils import (
    integration_harmony,
) # noqa: E402

In [3]:
def _load_one(
    sample_dir: Path, seg_method: str, logger: logging.Logger
) -> Tuple[str, AnnData | None]:
    """Load AnnData from one master sdata."""
    sdata = read_zarr(sample_dir / "sdata_z3.zarr", selection=("tables",))
    if f"adata_{seg_method}" not in sdata.tables.keys():
        if logger:
            logger.warning(f"Skipping {seg_method}. No such key: {seg_method}")
        return sample_dir.name, None
    return sample_dir.name, sdata[f"adata_{seg_method}"]


warnings.filterwarnings("ignore", ".*The table is annotating*", UserWarning)
sc.settings.n_jobs = -1

In [4]:
# Logger setup
logger = logging.getLogger("integrate_adatas")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s]: %(message)s"))
logger.addHandler(handler)

In [5]:
# CLI args
# sys.argv = ["notebook", "SynergyLung", "Cellpose_1_Merlin"]
sys.argv = ["notebook", "foxf2", "Cellpose_1_nuclei_model"]
# sys.argv = ["notebook", "foxf2", "Proseg_Cellpose_1_nuclei_model"]
#sys.argv = ["notebook", "aging", "Cellpose_1_nuclei_model"]
sys.argv = ["notebook", "htra1", "Cellpose_1_nuclei_model"]

parser = argparse.ArgumentParser(
    description="Integrate adatas from a selected segmentation method."
)
parser.add_argument("cohort", help="Cohort name, e.g., 'foxf2'")
parser.add_argument(
    "seg_method", help="Segmentation method, e.g., 'Cellpose_1_nuclei_model'"
)
args = parser.parse_args()
args

Namespace(cohort='htra1', seg_method='Cellpose_1_nuclei_model')

In [6]:
base_path = Path("/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark")
samples_path = base_path / "samples"
save_path = base_path / "analysis" / args.cohort / args.seg_method
save_path.mkdir(parents=True, exist_ok=True)

In [7]:
sample_metadata_file, excluded = (
    yaml.safe_load(open(base_path / "misc" / f))
    for f in ["sample_metadata.yaml", "samples_excluded.yaml"]
)
sample_metadata_file

{'foxf2_s1_r0': {'cohort': 'foxf2',
  'slide': 1,
  'region': 0,
  'genotype': 'ECKO',
  'age_months': 6,
  'run_date': '20240229',
  'animal_id': '000',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240229_Foxf2-Slide01-cp-WT-ECKO/region_0-ECKO000'},
 'foxf2_s1_r1': {'cohort': 'foxf2',
  'slide': 1,
  'region': 1,
  'genotype': 'WT',
  'age_months': 6,
  'run_date': '20240229',
  'animal_id': '000',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240229_Foxf2-Slide01-cp-WT-ECKO/region_1-WT000'},
 'foxf2_s2_r1': {'cohort': 'foxf2',
  'slide': 2,
  'region': 1,
  'genotype': 'WT',
  'age_months': 6,
  'run_date': '20240322',
  'animal_id': '536',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240322_Foxf2-Slide02-cp-WT-PCKO/region_1-WT536'},
 'foxf2_s2_r2': {'cohort': 'foxf2',
  'slide': 2,
  'region': 2,
  'genotype': 'PCKO',
  '

In [8]:
excluded_samples = set(excluded.get(args.cohort, []))
excluded_samples

set()

In [13]:
yaml_samples = [
    name
    for name, meta in sample_metadata_file.items()
    if meta.get("cohort") == args.cohort and name not in excluded_samples
]
###########################
if args.cohort == "htra1": # add 6/18m WT samples from aging cohort as additional controls
    yaml_samples += ["aging_s1_r0", "aging_s5_r1", "aging_s6_r0", "aging_s7_r2", "aging_s8_r2", "aging_s11_r0"]

In [14]:
yaml_samples

['htra1_s1_r1',
 'htra1_s3_r0',
 'htra1_s3_r1',
 'htra1_s4_r1',
 'htra1_s4_r2',
 'htra1_s6_r1',
 'htra1_s6_r2',
 'htra1_s7_r1',
 'aging_s1_r0',
 'aging_s5_r1',
 'aging_s6_r0',
 'aging_s7_r2',
 'aging_s8_r2',
 'aging_s11_r0']

In [15]:
logger.info("Loading data...")
loads = []
for name in yaml_samples:
    p = samples_path / name
    if not (p / "sdata_z3.zarr").exists():
        logger.error("master sdata in %s not found.", p)
        continue
    loads.append(p)

2025-11-21 13:28:32,987 [INFO]: Loading data...


In [16]:
loads

[PosixPath('/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/htra1_s1_r1'),
 PosixPath('/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/htra1_s3_r0'),
 PosixPath('/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/htra1_s3_r1'),
 PosixPath('/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/htra1_s4_r1'),
 PosixPath('/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/htra1_s4_r2'),
 PosixPath('/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/htra1_s6_r1'),
 PosixPath('/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/htra1_s6_r2'),
 PosixPath('/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/htra1_s7_r1'),
 PosixPath('/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/aging_s1_r0'),
 PosixPath('/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/aging_s5_r1'),
 PosixPath('/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/aging_s6_r0'),
 PosixPath('/dss/dssf

In [17]:
#loads = loads[:-5]

In [18]:
%%time
max_workers = int(os.getenv("SLURM_CPUS_PER_TASK", 1))
loader = partial(_load_one, seg_method=args.seg_method, logger=None)
with ProcessPoolExecutor(max_workers=max_workers) as ex:
    results = list(ex.map(loader, loads))

CPU times: user 810 ms, sys: 2.88 s, total: 3.69 s
Wall time: 3min 12s


In [19]:
#%%time
#loader = partial(_load_one, seg_method=args.seg_method, logger=logger)
#results = [loader(p) for p in loads[:-5]]

In [20]:
results

[('htra1_s1_r1',
  AnnData object with n_obs × n_vars = 49265 × 500
      obs: 'region', 'slide', 'spt_region', 'cell_type_incl_low_quality_revised', 'cell_type_mmc_incl_low_quality_clusters', 'cell_type_mmc_incl_low_quality', 'cell_type_incl_mixed_revised', 'cell_type_mmc_incl_mixed_clusters', 'cell_type_mmc_incl_mixed', 'cell_type_revised', 'cell_type_mmc_raw_clusters', 'cell_type_mmc_raw', 'cell_id', 'area', 'volume_sum', 'volume_final', 'num_z_planes', 'size_normalized', 'surface_to_volume_ratio', 'sphericity', 'solidity', 'elongation', 'genotype', 'age_months', 'run_date', 'organism', 'cohort', 'sample', 'animal_id', 'condition'
      uns: 'sopa_attrs', 'spatialdata_attrs'
      obsm: 'intensities', 'spatial', 'spatial_microns', 'spatial_pixel'),
 ('htra1_s3_r0',
  AnnData object with n_obs × n_vars = 50429 × 500
      obs: 'region', 'slide', 'spt_region', 'cell_type_incl_low_quality_revised', 'cell_type_mmc_incl_low_quality_clusters', 'cell_type_mmc_incl_low_quality', 'cell_type_

In [21]:
# keep YAML order, drop Nones
adata_list = [(name, adata) for name, adata in results if adata is not None]

In [25]:
adata_list

[('htra1_s1_r1',
  AnnData object with n_obs × n_vars = 49265 × 500
      obs: 'region', 'slide', 'spt_region', 'cell_type_incl_low_quality_revised', 'cell_type_mmc_incl_low_quality_clusters', 'cell_type_mmc_incl_low_quality', 'cell_type_incl_mixed_revised', 'cell_type_mmc_incl_mixed_clusters', 'cell_type_mmc_incl_mixed', 'cell_type_revised', 'cell_type_mmc_raw_clusters', 'cell_type_mmc_raw', 'cell_id', 'area', 'volume_sum', 'volume_final', 'num_z_planes', 'size_normalized', 'surface_to_volume_ratio', 'sphericity', 'solidity', 'elongation', 'genotype', 'age_months', 'run_date', 'organism', 'cohort', 'sample', 'animal_id', 'condition'
      uns: 'sopa_attrs', 'spatialdata_attrs'
      obsm: 'intensities', 'spatial', 'spatial_microns', 'spatial_pixel'),
 ('htra1_s3_r0',
  AnnData object with n_obs × n_vars = 50429 × 500
      obs: 'region', 'slide', 'spt_region', 'cell_type_incl_low_quality_revised', 'cell_type_mmc_incl_low_quality_clusters', 'cell_type_mmc_incl_low_quality', 'cell_type_

In [23]:
import pandas as pd

In [24]:
# temp fix for aging_s11_r0
for i, (n, ad) in enumerate(adata_list):
    #if n == "aging_s11_r0":
    if "aging" in n:
        m = sample_metadata_file[n]
        for k, v in {**m, "sample": n, "condition": f"{m['genotype']}_{m['age_months']}"} .items():
            ad.obs[k] = pd.Categorical([str(v)] * len(ad))
        adata_list[i] = (n, ad)

In [26]:
#adata_tmp = next((a for n, a in adata_list if n == "aging_s11_r0"), None)

In [27]:
# Merge and process
adata = merge_adatas(
    adata_list,
    seg_method=args.seg_method,
    logger=logger,
    plot_qc_stats=True,
    save_path=save_path / "plots",
)
del adata_list

2025-11-21 13:44:40,123 [INFO]: Merging adatas of Cellpose_1_nuclei_model
100%|██████████| 14/14 [00:00<00:00, 69.37it/s]
  utils.warn_names_duplicates("obs")
2025-11-21 13:44:41,489 [INFO]: Cellpose_1_nuclei_model: #cells=668675, #samples=14
2025-11-21 13:44:41,495 [INFO]: Plotting QC results
  fig.tight_layout()
  fig.tight_layout()
  fig.tight_layout()


In [28]:
adata.obs["sample"].value_counts(dropna=False)

sample
aging_s1_r0     56230
aging_s8_r2     55444
htra1_s4_r2     54216
htra1_s4_r1     52608
htra1_s6_r1     51711
htra1_s3_r0     50429
aging_s6_r0     49699
htra1_s1_r1     49265
aging_s7_r2     45655
htra1_s7_r1     44114
htra1_s3_r1     42369
htra1_s6_r2     40957
aging_s5_r1     40329
aging_s11_r0    35649
Name: count, dtype: int64

In [29]:
adata.obs["condition"].value_counts(dropna=False)

condition
TG_6     147151
KO_6     145850
WT_6     142214
WT_18    140792
KO_18     92668
Name: count, dtype: int64

In [30]:
adata.obs["cohort"].value_counts(dropna=False)

cohort
htra1    385669
aging    283006
Name: count, dtype: int64

In [31]:
adata.obs["age_months"].value_counts(dropna=False)

age_months
6     435215
18    233460
Name: count, dtype: int64

In [32]:
import pandas as pd

In [33]:
pd.crosstab(adata.obs["age_months"], adata.obs["sample"], dropna=False)

sample,aging_s11_r0,aging_s1_r0,aging_s5_r1,aging_s6_r0,aging_s7_r2,aging_s8_r2,htra1_s1_r1,htra1_s3_r0,htra1_s3_r1,htra1_s4_r1,htra1_s4_r2,htra1_s6_r1,htra1_s6_r2,htra1_s7_r1
age_months,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
18,35649,0,0,49699,0,55444,0,0,0,0,0,51711,40957,0
6,0,56230,40329,0,45655,0,49265,50429,42369,52608,54216,0,0,44114


In [34]:
# workaround to fix adata.obs formatting ###############
# adata.obs["sample"] = adata.obs["sample"].str.replace(
#    rf"^{args.cohort}_(\d+)_(\d+)$", rf"{args.cohort}_s\1_r\2", regex=True
# )
# adata.obs.drop(columns=["spt_region"], inplace=True, errors="ignore")
# adata.obs["condition"] = (
#    adata.obs["genotype"].astype(str) + "_" + adata.obs["age_months"].astype(str)
# )

In [35]:
adata.obs.columns

Index(['region', 'slide', 'cell_type_incl_low_quality_revised',
       'cell_type_mmc_incl_low_quality_clusters',
       'cell_type_mmc_incl_low_quality', 'cell_type_incl_mixed_revised',
       'cell_type_mmc_incl_mixed_clusters', 'cell_type_mmc_incl_mixed',
       'cell_type_revised', 'cell_type_mmc_raw_clusters', 'cell_type_mmc_raw',
       'area', 'volume_sum', 'volume_final', 'num_z_planes', 'size_normalized',
       'surface_to_volume_ratio', 'sphericity', 'solidity', 'elongation',
       'genotype', 'age_months', 'run_date', 'organism', 'cohort', 'sample',
       'animal_id', 'condition', 'n_counts', 'n_genes', 'path'],
      dtype='object')

In [36]:
adata.obsm["spatial"] = adata.obsm.get("spatial_microns", adata.obsm["spatial"])
adata = filter_spatial_outlier_cells(
    adata,
    data_dir=str(base_path),
    sample_metadata_file=sample_metadata_file,
    save_path=save_path / "plots",
    logger=logger,
)

2025-11-21 13:45:49,617 [INFO]: # total cells before filtering spatial outliers: 668675
2025-11-21 13:45:49,618 [INFO]: # spatial_outlier:       20942
2025-11-21 13:45:51,049 [INFO]: # total cells after filtering spatial outliers:  647733


In [37]:
#args.cohort == "SynergyLung"

In [38]:
if "vpt_3D" in args.seg_method:  # min_counts=10 due to smaller cell sizes
    min_counts = 10
elif args.cohort == "SynergyLung":  # more lenient for initial analysis
    min_counts = 15
else:
    min_counts = None  # default = 25

In [39]:
adata = filter_low_quality_cells(
    adata,
    save_path=save_path / "plots",
    **({"min_counts": min_counts} if min_counts is not None else {}),
    logger=logger,
)

adata = filter_genes(adata, save_path=save_path / "plots", logger=logger)

adata = normalize_counts(
    adata, save_path=save_path / "plots", seg_method=args.seg_method, logger=logger
)

2025-11-21 13:47:05,197 [INFO]: # total cells before filtering low-quality or volume outliers: 647733
2025-11-21 13:47:05,198 [INFO]: # low_quality_cell:             15854
2025-11-21 13:47:05,198 [INFO]: # volume_outlier_cell:       347
2025-11-21 13:47:10,217 [INFO]: # total cells after filtering low-quality or volume outliers:  631532
2025-11-21 13:47:10,218 [INFO]: # genes before filtering: 500
2025-11-21 13:47:12,102 [INFO]: # genes after filtering: 500
2025-11-21 13:47:13,790 [INFO]: Normalizing counts...
2025-11-21 13:47:14,951 [INFO]: Cells before/after outlier removal during normalization: 631532 -> 618900
  return dispatch(args[0].__class__)(*args, **kw)


In [40]:
adata = pca_umap_single(adata, save_path=save_path / "plots", logger=logger)

2025-11-21 13:47:26,771 [INFO]: Dimensionality reduction: PCA
2025-11-21 13:48:11,538 [INFO]: Dimensionality reduction: UMAP with n_neighbors=20, n_pcs=50


In [44]:
!mamba install bioconda::harmonypy -y


Looking for: ['bioconda::harmonypy']

[?25l[2K[0G[+] 0.0s
[2K[1A[2K[0Gbioconda/linux-64 (check zst)                      Checked  0.1s
[?25h[?25l[2K[0G[+] 0.0s
[2K[1A[2K[0Gbioconda/noarch (check zst)                        Checked  0.0s
[?25h[?25l[2K[0G[+] 0.0s
[2K[1A[2K[0G[+] 0.1s
conda-forge/linux-64 [90m━━━━━━━━━━━━━━━━━━━━━━━[0m  60.9kB /  48.9MB @ 998.1kB/s  0.1s
conda-forge/noarch   [90m╸[0m[33m━━━━━━━━━━━━━━━╸[0m[90m━━━━━━[0m   0.0 B /  ??.?MB @  ??.?MB/s  0.1s
bioconda/linux-64    [33m━━━━━━━━━━━━╸[0m[90m━━━━━━━━━━[0m   0.0 B /  ??.?MB @  ??.?MB/s  0.1s
bioconda/noarch      [90m━━━━━━━╸[0m[33m━━━━━━━━━━━━━━━[0m   0.0 B /  ??.?MB @  ??.?MB/s  0.1s[2K[1A[2K[1A[2K[1A[2K[1A[2K[0G[+] 0.2s
conda-forge/linux-64 [90m━━━━━━━━━━━━━━━━━━━━━━━[0m   2.1MB /  48.9MB @  12.3MB/s  0.2s
conda-forge/noarch   [90m━━━━━━━━━━━━━━━━━━━━━━━[0m 207.5kB /  23.4MB @   1.3MB/s  0.2s
bioconda/linux-64    ━━━╸[90m━━━━━━━━━━━━━━━━━━━[0m   1.0MB /   5.3

In [45]:
adata.obs.columns

Index(['region', 'slide', 'cell_type_incl_low_quality_revised',
       'cell_type_mmc_incl_low_quality_clusters',
       'cell_type_mmc_incl_low_quality', 'cell_type_incl_mixed_revised',
       'cell_type_mmc_incl_mixed_clusters', 'cell_type_mmc_incl_mixed',
       'cell_type_revised', 'cell_type_mmc_raw_clusters', 'cell_type_mmc_raw',
       'area', 'volume_sum', 'volume_final', 'num_z_planes', 'size_normalized',
       'surface_to_volume_ratio', 'sphericity', 'solidity', 'elongation',
       'genotype', 'age_months', 'run_date', 'organism', 'cohort', 'sample',
       'animal_id', 'condition', 'n_counts', 'n_genes', 'path',
       'spatial_outlier', 'low_quality_cell', 'volume_outlier_cell'],
      dtype='object')

In [46]:
%%time
adata = integration_harmony(
    adata,
    batch_key="slide",
    save_path=save_path / "plots",
    logger=logger,
)

2025-11-21 14:00:00,520 [INFO]: Integration: Run Harmony
2025-11-21 14:00:00,964 - harmonypy - INFO - Computing initial centroids with sklearn.KMeans...
2025-11-21 14:00:51,976 - harmonypy - INFO - sklearn.KMeans initialization complete.
2025-11-21 14:00:54,115 - harmonypy - INFO - Iteration 1 of 10
2025-11-21 14:04:13,668 - harmonypy - INFO - Iteration 2 of 10
2025-11-21 14:07:27,864 - harmonypy - INFO - Iteration 3 of 10
2025-11-21 14:10:42,705 - harmonypy - INFO - Iteration 4 of 10
2025-11-21 14:13:57,830 - harmonypy - INFO - Iteration 5 of 10
2025-11-21 14:17:11,394 - harmonypy - INFO - Converged after 5 iterations
2025-11-21 14:17:11,474 [INFO]: Integration: Compute neighbors and UMAP


CPU times: user 2h 8min 14s, sys: 2min 8s, total: 2h 10min 23s
Wall time: 37min 46s


In [47]:
logger.info("Saving integrated object...")
if "fov" not in adata.obs.columns:
    adata.obs["fov"] = ""
adata.obs["fov"] = adata.obs["fov"].astype(str)
output_path = save_path / "adatas"
output_path.mkdir(parents=True, exist_ok=True)
adata.write(output_path / "adata_integrated.h5ad.gz", compression="gzip")

2025-11-21 14:37:46,611 [INFO]: Saving integrated object...


In [48]:
adata.layers

Layers with keys: counts, volume_norm, volume_log1p_norm, zscore, librarysize_log1p_norm

In [10]:
###############################################
# load adata post hoc
output_path = save_path / "adatas"
adata = sc.read_h5ad(output_path / "adata_integrated.h5ad.gz")

In [11]:
adata.shape

(996164, 451)

In [12]:
adata

AnnData object with n_obs × n_vars = 996164 × 451
    obs: 'fov', 'volume', 'center_x', 'center_y', 'min_x', 'min_y', 'max_x', 'max_y', 'anisotropy', 'transcript_count', 'perimeter_area_ratio', 'Txnip_raw', 'Txnip_high_pass', 'Fth1_raw', 'Fth1_high_pass', 'DAPI_raw', 'DAPI_high_pass', 'Scgb1a1_raw', 'Scgb1a1_high_pass', 'Sftpc_raw', 'Sftpc_high_pass', 'PolyT_raw', 'PolyT_high_pass', 'Ifitm3_raw', 'Ifitm3_high_pass', 'region', 'slide', 'dataset_id', 'cells_region', 'area', 'volume_sum', 'volume_final', 'num_z_planes', 'size_normalized', 'surface_to_volume_ratio', 'sphericity', 'solidity', 'elongation', 'condition', 'run_date', 'organism', 'cohort', 'sample', 'n_counts', 'n_genes', 'Col1_raw', 'Col1_high_pass', 'spatial_outlier', 'low_quality_cell', 'volume_outlier_cell'
    var: 'n_cells'
    obsm: 'X_pca', 'X_pca_harmony', 'X_umap_20_50', 'X_umap_harmony_20_50', 'X_umap_harmony_20_50_3d', 'blank', 'spatial', 'spatial_microns', 'spatial_pixel'
    varm: 'PCs'
    layers: 'counts', 'libr

In [32]:
#######
# spatial plots post-hoc plotting
import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc


def _plot_flag(flag):
    n_cols = 3
    samples = adata.obs["sample"].unique()
    n_samples = len(samples)
    n_rows = math.ceil(n_samples / n_cols)

    fig, axes = plt.subplots(
        n_rows,
        n_cols,
        figsize=(4 * n_cols, 4 * n_rows),
        squeeze=False,
        sharex=False,
        sharey=False,
    )

    for idx, sample in enumerate(samples):
        ax = axes[idx // n_cols][idx % n_cols]

        mask = adata.obs["sample"] == sample
        sample_data = adata[mask]

        if sample_data.n_obs > int(1.5e5):
            sample_data = sc.pp.subsample(
                sample_data, n_obs=int(1.5e5), random_state=42, copy=True
            )

        coords = sample_data.obsm["spatial"]

        outliers = sample_data.obs[flag].values
        colors = pd.Categorical(
            np.where(outliers, "grey", "lightgrey"), categories=["lightgrey", "grey"]
        )

        ax.scatter(
            coords[:, 0],
            coords[:, 1],
            c=colors,
            s=max(0.3, min(0.7, 30000 / len(coords))),
            alpha=0.75,
            edgecolors="none",
        )
        ax.set_title(sample, fontsize=10)

        ax.set_xticks([])
        ax.set_yticks([])
        for spine in ax.spines.values():
            spine.set_visible(False)

    # Hide unused subplots
    for ax in axes.flat[n_samples:]:
        ax.set_visible(False)

    fig.suptitle(f"{flag.replace('_', ' ').title()} by sample", fontsize=14, y=1)
    fig.tight_layout()
    # fig.savefig(join(save_path, fname), dpi=200, bbox_inches="tight")
    # plt.close(fig)
    plt.show(fig)

In [None]:
_plot_flag("cohort")