In [1]:
import argparse
import logging
import os
import sys
import warnings
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Tuple

import scanpy as sc
import yaml
from anndata import AnnData
from spatialdata import read_zarr

sys.path.insert(1, "/dss/dsshome1/0C/ra98gaq/Git/cellseg-benchmark")
from cellseg_benchmark.adata_utils import (
    filter_genes,
    filter_low_quality_cells,
    filter_spatial_outlier_cells,
    integration_harmony,
    merge_adatas,
    normalize_counts,
    pca_umap_single,
)



In [2]:
# reload
from importlib import reload
import cellseg_benchmark.adata_utils as adu
reload(adu)
from cellseg_benchmark.adata_utils import (
    filter_genes,
    filter_low_quality_cells,
    filter_spatial_outlier_cells,
    integration_harmony,


    integration_harmony_new,

    merge_adatas,
    normalize_counts,
    pca_umap_single,
)

In [4]:
def _load_one(
    sample_dir: Path, seg_method: str, logger: logging.Logger
) -> Tuple[str, AnnData | None]:
    """Load AnnData from one master sdata."""
    sdata = read_zarr(sample_dir / "sdata_z3.zarr", selection=("tables",))
    if f"adata_{seg_method}" not in sdata.tables.keys():
        if logger:
            logger.warning(f"Skipping {seg_method}. No such key: {seg_method}")
        return sample_dir.name, None
    return sample_dir.name, sdata[f"adata_{seg_method}"]
    
warnings.filterwarnings("ignore", ".*The table is annotating*", UserWarning)
sc.settings.n_jobs = -1

In [5]:
# Logger setup
logger = logging.getLogger("integrate_adatas")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s]: %(message)s"))
logger.addHandler(handler)

In [6]:
# CLI args
sys.argv = ["notebook", "SynergyLung", "Cellpose_1_Merlin"]
#sys.argv = ["notebook", "foxf2", "Cellpose_1_nuclei_model"]
# sys.argv = ["notebook", "foxf2", "Proseg_Cellpose_1_nuclei_model"]
# sys.argv = ["notebook", "aging", "Cellpose_1_nuclei_model"]

parser = argparse.ArgumentParser(
    description="Integrate adatas from a selected segmentation method."
)
parser.add_argument("cohort", help="Cohort name, e.g., 'foxf2'")
parser.add_argument(
    "seg_method", help="Segmentation method, e.g., 'Cellpose_1_nuclei_model'"
)
args = parser.parse_args()
args

Namespace(cohort='SynergyLung', seg_method='Cellpose_1_Merlin')

In [7]:
base_path = Path("/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark")
samples_path = base_path / "samples"
save_path = base_path / "analysis" / args.cohort / args.seg_method
save_path.mkdir(parents=True, exist_ok=True)

In [8]:
sample_metadata_file, excluded = (
    yaml.safe_load(open(base_path / "misc" / f))
    for f in ["sample_metadata.yaml", "samples_excluded.yaml"]
)
sample_metadata_file

{'foxf2_s1_r0': {'cohort': 'foxf2',
  'slide': 1,
  'region': 0,
  'genotype': 'ECKO',
  'age_months': 6,
  'run_date': '20240229',
  'animal_id': '000',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240229_Foxf2-Slide01-cp-WT-ECKO/region_0-ECKO000'},
 'foxf2_s1_r1': {'cohort': 'foxf2',
  'slide': 1,
  'region': 1,
  'genotype': 'WT',
  'age_months': 6,
  'run_date': '20240229',
  'animal_id': '000',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240229_Foxf2-Slide01-cp-WT-ECKO/region_1-WT000'},
 'foxf2_s2_r1': {'cohort': 'foxf2',
  'slide': 2,
  'region': 1,
  'genotype': 'WT',
  'age_months': 6,
  'run_date': '20240322',
  'animal_id': '536',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240322_Foxf2-Slide02-cp-WT-PCKO/region_1-WT536'},
 'foxf2_s2_r2': {'cohort': 'foxf2',
  'slide': 2,
  'region': 2,
  'genotype': 'PCKO',
  '

In [8]:
excluded_samples = set(excluded.get(args.cohort, []))
excluded_samples

set()

In [9]:
yaml_samples = [name for name, meta in sample_metadata_file.items()
                if meta.get("cohort") == args.cohort and name not in excluded]
yaml_samples

['SynergyLung_s1_r0',
 'SynergyLung_s2_r0',
 'SynergyLung_s3_r0',
 'SynergyLung_s3_r1',
 'SynergyLung_s4_r0',
 'SynergyLung_s4_r1',
 'SynergyLung_s5_r0',
 'SynergyLung_s5_r1']

In [7]:
from functools import partial
logger.info("Loading data...")
loads = []
for name in yaml_samples:
    p = samples_path / name
    if not (p / "sdata_z3.zarr").exists():
        logger.error("master sdata in %s not found.", p)
        continue
    loads.append(p)

max_workers = int(os.getenv("SLURM_CPUS_PER_TASK", 1))
loader = partial(_load_one, seg_method=args.seg_method, logger=None)

with ProcessPoolExecutor(max_workers=max_workers) as ex:
    results = list(ex.map(loader, loads))

# keep YAML order, drop Nones
adata_list = [(name, adata) for name, adata in results if adata is not None]

2025-10-21 12:41:44,858 [INFO]: Loading data...


In [8]:
# temp fix for aging_s11_r0
#for i, (name, ad) in enumerate(adata_list):
#if name == "aging_s11_r0":
        #ad.obs["region"] = "0"
    #ad.obs["region"] = ad.obs["region"].astype(str).astype("category")
    #adata_list[i] = (name, ad)

In [9]:
# Merge and process
adata = merge_adatas(
    adata_list,
    seg_method=args.seg_method,
    logger=logger,
    plot_qc_stats=True,
    save_path=save_path / "plots",
)
del adata_list

2025-10-21 12:42:02,532 [INFO]: Merging adatas of Cellpose_1_Merlin
100%|██████████| 8/8 [00:06<00:00,  1.24it/s]
2025-10-21 12:42:10,035 [INFO]: Cellpose_1_Merlin: #cells=1226739, #samples=8
2025-10-21 12:42:10,037 [INFO]: Plotting QC results
  fig.tight_layout()
  fig.tight_layout()
  fig.tight_layout()
  return fn(*args_all, **kw)
  return fn(*args_all, **kw)
  return fn(*args_all, **kw)
  return fn(*args_all, **kw)
  return fn(*args_all, **kw)
  return fn(*args_all, **kw)
  return fn(*args_all, **kw)
  return fn(*args_all, **kw)


In [10]:
adata.obs["sample"].value_counts()

sample
SynergyLung_s2_r0    172011
SynergyLung_s4_r0    162189
SynergyLung_s5_r0    160754
SynergyLung_s3_r0    158137
SynergyLung_s4_r1    157625
SynergyLung_s5_r1    151416
SynergyLung_s3_r1    151232
SynergyLung_s1_r0    113375
Name: count, dtype: int64

In [11]:
adata.obs["condition"].value_counts()

condition
control    653091
stroke     573648
Name: count, dtype: int64

In [12]:
adata.obs["cohort"].value_counts()

cohort
SynergyLung    1226739
Name: count, dtype: int64

In [13]:
# workaround to fix adata.obs formatting ###############
#adata.obs["sample"] = adata.obs["sample"].str.replace(
#    rf"^{args.cohort}_(\d+)_(\d+)$", rf"{args.cohort}_s\1_r\2", regex=True
#)
#adata.obs.drop(columns=["spt_region"], inplace=True, errors="ignore")
#adata.obs["condition"] = (
#    adata.obs["genotype"].astype(str) + "_" + adata.obs["age_months"].astype(str)
#)

In [14]:
adata.obs.columns

Index(['fov', 'volume', 'center_x', 'center_y', 'min_x', 'min_y', 'max_x',
       'max_y', 'anisotropy', 'transcript_count', 'perimeter_area_ratio',
       'Txnip_raw', 'Txnip_high_pass', 'Fth1_raw', 'Fth1_high_pass',
       'DAPI_raw', 'DAPI_high_pass', 'Scgb1a1_raw', 'Scgb1a1_high_pass',
       'Sftpc_raw', 'Sftpc_high_pass', 'PolyT_raw', 'PolyT_high_pass',
       'Ifitm3_raw', 'Ifitm3_high_pass', 'region', 'slide', 'dataset_id',
       'cells_region', 'area', 'volume_sum', 'volume_final', 'num_z_planes',
       'size_normalized', 'surface_to_volume_ratio', 'sphericity', 'solidity',
       'elongation', 'condition', 'run_date', 'organism', 'cohort', 'sample',
       'n_counts', 'n_genes', 'Col1_raw', 'Col1_high_pass'],
      dtype='object')

In [15]:
adata.obsm["spatial"] = adata.obsm.get("spatial_microns", adata.obsm["spatial"])
adata = filter_spatial_outlier_cells(
    adata,
    data_dir=str(base_path),
    sample_metadata_file=sample_metadata_file,  ############
    save_path=save_path / "plots",
    logger=logger,
)

2025-10-21 12:44:39,075 [INFO]: [SynergyLung_s1_r0] Missing polygon file: /dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/SynergyLung_s1_r0/cell_outlier_coordinates.csv. Draw outliers in 10X Explorer if necessary.
2025-10-21 12:44:40,020 [INFO]: [SynergyLung_s2_r0] Missing polygon file: /dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/SynergyLung_s2_r0/cell_outlier_coordinates.csv. Draw outliers in 10X Explorer if necessary.
2025-10-21 12:44:40,899 [INFO]: [SynergyLung_s3_r0] Missing polygon file: /dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/SynergyLung_s3_r0/cell_outlier_coordinates.csv. Draw outliers in 10X Explorer if necessary.
2025-10-21 12:44:41,815 [INFO]: [SynergyLung_s3_r1] Missing polygon file: /dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/samples/SynergyLung_s3_r1/cell_outlier_coordinates.csv. Draw outliers in 10X Explorer if necessary.
2025-10-21 12:44:42,821 [INFO]: [SynergyLung_s4_r0] Missing polygon file: /dss/dssfs03/p

In [22]:
args.cohort == 'SynergyLung'

True

In [26]:
if "vpt_3D" in args.seg_method: # min_counts=10 due to smaller cell sizes
    min_counts = 10
elif args.cohort == "SynergyLung": # more lenient for initial analysis
    min_counts = 15
else:
    min_counts = None # default = 25

In [27]:
min_counts

15

In [28]:
adata = filter_low_quality_cells(
    adata,
    save_path=save_path / "plots",
    **({"min_counts": min_counts} if min_counts is not None else {}),
    logger=logger,
)

adata = filter_genes(adata, save_path=save_path / "plots", logger=logger)

adata = normalize_counts(
    adata, save_path=save_path / "plots", seg_method=args.seg_method, logger=logger
)

2025-10-21 12:47:55,598 [INFO]: # total cells before filtering low-quality or volume outliers: 1226739
2025-10-21 12:47:55,599 [INFO]: # low_quality_cell:             208207
2025-10-21 12:47:55,600 [INFO]: # volume_outlier_cell:       2173
2025-10-21 12:48:04,943 [INFO]: # total cells after filtering low-quality or volume outliers:  1016494
2025-10-21 12:48:04,944 [INFO]: # genes before filtering: 451
2025-10-21 12:48:07,888 [INFO]: # genes after filtering: 451
2025-10-21 12:48:14,323 [INFO]: Normalizing counts...
2025-10-21 12:48:16,720 [INFO]: Cells before/after outlier removal during normalization: 1016494 -> 996164
  return dispatch(args[0].__class__)(*args, **kw)


In [29]:
adata = pca_umap_single(adata, save_path=save_path / "plots", logger=logger)

2025-10-21 12:48:56,046 [INFO]: Dimensionality reduction: PCA
2025-10-21 12:55:04,226 [INFO]: Dimensionality reduction: UMAP with n_neighbors=20, n_pcs=50


In [35]:
!mamba install bioconda::harmonypy -y


Looking for: ['bioconda::harmonypy']

[?25l[2K[0G[+] 0.0s
[2K[1A[2K[0G[+] 0.1s
bioconda/linux-64 (check zst) [33m━━━━━━━━━━━━━╸[0m[90m━[0m   0.0 B @  ??.?MB/s Checking  0.1s[2K[1A[2K[0Gbioconda/linux-64 (check zst)                       Checked  0.1s
[?25h[?25l[2K[0G[+] 0.0s
bioconda/noarch (check zst) [90m━━━━━━━╸[0m[33m━━━━━━━━━[0m   0.0 B @  ??.?MB/s Checking  0.0s[2K[1A[2K[0Gbioconda/noarch (check zst)                         Checked  0.1s
[?25h[?25l[2K[0G[+] 0.0s
conda-forge/linux-64 [90m━━━━━━━━━━━━╸[0m[33m━━━━━━━━━━[0m   0.0 B /  ??.?MB @  ??.?MB/s  0.0s[2K[1A[2K[0G[+] 0.1s
conda-forge/linux-64 [90m━━━━━━━━━━━━━━━━━━━━━━━[0m  14.1kB /  48.0MB @ 232.4kB/s  0.1s
conda-forge/noarch   [90m━━━━━━━━━━╸[0m[33m━━━━━━━━━━━━[0m   0.0 B /  ??.?MB @  ??.?MB/s  0.1s
bioconda/linux-64    [90m━━━━━━╸[0m[33m━━━━━━━━━━━━━━━━[0m   0.0 B /  ??.?MB @  ??.?MB/s  0.1s
bioconda/noarch      [33m━━━━━━━━━━━━╸[0m[90m━━━━━━━━━━[0m   0.0 B /  ??.?MB @  

In [36]:
# Save result
logger.info("Saving integrated object...")
output_path = save_path / "adatas"
output_path.mkdir(parents=True, exist_ok=True)
adata.write(output_path / "adata_TEMP.h5ad.gz", compression="gzip")
logger.info("Done.")

2025-10-21 14:27:17,915 [INFO]: Saving integrated object...


ERROR! Session/line number was not unique in database. History logging moved to new session 14


2025-10-21 14:29:34,581 [INFO]: Done.


In [11]:
output_path = save_path / "adatas"
output_path / "adata_TEMP.h5ad.gz"

PosixPath('/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/analysis/SynergyLung/Cellpose_1_Merlin/adatas/adata_TEMP.h5ad.gz')

In [12]:
adata = sc.read_h5ad(output_path / "adata_TEMP.h5ad.gz")

In [13]:
adata

AnnData object with n_obs × n_vars = 996164 × 451
    obs: 'fov', 'volume', 'center_x', 'center_y', 'min_x', 'min_y', 'max_x', 'max_y', 'anisotropy', 'transcript_count', 'perimeter_area_ratio', 'Txnip_raw', 'Txnip_high_pass', 'Fth1_raw', 'Fth1_high_pass', 'DAPI_raw', 'DAPI_high_pass', 'Scgb1a1_raw', 'Scgb1a1_high_pass', 'Sftpc_raw', 'Sftpc_high_pass', 'PolyT_raw', 'PolyT_high_pass', 'Ifitm3_raw', 'Ifitm3_high_pass', 'region', 'slide', 'dataset_id', 'cells_region', 'area', 'volume_sum', 'volume_final', 'num_z_planes', 'size_normalized', 'surface_to_volume_ratio', 'sphericity', 'solidity', 'elongation', 'condition', 'run_date', 'organism', 'cohort', 'sample', 'n_counts', 'n_genes', 'Col1_raw', 'Col1_high_pass', 'spatial_outlier', 'low_quality_cell', 'volume_outlier_cell'
    var: 'n_cells'
    uns: 'X_umap_20_50', 'pca'
    obsm: 'X_pca', 'X_umap_20_50', 'blank', 'spatial', 'spatial_microns', 'spatial_pixel'
    varm: 'PCs'
    layers: 'counts', 'librarysize_log1p_norm', 'volume_log1p_no

In [14]:
adata = integration_harmony_new(
    adata,
    batch_key="slide",
    save_path=save_path / "plots",
    logger=logger,
)

2025-10-21 14:36:24,818 [INFO]: Integration: Run Harmony
2025-10-21 14:36:27,278 - harmonypy - INFO - Computing initial centroids with sklearn.KMeans...
2025-10-21 14:39:50,480 - harmonypy - INFO - sklearn.KMeans initialization complete.
2025-10-21 14:40:02,229 - harmonypy - INFO - Iteration 1 of 10
2025-10-21 14:55:51,017 - harmonypy - INFO - Iteration 2 of 10
2025-10-21 15:11:31,129 - harmonypy - INFO - Iteration 3 of 10
2025-10-21 15:27:01,647 - harmonypy - INFO - Converged after 3 iterations
2025-10-21 15:27:01,907 [INFO]: Integration: Compute neighbors and UMAP


In [None]:
# Save result
logger.info("Saving integrated object...")
if "fov" not in adata.obs.columns:
    adata.obs["fov"] = ""
adata.obs["fov"] = adata.obs["fov"].astype(str)
output_path = save_path / "adatas"
output_path.mkdir(parents=True, exist_ok=True)
adata.write(output_path / "adata_integrated.h5ad.gz", compression="gzip")
logger.info("Done.")

2025-10-21 17:26:50,297 [INFO]: Saving integrated object...


In [None]:
#######
# spatial plots after filtering
from cellseg_benchmark.adata_utils import _plot_flag

In [None]:
################

In [33]:
from os.path import join
from typing import Optional, Sequence

import matplotlib.pyplot as plt
import scanpy as sc
from anndata import AnnData


def plot_integration_comparison(
    adata: AnnData,
    save_path: str,
    umap_key: str,
    batch_key: Optional[str] = None,
    color_keys: Optional[Sequence[tuple[str, str]]] = None,
    point_size_factor: int = 150000,
    dpi: int = 200,
    filename: str = "UMAP_integrated_harmony.png",
) -> None:
    """
    Before/after integration UMAPs in two columns .
    """

    cfg = color_keys or [
        ("sample", "Sample"),
        ("slide", "Slide"),
        ("condition", "Condition"),
        ("cell_type_revised", "Cell Type"),
    ]
    cfg = [(k, lbl) for k, lbl in cfg if k in adata.obs.columns]
    if not cfg:
        return

    if umap_key.replace("_harmony", "") not in adata.obsm or umap_key not in adata.obsm:
        raise KeyError("Requested embedding(s) not found in adata.obsm")

    nrows = len(cfg)
    fig, axes = plt.subplots(
        nrows,
        2,
        figsize=(7.5, 4 * nrows),
        gridspec_kw={"wspace": 0.02, "hspace": 0.02},
    )
    if nrows == 1:
        axes = axes.reshape(1, 2)

    fig.subplots_adjust(
        left=0.06, right=0.98, top=0.90, bottom=0.06, wspace=0.02, hspace=0.02
    )

    # Column headers
    left_box = axes[0, 0].get_position(fig)
    right_box = axes[0, 1].get_position(fig)
    top_y = max(left_box.y1, right_box.y1) + 0.005
    fig.text(
        left_box.x0 + left_box.width / 2,
        top_y,
        "Unintegrated",
        ha="center",
        va="bottom",
        fontsize=13,
    )
    fig.text(
        right_box.x0 + right_box.width / 2,
        top_y,
        f"Integrated (Harmony{f'; by {batch_key}' if batch_key else ''})",
        ha="center",
        va="bottom",
        fontsize=13,
    )

    # Plot rows
    for i, (obs_key, row_label) in enumerate(cfg):
        sc.pl.embedding(
            adata,
            basis=umap_key.replace("_harmony", ""),
            color=obs_key,
            ax=axes[i, 0],
            show=False,
            size=point_size_factor / adata.n_obs,
            legend_loc=None,
            title="",
        )
        sc.pl.embedding(
            adata,
            basis=umap_key,
            color=obs_key,
            ax=axes[i, 1],
            show=False,
            size=point_size_factor / adata.n_obs,
            legend_loc="right margin",
            legend_fontsize=8,
            title="",
        )
        axes[i, 0].annotate(
            row_label,
            xy=(-0.06, 0.5),
            xycoords="axes fraction",
            va="center",
            ha="right",
            rotation=90,
            fontsize=11,
        )

    for ax in axes.ravel():
        if hasattr(ax, "set_box_aspect"):
            ax.set_box_aspect(1)
        else:
            ax.set_aspect("equal", adjustable="datalim")
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel("")
        ax.set_ylabel("")

    plt.savefig(
        join(save_path, filename), dpi=dpi, bbox_inches="tight", pad_inches=0.08
    )
    plt.close()

In [34]:
plot_integration_comparison(
    adata,
    save_path=save_path / "plots",
    umap_key="X_umap_harmony_20_50",
    batch_key="slide",
    point_size_factor=150000,
)

In [9]:
###############################################
# load adata post hoc
output_path = save_path / "adatas"
adata = sc.read_h5ad(output_path / "adata_integrated.h5ad.gz")

In [10]:
adata.shape

(996164, 451)

In [12]:
adata

AnnData object with n_obs × n_vars = 996164 × 451
    obs: 'fov', 'volume', 'center_x', 'center_y', 'min_x', 'min_y', 'max_x', 'max_y', 'anisotropy', 'transcript_count', 'perimeter_area_ratio', 'Txnip_raw', 'Txnip_high_pass', 'Fth1_raw', 'Fth1_high_pass', 'DAPI_raw', 'DAPI_high_pass', 'Scgb1a1_raw', 'Scgb1a1_high_pass', 'Sftpc_raw', 'Sftpc_high_pass', 'PolyT_raw', 'PolyT_high_pass', 'Ifitm3_raw', 'Ifitm3_high_pass', 'region', 'slide', 'dataset_id', 'cells_region', 'area', 'volume_sum', 'volume_final', 'num_z_planes', 'size_normalized', 'surface_to_volume_ratio', 'sphericity', 'solidity', 'elongation', 'condition', 'run_date', 'organism', 'cohort', 'sample', 'n_counts', 'n_genes', 'Col1_raw', 'Col1_high_pass', 'spatial_outlier', 'low_quality_cell', 'volume_outlier_cell'
    var: 'n_cells'
    obsm: 'X_pca', 'X_pca_harmony', 'X_umap_20_50', 'X_umap_harmony_20_50', 'X_umap_harmony_20_50_3d', 'blank', 'spatial', 'spatial_microns', 'spatial_pixel'
    varm: 'PCs'
    layers: 'counts', 'libr