In [1]:
import argparse
import logging
import os
import sys
import warnings
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from pathlib import Path
from typing import Tuple

import scanpy as sc
import yaml
from anndata import AnnData
from spatialdata import read_zarr

sys.path.insert(1, "/dss/dsshome1/0C/ra98gaq/Git/cellseg-benchmark")
from cellseg_benchmark.adata_utils import (
    filter_genes,
    filter_low_quality_cells,
    filter_spatial_outlier_cells,
    merge_adatas,
    normalize_counts,
    pca_umap_single,
)



In [2]:
# reload
from importlib import reload

import cellseg_benchmark.adata_utils as adu

reload(adu)
from cellseg_benchmark.adata_utils import (
    integration_harmony,
) # noqa: E402

In [3]:
def _load_one(
    sample_dir: Path, seg_method: str, logger: logging.Logger
) -> Tuple[str, AnnData | None]:
    """Load AnnData from one master sdata."""
    sdata = read_zarr(sample_dir / "sdata_z3.zarr", selection=("tables",))
    if f"adata_{seg_method}" not in sdata.tables.keys():
        if logger:
            logger.warning(f"Skipping {seg_method}. No such key: {seg_method}")
        return sample_dir.name, None
    return sample_dir.name, sdata[f"adata_{seg_method}"]


warnings.filterwarnings("ignore", ".*The table is annotating*", UserWarning)
sc.settings.n_jobs = -1

In [4]:
# Logger setup
logger = logging.getLogger("integrate_adatas")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s]: %(message)s"))
logger.addHandler(handler)

In [5]:
# CLI args
# sys.argv = ["notebook", "SynergyLung", "Cellpose_1_Merlin"]
sys.argv = ["notebook", "foxf2", "Cellpose_1_nuclei_model"]
# sys.argv = ["notebook", "foxf2", "Proseg_Cellpose_1_nuclei_model"]
#sys.argv = ["notebook", "aging", "Cellpose_1_nuclei_model"]

parser = argparse.ArgumentParser(
    description="Integrate adatas from a selected segmentation method."
)
parser.add_argument("cohort", help="Cohort name, e.g., 'foxf2'")
parser.add_argument(
    "seg_method", help="Segmentation method, e.g., 'Cellpose_1_nuclei_model'"
)
args = parser.parse_args()
args

Namespace(cohort='foxf2', seg_method='Cellpose_1_nuclei_model')

In [6]:
base_path = Path("/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark")
samples_path = base_path / "samples"
save_path = base_path / "analysis" / args.cohort / args.seg_method
save_path.mkdir(parents=True, exist_ok=True)

In [7]:
sample_metadata_file, excluded = (
    yaml.safe_load(open(base_path / "misc" / f))
    for f in ["sample_metadata.yaml", "samples_excluded.yaml"]
)
sample_metadata_file

{'foxf2_s1_r0': {'cohort': 'foxf2',
  'slide': 1,
  'region': 0,
  'genotype': 'ECKO',
  'age_months': 6,
  'run_date': '20240229',
  'animal_id': '000',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240229_Foxf2-Slide01-cp-WT-ECKO/region_0-ECKO000'},
 'foxf2_s1_r1': {'cohort': 'foxf2',
  'slide': 1,
  'region': 1,
  'genotype': 'WT',
  'age_months': 6,
  'run_date': '20240229',
  'animal_id': '000',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240229_Foxf2-Slide01-cp-WT-ECKO/region_1-WT000'},
 'foxf2_s2_r1': {'cohort': 'foxf2',
  'slide': 2,
  'region': 1,
  'genotype': 'WT',
  'age_months': 6,
  'run_date': '20240322',
  'animal_id': '536',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240322_Foxf2-Slide02-cp-WT-PCKO/region_1-WT536'},
 'foxf2_s2_r2': {'cohort': 'foxf2',
  'slide': 2,
  'region': 2,
  'genotype': 'PCKO',
  '

In [8]:
excluded_samples = set(excluded.get(args.cohort, []))
excluded_samples

{'foxf2_s2_r0', 'foxf2_s3_r0', 'foxf2_s3_r1'}

In [9]:
yaml_samples = [
    name
    for name, meta in sample_metadata_file.items()
    if meta.get("cohort") == args.cohort and name not in excluded_samples
]
yaml_samples

['foxf2_s1_r0',
 'foxf2_s1_r1',
 'foxf2_s2_r1',
 'foxf2_s2_r2',
 'foxf2_s4_r0',
 'foxf2_s4_r1',
 'foxf2_s5_r0',
 'foxf2_s5_r1',
 'foxf2_s6_r0',
 'foxf2_s6_r1',
 'foxf2_s6_r2',
 'foxf2_s7_r0',
 'foxf2_s7_r1']

In [10]:
%%time
logger.info("Loading data...")
loads = []
for name in yaml_samples:
    p = samples_path / name
    if not (p / "sdata_z3.zarr").exists():
        logger.error("master sdata in %s not found.", p)
        continue
    loads.append(p)

max_workers = int(os.getenv("SLURM_CPUS_PER_TASK", 1))
loader = partial(_load_one, seg_method=args.seg_method, logger=None)

with ProcessPoolExecutor(max_workers=max_workers) as ex:
    results = list(ex.map(loader, loads))

# keep YAML order, drop Nones
adata_list = [(name, adata) for name, adata in results if adata is not None]

2025-11-03 13:40:05,347 [INFO]: Loading data...


CPU times: user 1.77 s, sys: 3.84 s, total: 5.61 s
Wall time: 9min 37s


In [17]:
adata_list

[('foxf2_s1_r0',
  AnnData object with n_obs × n_vars = 46814 × 500
      obs: 'region', 'slide', 'spt_region', 'cell_type_incl_low_quality_revised', 'cell_type_mmc_incl_low_quality_clusters', 'cell_type_mmc_incl_low_quality', 'cell_type_incl_mixed_revised', 'cell_type_mmc_incl_mixed_clusters', 'cell_type_mmc_incl_mixed', 'cell_type_revised', 'cell_type_mmc_raw_clusters', 'cell_type_mmc_raw', 'cell_id', 'area', 'volume_sum', 'volume_final', 'num_z_planes', 'size_normalized', 'surface_to_volume_ratio', 'sphericity', 'solidity', 'elongation', 'genotype', 'age_months', 'condition', 'run_date', 'animal_id', 'organism', 'cohort', 'sample'
      uns: 'sopa_attrs', 'spatialdata_attrs'
      obsm: 'Ovrlpy_stats', 'ficture_area', 'ficture_means_weight', 'ficture_vars_weight', 'intensities', 'spatial', 'spatial_microns', 'spatial_pixel'),
 ('foxf2_s1_r1',
  AnnData object with n_obs × n_vars = 42513 × 500
      obs: 'region', 'slide', 'spt_region', 'cell_type_incl_low_quality_revised', 'cell_typ

In [11]:
import pandas as pd

In [12]:
# temp fix for aging_s11_r0
for i, (n, ad) in enumerate(adata_list):
    #if n == "aging_s11_r0":
    if "aging" in n:
        m = sample_metadata_file[n]
        for k, v in {**m, "sample": n, "condition": f"{m['genotype']}_{m['age_months']}"} .items():
            ad.obs[k] = pd.Categorical([str(v)] * len(ad))
        adata_list[i] = (n, ad)

In [18]:
#adata_tmp = next((a for n, a in adata_list if n == "aging_s11_r0"), None)

In [19]:
# Merge and process
adata = merge_adatas(
    adata_list,
    seg_method=args.seg_method,
    logger=logger,
    plot_qc_stats=True,
    save_path=save_path / "plots",
)
del adata_list

2025-11-03 13:58:40,433 [INFO]: Merging adatas of Cellpose_1_nuclei_model
100%|██████████| 13/13 [00:00<00:00, 37.17it/s]
  utils.warn_names_duplicates("obs")
2025-11-03 13:58:42,599 [INFO]: Cellpose_1_nuclei_model: #cells=595581, #samples=13
2025-11-03 13:58:42,606 [INFO]: Plotting QC results
  fig.tight_layout()
  fig.tight_layout()
  fig.tight_layout()


In [20]:
adata.obs["sample"].value_counts(dropna=False)

sample
foxf2_s5_r1    58868
foxf2_s5_r0    56940
foxf2_s4_r0    49649
foxf2_s2_r1    47686
foxf2_s6_r2    47030
foxf2_s1_r0    46814
foxf2_s2_r2    45477
foxf2_s6_r1    44109
foxf2_s1_r1    42513
foxf2_s7_r0    42204
foxf2_s7_r1    41277
foxf2_s6_r0    37142
foxf2_s4_r1    35872
Name: count, dtype: int64

In [21]:
adata.obs["condition"].value_counts(dropna=False)

condition
WT_6      179433
PCKO_6    152066
ECKO_6    141554
GLKO_6    122528
Name: count, dtype: int64

In [22]:
adata.obs["cohort"].value_counts(dropna=False)

cohort
foxf2    595581
Name: count, dtype: int64

In [23]:
adata.obs["age_months"].value_counts(dropna=False)

age_months
6    595581
Name: count, dtype: int64

In [24]:
import pandas as pd

In [25]:
pd.crosstab(adata.obs["age_months"], adata.obs["sample"], dropna=False)

sample,foxf2_s1_r0,foxf2_s1_r1,foxf2_s2_r1,foxf2_s2_r2,foxf2_s4_r0,foxf2_s4_r1,foxf2_s5_r0,foxf2_s5_r1,foxf2_s6_r0,foxf2_s6_r1,foxf2_s6_r2,foxf2_s7_r0,foxf2_s7_r1
age_months,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
6,46814,42513,47686,45477,49649,35872,56940,58868,37142,44109,47030,42204,41277


In [26]:
# workaround to fix adata.obs formatting ###############
# adata.obs["sample"] = adata.obs["sample"].str.replace(
#    rf"^{args.cohort}_(\d+)_(\d+)$", rf"{args.cohort}_s\1_r\2", regex=True
# )
# adata.obs.drop(columns=["spt_region"], inplace=True, errors="ignore")
# adata.obs["condition"] = (
#    adata.obs["genotype"].astype(str) + "_" + adata.obs["age_months"].astype(str)
# )

In [27]:
adata.obs.columns

Index(['region', 'slide', 'cell_type_incl_low_quality_revised',
       'cell_type_mmc_incl_low_quality_clusters',
       'cell_type_mmc_incl_low_quality', 'cell_type_incl_mixed_revised',
       'cell_type_mmc_incl_mixed_clusters', 'cell_type_mmc_incl_mixed',
       'cell_type_revised', 'cell_type_mmc_raw_clusters', 'cell_type_mmc_raw',
       'area', 'volume_sum', 'volume_final', 'num_z_planes', 'size_normalized',
       'surface_to_volume_ratio', 'sphericity', 'solidity', 'elongation',
       'genotype', 'age_months', 'condition', 'run_date', 'animal_id',
       'organism', 'cohort', 'sample', 'n_counts', 'n_genes'],
      dtype='object')

In [28]:
adata.obsm["spatial"] = adata.obsm.get("spatial_microns", adata.obsm["spatial"])
adata = filter_spatial_outlier_cells(
    adata,
    data_dir=str(base_path),
    sample_metadata_file=sample_metadata_file,
    save_path=save_path / "plots",
    logger=logger,
)

2025-11-03 14:02:10,233 [INFO]: # total cells before filtering spatial outliers: 595581
2025-11-03 14:02:10,234 [INFO]: # spatial_outlier:       13221
2025-11-03 14:02:12,517 [INFO]: # total cells after filtering spatial outliers:  582360


In [29]:
#args.cohort == "SynergyLung"

In [30]:
if "vpt_3D" in args.seg_method:  # min_counts=10 due to smaller cell sizes
    min_counts = 10
elif args.cohort == "SynergyLung":  # more lenient for initial analysis
    min_counts = 15
else:
    min_counts = None  # default = 25

In [31]:
adata = filter_low_quality_cells(
    adata,
    save_path=save_path / "plots",
    **({"min_counts": min_counts} if min_counts is not None else {}),
    logger=logger,
)

adata = filter_genes(adata, save_path=save_path / "plots", logger=logger)

adata = normalize_counts(
    adata, save_path=save_path / "plots", seg_method=args.seg_method, logger=logger
)

2025-11-03 14:02:27,441 [INFO]: # total cells before filtering low-quality or volume outliers: 582360
2025-11-03 14:02:27,442 [INFO]: # low_quality_cell:             5003
2025-11-03 14:02:27,443 [INFO]: # volume_outlier_cell:       66
2025-11-03 14:02:35,263 [INFO]: # total cells after filtering low-quality or volume outliers:  577291
2025-11-03 14:02:35,264 [INFO]: # genes before filtering: 500
2025-11-03 14:02:38,177 [INFO]: # genes after filtering: 500
2025-11-03 14:02:40,144 [INFO]: Normalizing counts...
2025-11-03 14:02:42,293 [INFO]: Cells before/after outlier removal during normalization: 577291 -> 565745
  return dispatch(args[0].__class__)(*args, **kw)


In [32]:
adata = pca_umap_single(adata, save_path=save_path / "plots", logger=logger)

2025-11-03 14:02:59,230 [INFO]: Dimensionality reduction: PCA
2025-11-03 14:03:37,636 [INFO]: Dimensionality reduction: UMAP with n_neighbors=20, n_pcs=50


In [33]:
#!mamba install bioconda::harmonypy -y

In [34]:
adata.obs.columns

Index(['region', 'slide', 'cell_type_incl_low_quality_revised',
       'cell_type_mmc_incl_low_quality_clusters',
       'cell_type_mmc_incl_low_quality', 'cell_type_incl_mixed_revised',
       'cell_type_mmc_incl_mixed_clusters', 'cell_type_mmc_incl_mixed',
       'cell_type_revised', 'cell_type_mmc_raw_clusters', 'cell_type_mmc_raw',
       'area', 'volume_sum', 'volume_final', 'num_z_planes', 'size_normalized',
       'surface_to_volume_ratio', 'sphericity', 'solidity', 'elongation',
       'genotype', 'age_months', 'condition', 'run_date', 'animal_id',
       'organism', 'cohort', 'sample', 'n_counts', 'n_genes',
       'spatial_outlier', 'low_quality_cell', 'volume_outlier_cell'],
      dtype='object')

In [35]:
adata = integration_harmony(
    adata,
    batch_key="slide",
    save_path=save_path / "plots",
    logger=logger,
)

2025-11-03 14:18:33,977 [INFO]: Integration: Run Harmony
2025-11-03 14:18:34,734 - harmonypy - INFO - Computing initial centroids with sklearn.KMeans...
2025-11-03 14:19:29,114 - harmonypy - INFO - sklearn.KMeans initialization complete.
2025-11-03 14:19:32,269 - harmonypy - INFO - Iteration 1 of 10
2025-11-03 14:23:34,227 - harmonypy - INFO - Iteration 2 of 10
2025-11-03 14:27:35,563 - harmonypy - INFO - Iteration 3 of 10
2025-11-03 14:30:53,071 - harmonypy - INFO - Iteration 4 of 10
2025-11-03 14:32:30,295 - harmonypy - INFO - Iteration 5 of 10
2025-11-03 14:33:44,819 - harmonypy - INFO - Iteration 6 of 10
2025-11-03 14:34:59,317 - harmonypy - INFO - Iteration 7 of 10
2025-11-03 14:36:13,848 - harmonypy - INFO - Iteration 8 of 10
2025-11-03 14:37:28,436 - harmonypy - INFO - Iteration 9 of 10
2025-11-03 14:38:42,973 - harmonypy - INFO - Converged after 9 iterations
2025-11-03 14:38:43,076 [INFO]: Integration: Compute neighbors and UMAP


In [36]:
logger.info("Saving integrated object...")
if "fov" not in adata.obs.columns:
    adata.obs["fov"] = ""
adata.obs["fov"] = adata.obs["fov"].astype(str)
output_path = save_path / "adatas"
output_path.mkdir(parents=True, exist_ok=True)
adata.write(output_path / "adata_integrated.h5ad.gz", compression="gzip")

2025-11-03 15:08:12,008 [INFO]: Saving integrated object...


In [10]:
###############################################
# load adata post hoc
output_path = save_path / "adatas"
adata = sc.read_h5ad(output_path / "adata_integrated.h5ad.gz")

In [11]:
adata.shape

(996164, 451)

In [12]:
adata

AnnData object with n_obs × n_vars = 996164 × 451
    obs: 'fov', 'volume', 'center_x', 'center_y', 'min_x', 'min_y', 'max_x', 'max_y', 'anisotropy', 'transcript_count', 'perimeter_area_ratio', 'Txnip_raw', 'Txnip_high_pass', 'Fth1_raw', 'Fth1_high_pass', 'DAPI_raw', 'DAPI_high_pass', 'Scgb1a1_raw', 'Scgb1a1_high_pass', 'Sftpc_raw', 'Sftpc_high_pass', 'PolyT_raw', 'PolyT_high_pass', 'Ifitm3_raw', 'Ifitm3_high_pass', 'region', 'slide', 'dataset_id', 'cells_region', 'area', 'volume_sum', 'volume_final', 'num_z_planes', 'size_normalized', 'surface_to_volume_ratio', 'sphericity', 'solidity', 'elongation', 'condition', 'run_date', 'organism', 'cohort', 'sample', 'n_counts', 'n_genes', 'Col1_raw', 'Col1_high_pass', 'spatial_outlier', 'low_quality_cell', 'volume_outlier_cell'
    var: 'n_cells'
    obsm: 'X_pca', 'X_pca_harmony', 'X_umap_20_50', 'X_umap_harmony_20_50', 'X_umap_harmony_20_50_3d', 'blank', 'spatial', 'spatial_microns', 'spatial_pixel'
    varm: 'PCs'
    layers: 'counts', 'libr

In [32]:
#######
# spatial plots post-hoc plotting
import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc


def _plot_flag(flag):
    n_cols = 3
    samples = adata.obs["sample"].unique()
    n_samples = len(samples)
    n_rows = math.ceil(n_samples / n_cols)

    fig, axes = plt.subplots(
        n_rows,
        n_cols,
        figsize=(4 * n_cols, 4 * n_rows),
        squeeze=False,
        sharex=False,
        sharey=False,
    )

    for idx, sample in enumerate(samples):
        ax = axes[idx // n_cols][idx % n_cols]

        mask = adata.obs["sample"] == sample
        sample_data = adata[mask]

        if sample_data.n_obs > int(1.5e5):
            sample_data = sc.pp.subsample(
                sample_data, n_obs=int(1.5e5), random_state=42, copy=True
            )

        coords = sample_data.obsm["spatial"]

        outliers = sample_data.obs[flag].values
        colors = pd.Categorical(
            np.where(outliers, "grey", "lightgrey"), categories=["lightgrey", "grey"]
        )

        ax.scatter(
            coords[:, 0],
            coords[:, 1],
            c=colors,
            s=max(0.3, min(0.7, 30000 / len(coords))),
            alpha=0.75,
            edgecolors="none",
        )
        ax.set_title(sample, fontsize=10)

        ax.set_xticks([])
        ax.set_yticks([])
        for spine in ax.spines.values():
            spine.set_visible(False)

    # Hide unused subplots
    for ax in axes.flat[n_samples:]:
        ax.set_visible(False)

    fig.suptitle(f"{flag.replace('_', ' ').title()} by sample", fontsize=14, y=1)
    fig.tight_layout()
    # fig.savefig(join(save_path, fname), dpi=200, bbox_inches="tight")
    # plt.close(fig)
    plt.show(fig)

In [None]:
_plot_flag("cohort")