In [1]:
import argparse
import logging
import os
import sys
import warnings
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Tuple

import scanpy as sc
import yaml
from anndata import AnnData
from spatialdata import read_zarr

sys.path.insert(1, "/dss/dsshome1/0C/ra98gaq/Git/cellseg-benchmark")
from cellseg_benchmark.adata_utils import (
    filter_genes,
    filter_low_quality_cells,
    filter_spatial_outlier_cells,
    integration_harmony,
    merge_adatas,
    normalize_counts,
    pca_umap_single,
)


def _load_one(
    sample_dir: Path, seg_method: str, logger: logging.Logger
) -> Tuple[str, AnnData | None]:
    """Load AnnData from one master sdata."""
    sdata = read_zarr(sample_dir / "sdata_z3.zarr", selection=("tables",))
    if f"adata_{seg_method}" not in sdata.tables.keys():
        if logger:
            logger.warning(f"Skipping {seg_method}. No such key: {seg_method}")
        return sample_dir.name, None
    return sample_dir.name, sdata[f"adata_{seg_method}"]


warnings.filterwarnings("ignore", ".*The table is annotating*", UserWarning)

sc.settings.n_jobs = -1

# Logger setup
logger = logging.getLogger("integrate_adatas")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s]: %(message)s"))
logger.addHandler(handler)



In [2]:
# CLI args
sys.argv = ["notebook", "foxf2", "Cellpose_1_nuclei_model"]
# sys.argv = ["notebook", "foxf2", "Proseg_Cellpose_1_nuclei_model"]
# sys.argv = ["notebook", "aging", "Cellpose_1_nuclei_model"]

parser = argparse.ArgumentParser(
    description="Integrate adatas from a selected segmentation method."
)
parser.add_argument("cohort", help="Cohort name, e.g., 'foxf2'")
parser.add_argument(
    "seg_method", help="Segmentation method, e.g., 'Cellpose_1_nuclei_model'"
)
args = parser.parse_args()
args

Namespace(cohort='foxf2', seg_method='Cellpose_1_nuclei_model')

In [3]:
base_path = Path("/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark")
samples_path = base_path / "samples"
save_path = base_path / "analysis" / args.cohort / args.seg_method
save_path.mkdir(parents=True, exist_ok=True)

In [4]:
sample_metadata_file, excluded = (
    yaml.safe_load(open(base_path / "misc" / f))
    for f in ["sample_metadata.yaml", "samples_excluded.yaml"]
)  ###############
sample_metadata_file

{'foxf2_s1_r0': {'cohort': 'foxf2',
  'slide': 1,
  'region': 0,
  'genotype': 'ECKO',
  'age_months': 6,
  'run_date': '20240229',
  'animal_id': '000',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240229_Foxf2-Slide01-cp-WT-ECKO/region_0-ECKO000'},
 'foxf2_s1_r1': {'cohort': 'foxf2',
  'slide': 1,
  'region': 1,
  'genotype': 'WT',
  'age_months': 6,
  'run_date': '20240229',
  'animal_id': '000',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240229_Foxf2-Slide01-cp-WT-ECKO/region_1-WT000'},
 'foxf2_s2_r1': {'cohort': 'foxf2',
  'slide': 2,
  'region': 1,
  'genotype': 'WT',
  'age_months': 6,
  'run_date': '20240322',
  'animal_id': '536',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240322_Foxf2-Slide02-cp-WT-PCKO/region_1-WT536'},
 'foxf2_s2_r2': {'cohort': 'foxf2',
  'slide': 2,
  'region': 2,
  'genotype': 'PCKO',
  '

In [5]:
excluded_samples = set(excluded.get(args.cohort, []))  ###############
excluded_samples

{'foxf2_s2_r0', 'foxf2_s3_r0', 'foxf2_s3_r1'}

In [6]:
# Load sdata
logger.info("Loading data...")
loads = []
for sample_dir in samples_path.glob(f"{args.cohort}*"):  ###############
    if sample_dir.name in excluded_samples:
        continue
    if not (sample_dir / "sdata_z3.zarr").exists():
        logger.error("master sdata in %s not found.", sample_dir)
        continue
    loads.append(sample_dir)

# drop samples where AnnData could not be loaded (missing or invalid), e.g. aging s8 r1 Cellpose 2 Transcripts
with ProcessPoolExecutor(max_workers=int(os.getenv("SLURM_CPUS_PER_TASK", 1))) as ex:
    futures = {ex.submit(_load_one, p, args.seg_method, logger): p for p in loads}
    adata_list = [f.result() for f in as_completed(futures)]
adata_list = [(x, y) for x, y in adata_list if y is not None]

2025-09-18 14:06:26,540 [INFO]: Loading data...


In [7]:
# temp fix for aging_s11_r0
for i, (name, ad) in enumerate(adata_list):
    if name == "aging_s11_r0":
        ad.obs["region"] = "0"
    ad.obs["region"] = ad.obs["region"].astype(str).astype("category")
    adata_list[i] = (name, ad)

In [8]:
# Merge and process
adata = merge_adatas(
    adata_list,
    seg_method=args.seg_method,
    logger=logger,
    plot_qc_stats=True,
    save_path=save_path / "plots",
)
del adata_list

2025-09-18 14:19:42,431 [INFO]: Merging adatas of Cellpose_1_nuclei_model
100%|██████████| 13/13 [00:00<00:00, 16.55it/s]
  utils.warn_names_duplicates("obs")
2025-09-18 14:19:45,980 [INFO]: Cellpose_1_nuclei_model: # of cells: 595581, # of samples: 13
2025-09-18 14:19:46,026 [INFO]: Cellpose_1_nuclei_model: # of cells: 595581, # of samples: 13
2025-09-18 14:19:46,038 [INFO]: Plotting QC results
  fig.tight_layout()
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[k

In [9]:
adata.obs["sample"].value_counts()

sample
foxf2_5_1    58868
foxf2_5_0    56940
foxf2_4_0    49649
foxf2_2_1    47686
foxf2_6_2    47030
foxf2_1_0    46814
foxf2_2_2    45477
foxf2_6_1    44109
foxf2_1_1    42513
foxf2_7_0    42204
foxf2_7_1    41277
foxf2_6_0    37142
foxf2_4_1    35872
Name: count, dtype: int64

In [10]:
# workaround to fix adata.obs formatting ###############
adata.obs["sample"] = adata.obs["sample"].str.replace(
    rf"^{args.cohort}_(\d+)_(\d+)$", rf"{args.cohort}_s\1_r\2", regex=True
)
adata.obs.drop(columns=["spt_region"], inplace=True, errors="ignore")
adata.obs["condition"] = (
    adata.obs["genotype"].astype(str) + "_" + adata.obs["age_months"].astype(str)
)

In [11]:
adata.obs["condition"].value_counts()

condition
WT_6      179433
PCKO_6    152066
ECKO_6    141554
GLKO_6    122528
Name: count, dtype: int64

In [12]:
adata.obs["age_months"].value_counts()

age_months
6    595581
Name: count, dtype: int64

In [15]:
adata.obs["region"]

aaaaaaaa-1      0
aaaaaaab-1      0
aaaaaaac-1      0
aaaaaaad-1      0
aaaaaaae-1      0
               ..
aaaakmei-1-7    1
aaaakmej-1-7    1
aaaakmek-1-7    1
aaaakmel-1-7    1
aaaakmem-1-7    1
Name: region, Length: 582360, dtype: category
Categories (3, object): ['0', '1', '2']

In [13]:
adata.obs.columns

Index(['region', 'slide', 'cell_type_mmc_incl_low_quality_revised',
       'cell_type_mmc_incl_low_quality_clusters',
       'cell_type_mmc_incl_low_quality', 'cell_type_mmc_incl_mixed_revised',
       'cell_type_mmc_incl_mixed_clusters', 'cell_type_mmc_incl_mixed',
       'cell_type_mmc_raw_revised', 'cell_type_mmc_raw_clusters',
       'cell_type_mmc_raw', 'area', 'volume_sum', 'volume_final',
       'num_z_planes', 'size_normalized', 'surface_to_volume_ratio',
       'sphericity', 'solidity', 'elongation', 'genotype', 'age_months',
       'run_date', 'animal_id', 'organism', 'cohort', 'sample', 'n_counts',
       'n_genes', 'condition'],
      dtype='object')

In [14]:
adata.obsm["spatial"] = adata.obsm.get("spatial_microns", adata.obsm["spatial"])
adata = filter_spatial_outlier_cells(
    adata,
    data_dir=str(base_path),
    sample_metadata_file=sample_metadata_file,  ############
    save_path=save_path / "plots",
    logger=logger,
)

2025-09-18 14:21:46,861 [INFO]: # total cells before filtering spatial outliers: 595581
2025-09-18 14:21:46,864 [INFO]: # spatial_outlier:       13221
2025-09-18 14:21:50,871 [INFO]: # total cells after filtering spatial outliers:  582360


In [16]:
# Special case: rerun filter_low_quality_cells with lower threshold for vpt_3D
if "vpt_3D" in args.seg_method:
    logger.info(
        "Segmentation method contains 'vpt_3D': applying low-quality cell filtering with min_counts=10 due to smaller cell sizes."
    )
    adata = filter_low_quality_cells(
        adata,
        save_path=save_path / "plots",
        min_counts=10,
        logger=logger,
    )
else:
    adata = filter_low_quality_cells(
        adata,
        save_path=save_path / "plots",
        logger=logger,
    )

adata = filter_genes(adata, save_path=save_path / "plots", logger=logger)

adata = normalize_counts(
    adata, save_path=save_path / "plots", seg_method=args.seg_method, logger=logger
)

2025-09-18 14:22:11,222 [INFO]: # total cells before filtering low-quality or volume outliers: 582360
2025-09-18 14:22:11,225 [INFO]: # low_quality_cell:             5003
2025-09-18 14:22:11,226 [INFO]: # volume_outlier_cell:       66
2025-09-18 14:22:24,973 [INFO]: # total cells after filtering low-quality or volume outliers:  577291
2025-09-18 14:22:24,977 [INFO]: # genes before filtering: 500
2025-09-18 14:22:30,010 [INFO]: # genes after filtering: 500
2025-09-18 14:22:32,768 [INFO]: Normalizing counts...
2025-09-18 14:22:36,390 [INFO]: Cells before/after outlier removal during normalization: 577291 -> 565745
  return dispatch(args[0].__class__)(*args, **kw)


In [17]:
adata = pca_umap_single(adata, save_path=save_path / "plots", logger=logger)

2025-09-18 14:23:04,314 [INFO]: Dimensionality reduction: PCA
2025-09-18 14:24:16,306 [INFO]: Dimensionality reduction: UMAP with {'n_neighbors': 20, 'n_pcs': 50, 'key': 'X_umap_20_50'}


In [18]:
!mamba install bioconda::harmonypy -y


Looking for: ['bioconda::harmonypy']

[?25l[2K[0G[+] 0.0s
[2K[1A[2K[0G[+] 0.1s
bioconda/linux-64 (check zst) [90m━━╸[0m[33m━━━━━━━━━━━━[0m   0.0 B @  ??.?MB/s Checking  0.1s[2K[1A[2K[0Gbioconda/linux-64 (check zst)                       Checked  0.1s
[?25h[?25l[2K[0G[+] 0.0s
bioconda/noarch (check zst) [90m━━━━━╸[0m[33m━━━━━━━━━━━[0m   0.0 B @  ??.?MB/s Checking  0.0s[2K[1A[2K[0Gbioconda/noarch (check zst)                         Checked  0.0s
[?25h[?25l[2K[0G[+] 0.0s
conda-forge/linux-64 [33m━━━━━━━━━━╸[0m[90m━━━━━━━━━━━━[0m   0.0 B /  ??.?MB @  ??.?MB/s  0.0s[2K[1A[2K[0G[+] 0.1s
conda-forge/linux-64 [90m━━━━━━━━━━━━━━━━━━━━━━━[0m  60.4kB /  47.0MB @ 951.1kB/s  0.1s
conda-forge/noarch   [90m━━╸[0m[33m━━━━━━━━━━━━━━━╸[0m[90m━━━━[0m   0.0 B /  ??.?MB @  ??.?MB/s  0.1s
bioconda/linux-64    [33m━━━━━━━━━━━╸[0m[90m━━━━━━━━━━━[0m   0.0 B /  ??.?MB @  ??.?MB/s  0.1s
bioconda/noarch      [90m━━╸[0m[33m━━━━━━━━━━━━━━━╸[0m[90m━━━━[0m   0

In [None]:
adata = integration_harmony(
    adata,
    batch_key="slide",
    save_path=save_path / "plots",
    logger=logger,
)

2025-09-18 14:50:58,705 [INFO]: Integration: Run Harmony
2025-09-18 14:51:00,192 - harmonypy - INFO - Computing initial centroids with sklearn.KMeans...


In [None]:
# Save result
logger.info("Saving integrated object...")
if "fov" not in adata.obs.columns:
    adata.obs["fov"] = ""
adata.obs["fov"] = adata.obs["fov"].astype(str)
output_path = save_path / "adatas"
output_path.mkdir(parents=True, exist_ok=True)
adata.write(output_path / "adata_integrated.h5ad.gz", compression="gzip")
logger.info("Done.")

In [None]:
################

In [33]:
from os.path import join
from typing import Optional, Sequence

import matplotlib.pyplot as plt
import scanpy as sc
from anndata import AnnData


def plot_integration_comparison(
    adata: AnnData,
    save_path: str,
    umap_key: str,
    batch_key: Optional[str] = None,
    color_keys: Optional[Sequence[tuple[str, str]]] = None,
    point_size_factor: int = 150000,
    dpi: int = 200,
    filename: str = "UMAP_integrated_harmony.png",
) -> None:
    """
    Before/after integration UMAPs in two columns .
    """

    cfg = color_keys or [
        ("sample", "Sample"),
        ("slide", "Slide"),
        ("condition", "Condition"),
        ("cell_type_mmc_raw_revised", "Cell Type"),
    ]
    cfg = [(k, lbl) for k, lbl in cfg if k in adata.obs.columns]
    if not cfg:
        return

    if umap_key.replace("_harmony", "") not in adata.obsm or umap_key not in adata.obsm:
        raise KeyError("Requested embedding(s) not found in adata.obsm")

    nrows = len(cfg)
    fig, axes = plt.subplots(
        nrows,
        2,
        figsize=(7.5, 4 * nrows),
        gridspec_kw={"wspace": 0.02, "hspace": 0.02},
    )
    if nrows == 1:
        axes = axes.reshape(1, 2)

    fig.subplots_adjust(
        left=0.06, right=0.98, top=0.90, bottom=0.06, wspace=0.02, hspace=0.02
    )

    # Column headers
    left_box = axes[0, 0].get_position(fig)
    right_box = axes[0, 1].get_position(fig)
    top_y = max(left_box.y1, right_box.y1) + 0.005
    fig.text(
        left_box.x0 + left_box.width / 2,
        top_y,
        "Unintegrated",
        ha="center",
        va="bottom",
        fontsize=13,
    )
    fig.text(
        right_box.x0 + right_box.width / 2,
        top_y,
        f"Integrated (Harmony{f'; by {batch_key}' if batch_key else ''})",
        ha="center",
        va="bottom",
        fontsize=13,
    )

    # Plot rows
    for i, (obs_key, row_label) in enumerate(cfg):
        sc.pl.embedding(
            adata,
            basis=umap_key.replace("_harmony", ""),
            color=obs_key,
            ax=axes[i, 0],
            show=False,
            size=point_size_factor / adata.n_obs,
            legend_loc=None,
            title="",
        )
        sc.pl.embedding(
            adata,
            basis=umap_key,
            color=obs_key,
            ax=axes[i, 1],
            show=False,
            size=point_size_factor / adata.n_obs,
            legend_loc="right margin",
            legend_fontsize=8,
            title="",
        )
        axes[i, 0].annotate(
            row_label,
            xy=(-0.06, 0.5),
            xycoords="axes fraction",
            va="center",
            ha="right",
            rotation=90,
            fontsize=11,
        )

    for ax in axes.ravel():
        if hasattr(ax, "set_box_aspect"):
            ax.set_box_aspect(1)
        else:
            ax.set_aspect("equal", adjustable="datalim")
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel("")
        ax.set_ylabel("")

    plt.savefig(
        join(save_path, filename), dpi=dpi, bbox_inches="tight", pad_inches=0.08
    )
    plt.close()

In [34]:
plot_integration_comparison(
    adata,
    save_path=save_path / "plots",
    umap_key="X_umap_harmony_20_50",
    batch_key="slide",
    point_size_factor=150000,
)

In [6]:
###############################################
# load adata post hoc
output_path = save_path / "adatas"
adata = sc.read_h5ad(output_path / "adata_integrated.h5ad.gz")

In [7]:
adata.shape

(565745, 500)