In [1]:
import argparse
import logging
import os
import sys
import warnings
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
from pathlib import Path
from typing import Tuple

import scanpy as sc
import yaml
from anndata import AnnData
from spatialdata import read_zarr

sys.path.insert(1, "/dss/dsshome1/0C/ra98gaq/Git/cellseg-benchmark")
from cellseg_benchmark.adata_utils import (
    filter_genes,
    filter_low_quality_cells,
    filter_spatial_outlier_cells,
    integration_harmony,
    merge_adatas,
    normalize_counts,
    pca_umap_single,
)



In [2]:
# reload
from importlib import reload
import cellseg_benchmark.adata_utils as adu
reload(adu)
from cellseg_benchmark.adata_utils import (
    filter_genes,
    filter_low_quality_cells,
    filter_spatial_outlier_cells,
    integration_harmony,


    integration_harmony_new,

    merge_adatas,
    normalize_counts,
    pca_umap_single,
)

In [3]:
def _load_one(
    sample_dir: Path, seg_method: str, logger: logging.Logger
) -> Tuple[str, AnnData | None]:
    """Load AnnData from one master sdata."""
    sdata = read_zarr(sample_dir / "sdata_z3.zarr", selection=("tables",))
    if f"adata_{seg_method}" not in sdata.tables.keys():
        if logger:
            logger.warning(f"Skipping {seg_method}. No such key: {seg_method}")
        return sample_dir.name, None
    return sample_dir.name, sdata[f"adata_{seg_method}"]
    
warnings.filterwarnings("ignore", ".*The table is annotating*", UserWarning)
sc.settings.n_jobs = -1

In [4]:
# Logger setup
logger = logging.getLogger("integrate_adatas")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s]: %(message)s"))
logger.addHandler(handler)

In [5]:
# CLI args
#sys.argv = ["notebook", "SynergyLung", "Cellpose_1_Merlin"]
#sys.argv = ["notebook", "foxf2", "Cellpose_1_nuclei_model"]
#sys.argv = ["notebook", "foxf2", "Proseg_Cellpose_1_nuclei_model"]
sys.argv = ["notebook", "aging", "Cellpose_1_nuclei_model"]

parser = argparse.ArgumentParser(
    description="Integrate adatas from a selected segmentation method."
)
parser.add_argument("cohort", help="Cohort name, e.g., 'foxf2'")
parser.add_argument(
    "seg_method", help="Segmentation method, e.g., 'Cellpose_1_nuclei_model'"
)
args = parser.parse_args()
args

Namespace(cohort='aging', seg_method='Cellpose_1_nuclei_model')

In [6]:
base_path = Path("/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark")
samples_path = base_path / "samples"
save_path = base_path / "analysis" / args.cohort / args.seg_method
save_path.mkdir(parents=True, exist_ok=True)

In [7]:
sample_metadata_file, excluded = (
    yaml.safe_load(open(base_path / "misc" / f))
    for f in ["sample_metadata.yaml", "samples_excluded.yaml"]
)
sample_metadata_file

{'foxf2_s1_r0': {'cohort': 'foxf2',
  'slide': 1,
  'region': 0,
  'genotype': 'ECKO',
  'age_months': 6,
  'run_date': '20240229',
  'animal_id': '000',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240229_Foxf2-Slide01-cp-WT-ECKO/region_0-ECKO000'},
 'foxf2_s1_r1': {'cohort': 'foxf2',
  'slide': 1,
  'region': 1,
  'genotype': 'WT',
  'age_months': 6,
  'run_date': '20240229',
  'animal_id': '000',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240229_Foxf2-Slide01-cp-WT-ECKO/region_1-WT000'},
 'foxf2_s2_r1': {'cohort': 'foxf2',
  'slide': 2,
  'region': 1,
  'genotype': 'WT',
  'age_months': 6,
  'run_date': '20240322',
  'animal_id': '536',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240322_Foxf2-Slide02-cp-WT-PCKO/region_1-WT536'},
 'foxf2_s2_r2': {'cohort': 'foxf2',
  'slide': 2,
  'region': 2,
  'genotype': 'PCKO',
  '

In [8]:
excluded_samples = set(excluded.get(args.cohort, []))
excluded_samples

{'aging_s5_r0'}

In [11]:
yaml_samples = [name for name, meta in sample_metadata_file.items()
                if meta.get("cohort") == args.cohort and name not in excluded_samples]
yaml_samples

['aging_s1_r0',
 'aging_s5_r1',
 'aging_s5_r2',
 'aging_s6_r0',
 'aging_s7_r2',
 'aging_s8_r0',
 'aging_s8_r1',
 'aging_s8_r2',
 'aging_s10_r0',
 'aging_s10_r1',
 'aging_s10_r2',
 'aging_s11_r0',
 'aging_s11_r1',
 'aging_s11_r2',
 'aging_s12_r0']

In [12]:
logger.info("Loading data...")
loads = []
for name in yaml_samples:
    p = samples_path / name
    if not (p / "sdata_z3.zarr").exists():
        logger.error("master sdata in %s not found.", p)
        continue
    loads.append(p)

max_workers = int(os.getenv("SLURM_CPUS_PER_TASK", 1))
loader = partial(_load_one, seg_method=args.seg_method, logger=None)

with ProcessPoolExecutor(max_workers=max_workers) as ex:
    results = list(ex.map(loader, loads))

# keep YAML order, drop Nones
adata_list = [(name, adata) for name, adata in results if adata is not None]

2025-10-22 16:06:11,762 [INFO]: Loading data...


In [13]:
# temp fix for aging_s11_r0
#for i, (name, ad) in enumerate(adata_list):
#if name == "aging_s11_r0":
        #ad.obs["region"] = "0"
    #ad.obs["region"] = ad.obs["region"].astype(str).astype("category")
    #adata_list[i] = (name, ad)

In [24]:
adata = next((a for n, a in adata_list if n == "aging_s11_r2"), None)

In [25]:
adata.obs

Unnamed: 0,region,slide,spt_region,cell_type_incl_low_quality_revised,cell_type_mmc_incl_low_quality_clusters,cell_type_mmc_incl_low_quality,cell_type_incl_mixed_revised,cell_type_mmc_incl_mixed_clusters,cell_type_mmc_incl_mixed,cell_type_revised,...,surface_to_volume_ratio,sphericity,solidity,elongation,animal_id,run_date,organism,cohort,sample,condition
aaaaaaaa-1,2,11,boundaries_Cellpose_1_nuclei_model,Neurons-Glut,Neurons-Glut,Neurons-Glut,Neurons-Glut,Neurons-Glut,Neurons-Glut,Neurons-Glut,...,0.378232,0.905776,0.987729,0.032425,888,20250526,mouse,aging,aging_s11_r2,WT_24
aaaaaaab-1,2,11,boundaries_Cellpose_1_nuclei_model,Neurons-Dopa,Neurons-Gaba,Neurons-Gaba,Neurons-Dopa,Neurons-Gaba,Neurons-Gaba,Neurons-Glyc-Gaba,...,0.530033,0.897108,1.000000,0.260327,888,20250526,mouse,aging,aging_s11_r2,WT_24
aaaaaaac-1,2,11,boundaries_Cellpose_1_nuclei_model,Neurons-Glut,Neurons-Glut,Neurons-Glut,Neurons-Glut,Neurons-Glut,Neurons-Glut,Neurons-Glut,...,0.697929,0.959499,1.000000,0.052645,888,20250526,mouse,aging,aging_s11_r2,WT_24
aaaaaaad-1,2,11,boundaries_Cellpose_1_nuclei_model,Astrocytes,Astrocytes,Neurons-Glut,Astrocytes,Astrocytes,Neurons-Glut,Astrocytes,...,0.285597,0.974852,1.000000,0.104153,888,20250526,mouse,aging,aging_s11_r2,WT_24
aaaaaaae-1,2,11,boundaries_Cellpose_1_nuclei_model,ECs,ECs,Neurons-Glut,ECs,ECs,Neurons-Glut,ECs,...,0.309734,0.959109,1.000000,0.155527,888,20250526,mouse,aging,aging_s11_r2,WT_24
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
aaaamcbn-1,2,11,boundaries_Cellpose_1_nuclei_model,Oligodendrocytes,Oligodendrocytes,Oligodendrocytes,Oligodendrocytes,Oligodendrocytes,Oligodendrocytes,Oligodendrocytes,...,0.290311,0.948982,0.995501,0.057180,888,20250526,mouse,aging,aging_s11_r2,WT_24
aaaamcbo-1,2,11,boundaries_Cellpose_1_nuclei_model,Oligodendrocytes,Choroid-Plexus,Astrocytes,Oligodendrocytes,Oligodendrocytes,Astrocytes,Oligodendrocytes,...,0.303056,0.918315,0.989862,0.064829,888,20250526,mouse,aging,aging_s11_r2,WT_24
aaaamcbp-1,2,11,boundaries_Cellpose_1_nuclei_model,Microglia,Microglia,Microglia,Microglia,Microglia,Microglia,Microglia,...,0.306707,0.924112,0.991330,0.147285,888,20250526,mouse,aging,aging_s11_r2,WT_24
aaaamcca-1,2,11,boundaries_Cellpose_1_nuclei_model,Pericytes,Microglia,Astrocytes,Pericytes,Microglia,Astrocytes,Pericytes,...,0.215834,0.837129,0.959902,0.074454,888,20250526,mouse,aging,aging_s11_r2,WT_24


In [26]:
# Merge and process
adata = merge_adatas(
    adata_list,
    seg_method=args.seg_method,
    logger=logger,
    plot_qc_stats=True,
    save_path=save_path / "plots",
)
del adata_list

2025-10-22 16:53:30,030 [INFO]: Merging adatas of Cellpose_1_nuclei_model
100%|██████████| 15/15 [00:00<00:00, 68.59it/s]
  utils.warn_names_duplicates("obs")
2025-10-22 16:53:31,484 [INFO]: Cellpose_1_nuclei_model: #cells=719901, #samples=15
2025-10-22 16:53:31,497 [INFO]: Plotting QC results
  fig.tight_layout()
  fig.tight_layout()
  fig.tight_layout()


In [40]:
adata.obs["sample"].value_counts(dropna=False)

sample
aging_s1_r0     56230
aging_s8_r2     55444
aging_s10_r2    51587
aging_s10_r1    50706
aging_s6_r0     49699
aging_s11_r2    49698
aging_s11_r1    49384
aging_s8_r1     48405
aging_s10_r0    47920
aging_s12_r0    46963
aging_s5_r2     46360
aging_s8_r0     45872
aging_s7_r2     45655
aging_s5_r1     40329
aging_s11_r0    35649
Name: count, dtype: int64

In [39]:
adata.obs["condition"].value_counts(dropna=False)

condition
WT_3     150698
WT_24    147002
WT_6     142214
WT_18    140792
WT_12    139195
Name: count, dtype: int64

In [38]:
adata.obs["cohort"].value_counts(dropna=False)

cohort
aging    719901
Name: count, dtype: int64

In [36]:
adata.obs["age_months"].value_counts(dropna=False)

age_months
NaN     481628
6.0     142214
18.0     49699
12.0     46360
Name: count, dtype: int64

In [32]:
import pandas as pd

In [34]:
pd.crosstab(adata.obs["age_months"], adata.obs["sample"], dropna=False)

sample,aging_s10_r0,aging_s10_r1,aging_s10_r2,aging_s11_r0,aging_s11_r1,aging_s11_r2,aging_s12_r0,aging_s1_r0,aging_s5_r1,aging_s5_r2,aging_s6_r0,aging_s7_r2,aging_s8_r0,aging_s8_r1,aging_s8_r2
age_months,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
6.0,0,0,0,0,0,0,0,56230,40329,0,0,45655,0,0,0
12.0,0,0,0,0,0,0,0,0,0,46360,0,0,0,0,0
18.0,0,0,0,0,0,0,0,0,0,0,49699,0,0,0,0
,47920,50706,51587,35649,49384,49698,46963,0,0,0,0,0,45872,48405,55444


In [13]:
# workaround to fix adata.obs formatting ###############
#adata.obs["sample"] = adata.obs["sample"].str.replace(
#    rf"^{args.cohort}_(\d+)_(\d+)$", rf"{args.cohort}_s\1_r\2", regex=True
#)
#adata.obs.drop(columns=["spt_region"], inplace=True, errors="ignore")
#adata.obs["condition"] = (
#    adata.obs["genotype"].astype(str) + "_" + adata.obs["age_months"].astype(str)
#)

In [14]:
adata.obs.columns

Index(['fov', 'volume', 'center_x', 'center_y', 'min_x', 'min_y', 'max_x',
       'max_y', 'anisotropy', 'transcript_count', 'perimeter_area_ratio',
       'Txnip_raw', 'Txnip_high_pass', 'Fth1_raw', 'Fth1_high_pass',
       'DAPI_raw', 'DAPI_high_pass', 'Scgb1a1_raw', 'Scgb1a1_high_pass',
       'Sftpc_raw', 'Sftpc_high_pass', 'PolyT_raw', 'PolyT_high_pass',
       'Ifitm3_raw', 'Ifitm3_high_pass', 'region', 'slide', 'dataset_id',
       'cells_region', 'area', 'volume_sum', 'volume_final', 'num_z_planes',
       'size_normalized', 'surface_to_volume_ratio', 'sphericity', 'solidity',
       'elongation', 'condition', 'run_date', 'organism', 'cohort', 'sample',
       'n_counts', 'n_genes', 'Col1_raw', 'Col1_high_pass'],
      dtype='object')

In [15]:
adata.obsm["spatial"] = adata.obsm.get("spatial_microns", adata.obsm["spatial"])
adata = filter_spatial_outlier_cells(
    adata,
    data_dir=str(base_path),
    sample_metadata_file=sample_metadata_file,  ############
    save_path=save_path / "plots",
    logger=logger,
)

2025-10-21 12:44:39,075 [INFO]: [SynergyLung_s1_r0] Missing polygon file: /dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/SynergyLung_s1_r0/cell_outlier_coordinates.csv. Draw outliers in 10X Explorer if necessary.
2025-10-21 12:44:40,020 [INFO]: [SynergyLung_s2_r0] Missing polygon file: /dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/SynergyLung_s2_r0/cell_outlier_coordinates.csv. Draw outliers in 10X Explorer if necessary.
2025-10-21 12:44:40,899 [INFO]: [SynergyLung_s3_r0] Missing polygon file: /dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/SynergyLung_s3_r0/cell_outlier_coordinates.csv. Draw outliers in 10X Explorer if necessary.
2025-10-21 12:44:41,815 [INFO]: [SynergyLung_s3_r1] Missing polygon file: /dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/SynergyLung_s3_r1/cell_outlier_coordinates.csv. Draw outliers in 10X Explorer if necessary.
2025-10-21 12:44:42,821 [INFO]: [SynergyLung_s4_r0] Missing polygon file: /dss/dssfs03/p

In [22]:
args.cohort == 'SynergyLung'

True

In [26]:
if "vpt_3D" in args.seg_method: # min_counts=10 due to smaller cell sizes
    min_counts = 10
elif args.cohort == "SynergyLung": # more lenient for initial analysis
    min_counts = 15
else:
    min_counts = None # default = 25

In [27]:
min_counts

15

In [28]:
adata = filter_low_quality_cells(
    adata,
    save_path=save_path / "plots",
    **({"min_counts": min_counts} if min_counts is not None else {}),
    logger=logger,
)

adata = filter_genes(adata, save_path=save_path / "plots", logger=logger)

adata = normalize_counts(
    adata, save_path=save_path / "plots", seg_method=args.seg_method, logger=logger
)

2025-10-21 12:47:55,598 [INFO]: # total cells before filtering low-quality or volume outliers: 1226739
2025-10-21 12:47:55,599 [INFO]: # low_quality_cell:             208207
2025-10-21 12:47:55,600 [INFO]: # volume_outlier_cell:       2173
2025-10-21 12:48:04,943 [INFO]: # total cells after filtering low-quality or volume outliers:  1016494
2025-10-21 12:48:04,944 [INFO]: # genes before filtering: 451
2025-10-21 12:48:07,888 [INFO]: # genes after filtering: 451
2025-10-21 12:48:14,323 [INFO]: Normalizing counts...
2025-10-21 12:48:16,720 [INFO]: Cells before/after outlier removal during normalization: 1016494 -> 996164
  return dispatch(args[0].__class__)(*args, **kw)


In [29]:
adata = pca_umap_single(adata, save_path=save_path / "plots", logger=logger)

2025-10-21 12:48:56,046 [INFO]: Dimensionality reduction: PCA
2025-10-21 12:55:04,226 [INFO]: Dimensionality reduction: UMAP with n_neighbors=20, n_pcs=50


In [35]:
!mamba install bioconda::harmonypy -y


Looking for: ['bioconda::harmonypy']

[?25l[2K[0G[+] 0.0s
[2K[1A[2K[0G[+] 0.1s
bioconda/linux-64 (check zst) [33m━━━━━━━━━━━━━╸[0m[90m━[0m   0.0 B @  ??.?MB/s Checking  0.1s[2K[1A[2K[0Gbioconda/linux-64 (check zst)                       Checked  0.1s
[?25h[?25l[2K[0G[+] 0.0s
bioconda/noarch (check zst) [90m━━━━━━━╸[0m[33m━━━━━━━━━[0m   0.0 B @  ??.?MB/s Checking  0.0s[2K[1A[2K[0Gbioconda/noarch (check zst)                         Checked  0.1s
[?25h[?25l[2K[0G[+] 0.0s
conda-forge/linux-64 [90m━━━━━━━━━━━━╸[0m[33m━━━━━━━━━━[0m   0.0 B /  ??.?MB @  ??.?MB/s  0.0s[2K[1A[2K[0G[+] 0.1s
conda-forge/linux-64 [90m━━━━━━━━━━━━━━━━━━━━━━━[0m  14.1kB /  48.0MB @ 232.4kB/s  0.1s
conda-forge/noarch   [90m━━━━━━━━━━╸[0m[33m━━━━━━━━━━━━[0m   0.0 B /  ??.?MB @  ??.?MB/s  0.1s
bioconda/linux-64    [90m━━━━━━╸[0m[33m━━━━━━━━━━━━━━━━[0m   0.0 B /  ??.?MB @  ??.?MB/s  0.1s
bioconda/noarch      [33m━━━━━━━━━━━━╸[0m[90m━━━━━━━━━━[0m   0.0 B /  ??.?MB @  

In [36]:
# Save result
logger.info("Saving integrated object...")
output_path = save_path / "adatas"
output_path.mkdir(parents=True, exist_ok=True)
adata.write(output_path / "adata_TEMP.h5ad.gz", compression="gzip")
logger.info("Done.")

2025-10-21 14:27:17,915 [INFO]: Saving integrated object...


ERROR! Session/line number was not unique in database. History logging moved to new session 14


2025-10-21 14:29:34,581 [INFO]: Done.


In [11]:
output_path = save_path / "adatas"
output_path / "adata_TEMP.h5ad.gz"

PosixPath('/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/analysis/SynergyLung/Cellpose_1_Merlin/adatas/adata_TEMP.h5ad.gz')

In [12]:
adata = sc.read_h5ad(output_path / "adata_TEMP.h5ad.gz")

In [13]:
adata

AnnData object with n_obs × n_vars = 996164 × 451
    obs: 'fov', 'volume', 'center_x', 'center_y', 'min_x', 'min_y', 'max_x', 'max_y', 'anisotropy', 'transcript_count', 'perimeter_area_ratio', 'Txnip_raw', 'Txnip_high_pass', 'Fth1_raw', 'Fth1_high_pass', 'DAPI_raw', 'DAPI_high_pass', 'Scgb1a1_raw', 'Scgb1a1_high_pass', 'Sftpc_raw', 'Sftpc_high_pass', 'PolyT_raw', 'PolyT_high_pass', 'Ifitm3_raw', 'Ifitm3_high_pass', 'region', 'slide', 'dataset_id', 'cells_region', 'area', 'volume_sum', 'volume_final', 'num_z_planes', 'size_normalized', 'surface_to_volume_ratio', 'sphericity', 'solidity', 'elongation', 'condition', 'run_date', 'organism', 'cohort', 'sample', 'n_counts', 'n_genes', 'Col1_raw', 'Col1_high_pass', 'spatial_outlier', 'low_quality_cell', 'volume_outlier_cell'
    var: 'n_cells'
    uns: 'X_umap_20_50', 'pca'
    obsm: 'X_pca', 'X_umap_20_50', 'blank', 'spatial', 'spatial_microns', 'spatial_pixel'
    varm: 'PCs'
    layers: 'counts', 'librarysize_log1p_norm', 'volume_log1p_no

In [14]:
adata = integration_harmony_new(
    adata,
    batch_key="slide",
    save_path=save_path / "plots",
    logger=logger,
)

2025-10-21 14:36:24,818 [INFO]: Integration: Run Harmony
2025-10-21 14:36:27,278 - harmonypy - INFO - Computing initial centroids with sklearn.KMeans...
2025-10-21 14:39:50,480 - harmonypy - INFO - sklearn.KMeans initialization complete.
2025-10-21 14:40:02,229 - harmonypy - INFO - Iteration 1 of 10
2025-10-21 14:55:51,017 - harmonypy - INFO - Iteration 2 of 10
2025-10-21 15:11:31,129 - harmonypy - INFO - Iteration 3 of 10
2025-10-21 15:27:01,647 - harmonypy - INFO - Converged after 3 iterations
2025-10-21 15:27:01,907 [INFO]: Integration: Compute neighbors and UMAP


In [None]:
# Save result
logger.info("Saving integrated object...")
if "fov" not in adata.obs.columns:
    adata.obs["fov"] = ""
adata.obs["fov"] = adata.obs["fov"].astype(str)
output_path = save_path / "adatas"
output_path.mkdir(parents=True, exist_ok=True)
adata.write(output_path / "adata_integrated.h5ad.gz", compression="gzip")
logger.info("Done.")

2025-10-21 17:26:50,297 [INFO]: Saving integrated object...


In [10]:
###############################################
# load adata post hoc
output_path = save_path / "adatas"
adata = sc.read_h5ad(output_path / "adata_integrated.h5ad.gz")

In [11]:
adata.shape

(996164, 451)

In [12]:
adata

AnnData object with n_obs × n_vars = 996164 × 451
    obs: 'fov', 'volume', 'center_x', 'center_y', 'min_x', 'min_y', 'max_x', 'max_y', 'anisotropy', 'transcript_count', 'perimeter_area_ratio', 'Txnip_raw', 'Txnip_high_pass', 'Fth1_raw', 'Fth1_high_pass', 'DAPI_raw', 'DAPI_high_pass', 'Scgb1a1_raw', 'Scgb1a1_high_pass', 'Sftpc_raw', 'Sftpc_high_pass', 'PolyT_raw', 'PolyT_high_pass', 'Ifitm3_raw', 'Ifitm3_high_pass', 'region', 'slide', 'dataset_id', 'cells_region', 'area', 'volume_sum', 'volume_final', 'num_z_planes', 'size_normalized', 'surface_to_volume_ratio', 'sphericity', 'solidity', 'elongation', 'condition', 'run_date', 'organism', 'cohort', 'sample', 'n_counts', 'n_genes', 'Col1_raw', 'Col1_high_pass', 'spatial_outlier', 'low_quality_cell', 'volume_outlier_cell'
    var: 'n_cells'
    obsm: 'X_pca', 'X_pca_harmony', 'X_umap_20_50', 'X_umap_harmony_20_50', 'X_umap_harmony_20_50_3d', 'blank', 'spatial', 'spatial_microns', 'spatial_pixel'
    varm: 'PCs'
    layers: 'counts', 'libr

In [32]:
#######
# spatial plots post-hoc plotting
import math
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
def _plot_flag(flag):
        n_cols = 3
        samples = adata.obs["sample"].unique()
        n_samples = len(samples)
        n_rows = math.ceil(n_samples / n_cols)

        fig, axes = plt.subplots(
            n_rows,
            n_cols,
            figsize=(4 * n_cols, 4 * n_rows),
            squeeze=False,
            sharex=False,
            sharey=False,
        )

        for idx, sample in enumerate(samples):
            ax = axes[idx // n_cols][idx % n_cols]

            mask = adata.obs["sample"] == sample
            sample_data = adata[mask]

            if sample_data.n_obs > int(1.5e5):
                sample_data = sc.pp.subsample(
                    sample_data, n_obs=int(1.5e5), random_state=42, copy=True
                )

            coords = sample_data.obsm["spatial"]

            outliers = sample_data.obs[flag].values
            colors = pd.Categorical(
                np.where(outliers, "grey", "lightgrey"), categories=["lightgrey", "grey"]
            )

            ax.scatter(
                coords[:, 0],
                coords[:, 1],
                c=colors,
                s=max(0.3, min(0.7, 30000 / len(coords))),
                alpha=0.75,
                edgecolors="none",
            )
            ax.set_title(sample, fontsize=10)

            ax.set_xticks([])
            ax.set_yticks([])
            for spine in ax.spines.values():
                spine.set_visible(False)

        # Hide unused subplots
        for ax in axes.flat[n_samples:]:
            ax.set_visible(False)

        fig.suptitle(f"{flag.replace('_', ' ').title()} by sample", fontsize=14, y=1)
        fig.tight_layout()
        #fig.savefig(join(save_path, fname), dpi=200, bbox_inches="tight")
        #plt.close(fig)
        plt.show(fig)

In [None]:
_plot_flag("cohort")