In [1]:
import argparse
import logging
import os
import sys
import warnings
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Tuple

import scanpy as sc
import yaml
from anndata import AnnData
from spatialdata import read_zarr

sys.path.insert(1, "/dss/dsshome1/0C/ra98gaq/Git/cellseg-benchmark")
from cellseg_benchmark.adata_utils import (
    filter_genes,
    filter_low_quality_cells,
    filter_spatial_outlier_cells,
    integration_harmony,
    merge_adatas,
    normalize_counts,
    pca_umap_single,
)


def _load_one(
    sample_dir: Path, seg_method: str, logger: logging.Logger
) -> Tuple[str, AnnData | None]:
    """Load AnnData from one master sdata."""
    sdata = read_zarr(sample_dir / "sdata_z3.zarr", selection=("tables",))
    if f"adata_{seg_method}" not in sdata.tables.keys():
        if logger:
            logger.warning(f"Skipping {seg_method}. No such key: {seg_method}")
        return sample_dir.name, None
    return sample_dir.name, sdata[f"adata_{seg_method}"]


warnings.filterwarnings("ignore", ".*The table is annotating*", UserWarning)

sc.settings.n_jobs = -1

# Logger setup
logger = logging.getLogger("integrate_adatas")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s]: %(message)s"))
logger.addHandler(handler)



In [2]:
# CLI args
sys.argv = ["notebook", "foxf2", "Cellpose_1_nuclei_model"]
# sys.argv = ["notebook", "foxf2", "Proseg_Cellpose_1_nuclei_model"]

parser = argparse.ArgumentParser(
    description="Integrate adatas from a selected segmentation method."
)
parser.add_argument("cohort", help="Cohort name, e.g., 'foxf2'")
parser.add_argument(
    "seg_method", help="Segmentation method, e.g., 'Cellpose_1_nuclei_model'"
)
args = parser.parse_args()
args

Namespace(cohort='foxf2', seg_method='Cellpose_1_nuclei_model')

In [3]:
base_path = Path("/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark")
samples_path = base_path / "samples"
save_path = base_path / "analysis" / args.cohort / args.seg_method
save_path.mkdir(parents=True, exist_ok=True)

In [4]:
sample_metadata_file, excluded = (
    yaml.safe_load(open(base_path / "misc" / f))
    for f in ["sample_metadata.yaml", "samples_excluded.yaml"]
)  ###############
sample_metadata_file

{'foxf2_s1_r0': {'cohort': 'foxf2',
  'slide': 1,
  'region': 0,
  'genotype': 'ECKO',
  'age_months': 6,
  'run_date': '20240229',
  'animal_id': '000',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240229_Foxf2-Slide01-cp-WT-ECKO/region_0-ECKO000'},
 'foxf2_s1_r1': {'cohort': 'foxf2',
  'slide': 1,
  'region': 1,
  'genotype': 'WT',
  'age_months': 6,
  'run_date': '20240229',
  'animal_id': '000',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240229_Foxf2-Slide01-cp-WT-ECKO/region_1-WT000'},
 'foxf2_s2_r1': {'cohort': 'foxf2',
  'slide': 2,
  'region': 1,
  'genotype': 'WT',
  'age_months': 6,
  'run_date': '20240322',
  'animal_id': '536',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240322_Foxf2-Slide02-cp-WT-PCKO/region_1-WT536'},
 'foxf2_s2_r2': {'cohort': 'foxf2',
  'slide': 2,
  'region': 2,
  'genotype': 'PCKO',
  '

In [5]:
excluded_samples = set(excluded.get(args.cohort, []))  ###############
excluded_samples

{'foxf2_s2_r0', 'foxf2_s3_r0', 'foxf2_s3_r1'}

In [6]:
# Load sdata
logger.info("Loading data...")
loads = []
for sample_dir in samples_path.glob(f"{args.cohort}*"):  ###############
    if sample_dir.name in excluded_samples:
        continue
    if not (sample_dir / "sdata_z3.zarr").exists():
        logger.error("master sdata in %s not found.", sample_dir)
        continue
    loads.append(sample_dir)

# drop samples where AnnData could not be loaded (missing or invalid), e.g. aging s8 r1 Cellpose 2 Transcripts
with ProcessPoolExecutor(max_workers=int(os.getenv("SLURM_CPUS_PER_TASK", 1))) as ex:
    futures = {ex.submit(_load_one, p, args.seg_method, logger): p for p in loads}
    adata_list = [f.result() for f in as_completed(futures)]
adata_list = [(x, y) for x, y in adata_list if y is not None]

2025-09-17 09:48:03,935 [INFO]: Loading data...


In [7]:
# Merge and process
adata = merge_adatas(
    adata_list,
    seg_method=args.seg_method,
    logger=logger,
    plot_qc_stats=True,
    save_path=save_path / "plots",
)
del adata_list

2025-09-17 09:53:16,658 [INFO]: Merging adatas of Cellpose_1_nuclei_model
100%|██████████| 13/13 [00:00<00:00, 40.39it/s]
  utils.warn_names_duplicates("obs")
2025-09-17 09:53:17,985 [INFO]: Cellpose_1_nuclei_model: # of cells: 595581, # of samples: 13
2025-09-17 09:53:18,010 [INFO]: Cellpose_1_nuclei_model: # of cells: 595581, # of samples: 13
2025-09-17 09:53:18,017 [INFO]: Plotting QC results
  fig.tight_layout()
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[key] = c
  df[k

In [8]:
# workaround to fix adata.obs formatting ###############
adata.obs["sample"] = adata.obs["sample"].str.replace(
    rf"^{args.cohort}_(\d+)_(\d+)$", rf"{args.cohort}_s\1_r\2", regex=True
)
adata.obs.drop(columns=["spt_region"], inplace=True, errors="ignore")
adata.obs["condition"] = (
    adata.obs["genotype"].astype(str) + "_" + adata.obs["age_months"].astype(str)
)

In [9]:
adata.obs["sample"].value_counts()

sample
foxf2_s5_r1    58868
foxf2_s5_r0    56940
foxf2_s4_r0    49649
foxf2_s2_r1    47686
foxf2_s6_r2    47030
foxf2_s1_r0    46814
foxf2_s2_r2    45477
foxf2_s6_r1    44109
foxf2_s1_r1    42513
foxf2_s7_r0    42204
foxf2_s7_r1    41277
foxf2_s6_r0    37142
foxf2_s4_r1    35872
Name: count, dtype: int64

In [10]:
adata.obs["condition"].value_counts()

condition
WT_6      179433
PCKO_6    152066
ECKO_6    141554
GLKO_6    122528
Name: count, dtype: int64

In [11]:
adata.obs.columns

Index(['region', 'slide', 'cell_type_mmc_incl_low_quality_revised',
       'cell_type_mmc_incl_low_quality_clusters',
       'cell_type_mmc_incl_low_quality', 'cell_type_mmc_incl_mixed_revised',
       'cell_type_mmc_incl_mixed_clusters', 'cell_type_mmc_incl_mixed',
       'cell_type_mmc_raw_revised', 'cell_type_mmc_raw_clusters',
       'cell_type_mmc_raw', 'area', 'volume_sum', 'volume_final',
       'num_z_planes', 'size_normalized', 'surface_to_volume_ratio',
       'sphericity', 'solidity', 'elongation', 'genotype', 'age_months',
       'run_date', 'animal_id', 'organism', 'cohort', 'sample', 'n_counts',
       'n_genes', 'condition'],
      dtype='object')

In [12]:
adata.obsm["spatial"] = adata.obsm.get("spatial_microns", adata.obsm["spatial"])
adata = filter_spatial_outlier_cells(
    adata,
    data_dir=str(base_path),
    sample_metadata_file=sample_metadata_file,  ############
    save_path=save_path / "plots",
    logger=logger,
)

2025-09-17 09:54:24,996 [INFO]: # total cells before filtering spatial outliers: 595581
2025-09-17 09:54:24,997 [INFO]: # spatial_outlier:       13221
2025-09-17 09:54:26,290 [INFO]: # total cells after filtering spatial outliers:  582360


In [13]:
# Special case: rerun filter_low_quality_cells with lower threshold for vpt_3D
if "vpt_3D" in args.seg_method:
    logger.info(
        "Segmentation method contains 'vpt_3D': applying low-quality cell filtering with min_counts=10 due to smaller cell sizes."
    )
    adata = filter_low_quality_cells(
        adata,
        save_path=save_path / "plots",
        min_counts=10,
        logger=logger,
    )
else:
    adata = filter_low_quality_cells(
        adata,
        save_path=save_path / "plots",
        logger=logger,
    )

adata = filter_genes(adata, save_path=save_path / "plots", logger=logger)

adata = normalize_counts(
    adata, save_path=save_path / "plots", seg_method=args.seg_method, logger=logger
)

2025-09-17 09:54:39,681 [INFO]: # total cells before filtering low-quality or volume outliers: 582360
2025-09-17 09:54:39,681 [INFO]: # low_quality_cell:             5003
2025-09-17 09:54:39,682 [INFO]: # volume_outlier_cell:       66
2025-09-17 09:54:44,426 [INFO]: # total cells after filtering low-quality or volume outliers:  577291
2025-09-17 09:54:44,427 [INFO]: # genes before filtering: 500
2025-09-17 09:54:46,475 [INFO]: # genes after filtering: 500
2025-09-17 09:54:48,020 [INFO]: Normalizing counts...
2025-09-17 09:54:49,426 [INFO]: Cells before/after outlier removal during normalization: 577291 -> 565745
  return dispatch(args[0].__class__)(*args, **kw)


In [14]:
adata = pca_umap_single(adata, save_path=save_path / "plots", logger=logger)

2025-09-17 09:55:01,040 [INFO]: Dimensionality reduction: PCA
2025-09-17 09:55:37,665 [INFO]: Dimensionality reduction: UMAP with {'n_neighbors': 20, 'n_pcs': 50, 'key': 'X_umap_20_50'}


In [16]:
!mamba install bioconda::harmonypy -y


Looking for: ['bioconda::harmonypy']

[?25l[2K[0G[+] 0.0s
[2K[1A[2K[0G[+] 0.1s
bioconda/linux-64 (check zst) [33m━━━━━━━━━━━━━━━[0m   0.0 B @  ??.?MB/s Checking  0.1s[2K[1A[2K[0Gbioconda/linux-64 (check zst)                       Checked  0.1s
[?25h[?25l[2K[0G[+] 0.0s
bioconda/noarch (check zst) [33m━━━━━━━━━━━━╸[0m[90m━━━━[0m   0.0 B @  ??.?MB/s Checking  0.0s[2K[1A[2K[0Gbioconda/noarch (check zst)                         Checked  0.0s
[?25h[?25l[2K[0G[+] 0.0s
conda-forge/linux-64 [33m━━━━━━━━━━━━━━╸[0m[90m━━━━━━━━[0m   0.0 B /  ??.?MB @  ??.?MB/s  0.0s[2K[1A[2K[0G[+] 0.1s
conda-forge/linux-64 [90m━━━━━━━━━━━━━━━━━━━━━━━[0m  60.7kB /  47.0MB @   1.2MB/s  0.1s
conda-forge/noarch   [33m━━━━━━━━━━━━━╸[0m[90m━━━━━━━━━[0m   0.0 B /  ??.?MB @  ??.?MB/s  0.1s
bioconda/linux-64    [33m━━━━━━━╸[0m[90m━━━━━━━━━━━━━━━[0m   0.0 B /  ??.?MB @  ??.?MB/s  0.1s
bioconda/noarch      [90m━╸[0m[33m━━━━━━━━━━━━━━━╸[0m[90m━━━━━[0m   0.0 B /  ??.?MB @  

In [17]:
adata = integration_harmony(
    adata,
    batch_key="slide",
    save_path=save_path / "plots",
    logger=logger,  ################
)

2025-09-17 10:09:10,435 [INFO]: Integration: Run Harmony
2025-09-17 10:09:10,853 - harmonypy - INFO - Computing initial centroids with sklearn.KMeans...
2025-09-17 10:09:57,750 - harmonypy - INFO - sklearn.KMeans initialization complete.
2025-09-17 10:09:59,675 - harmonypy - INFO - Iteration 1 of 10
2025-09-17 10:12:56,241 - harmonypy - INFO - Iteration 2 of 10
2025-09-17 10:15:53,019 - harmonypy - INFO - Iteration 3 of 10
2025-09-17 10:18:49,596 - harmonypy - INFO - Iteration 4 of 10
2025-09-17 10:21:37,827 - harmonypy - INFO - Iteration 5 of 10
2025-09-17 10:22:31,845 - harmonypy - INFO - Iteration 6 of 10
2025-09-17 10:23:25,674 - harmonypy - INFO - Iteration 7 of 10
2025-09-17 10:24:19,556 - harmonypy - INFO - Iteration 8 of 10
2025-09-17 10:25:13,462 - harmonypy - INFO - Converged after 8 iterations
2025-09-17 10:25:13,527 [INFO]: Integration: Compute neighbors and UMAP


In [18]:
import logging
from os.path import join
from typing import Tuple

import matplotlib.pyplot as plt
import scanpy as sc
from anndata import AnnData
from matplotlib.path import Path


def _plot_integration_comparison(
    adata: AnnData,
    save_path: str,
    umap_key: str,
    batch_key: str,
    point_size_factor: int = 320000,
) -> None:
    """Helper function to plot before/after integration comparison."""

    fig, axes = plt.subplots(
        4, 2, figsize=(17, 22), gridspec_kw={"hspace": 0.01, "wspace": -0.54}
    )

    fig.text(0.40, 0.89, "Unintegrated", fontsize=16, ha="center")
    fig.text(
        0.62,
        0.89,
        f"Integrated (Harmony{f'; by {batch_key}' if batch_key else ''})",
        fontsize=16,
        ha="center",
    )

    plot_configs = [
        ("sample", "Sample"),
        ("slide", "Slide"),
        ("condition", "Condition"),
        ("cell_type_mmc_raw_revised", "Cell Type"),
    ]

    for i, (color_key, label) in enumerate(plot_configs):
        # Unintegrated
        sc.pl.embedding(
            adata,
            basis=umap_key.replace("_harmony", ""),
            color=color_key,
            show=False,
            ax=axes[i, 0],
            size=point_size_factor / adata.shape[0],
            legend_loc=None,
            title="",
        )

        # Integrated
        sc.pl.embedding(
            adata,
            basis=umap_key,
            color=color_key,
            show=False,
            ax=axes[i, 1],
            size=point_size_factor / adata.shape[0],
            legend_loc="right margin",
            title="",
        )

        axes[i, 0].text(
            -0.05,
            0.5,
            label,
            transform=axes[i, 0].transAxes,
            fontsize=14,
            rotation=90,
            va="center",
            ha="center",
        )

    for ax in axes.ravel():
        ax.set_aspect("equal", adjustable="box")
        ax.set_xlabel("")
        ax.set_ylabel("")

    plt.savefig(
        join(save_path, "UMAP_integrated_harmony2.png"), dpi=150, bbox_inches="tight"
    )
    plt.close()

In [25]:
_plot_integration_comparison(
    adata,
    save_path=save_path / "plots",
    umap_key="X_umap_harmony_20_50",
    batch_key="slide",
    point_size_factor=200000,
)

# tbc here
- add plotting code above to script

In [20]:
# Save result
logger.info("Saving integrated object...")
if "fov" not in adata.obs.columns:
    adata.obs["fov"] = ""
adata.obs["fov"] = adata.obs["fov"].astype(str)
output_path = save_path / "adatas"
output_path.mkdir(parents=True, exist_ok=True)
adata.write(output_path / "adata_integrated.h5ad.gz", compression="gzip")
logger.info("Done.")

2025-09-17 10:58:27,720 [INFO]: Saving integrated object...
2025-09-17 11:00:46,394 [INFO]: Done.


In [4]:
%%time
logger.info("Loading data...")
adata = sc.read_h5ad(save_path / "adatas" / "adata_integrated.h5ad.gz")

2025-09-04 11:44:15,699 [INFO]: Loading data...


CPU times: user 35.6 s, sys: 4.15 s, total: 39.8 s
Wall time: 42.9 s


In [5]:
save_path

PosixPath('/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/analysis/foxf2/Cellpose_1_nuclei_model')

In [14]:
adata.obs.drop(columns=["spt_region"], inplace=True, errors="ignore")