In [1]:
import argparse
import logging
import os
import warnings
from concurrent.futures import ProcessPoolExecutor, as_completed
from os.path import exists
from pathlib import Path
from typing import Tuple

import scanpy as sc
import yaml
from anndata import AnnData
from spatialdata import read_zarr
import sys

sys.path.insert(1, "/dss/dsshome1/0C/ra98gaq/Git/cellseg-benchmark")
from cellseg_benchmark.adata_utils import (
    merge_adatas,
    filter_low_quality_cells,
    filter_spatial_outlier_cells,
    filter_genes,
    normalize_counts,
    pca_umap_single,
    integration_harmony
)

def _load_one(
    sample_dir: Path, seg_method: str, logger: logging.Logger
) -> Tuple[str, AnnData | None]:
    """Load AnnData from one master sdata."""
    sdata = read_zarr(sample_dir / "sdata_z3.zarr", selection=("tables",))
    if f"adata_{seg_method}" not in sdata.tables.keys():
        if logger:
            logger.warning(f"Skipping {seg_method}. No such key: {seg_method}")
        return sample_dir.name, None
    return sample_dir.name, sdata[f"adata_{seg_method}"]

warnings.filterwarnings("ignore", ".*The table is annotating*", UserWarning)

sc.settings.n_jobs = -1

# Logger setup
logger = logging.getLogger("integrate_adatas")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s]: %(message)s"))
logger.addHandler(handler)



In [2]:
# CLI args
sys.argv = ["notebook", "foxf2", "Cellpose_1_nuclei_model"]
#sys.argv = ["notebook", "foxf2", "Proseg_Cellpose_1_nuclei_model"]

parser = argparse.ArgumentParser(
    description="Integrate adatas from a selected segmentation method."
)
parser.add_argument("cohort", help="Cohort name, e.g., 'foxf2'")
parser.add_argument(
    "seg_method", help="Segmentation method, e.g., 'Cellpose_1_nuclei_model'"
)
args = parser.parse_args()
args

Namespace(cohort='foxf2', seg_method='Cellpose_1_nuclei_model')

In [3]:
base_path = Path("/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark")
samples_path = base_path / "samples"
save_path = base_path / "analysis" / args.cohort / args.seg_method
save_path.mkdir(parents=True, exist_ok=True)

In [4]:
sample_metadata_file

{'foxf2_s1_r0': {'cohort': 'foxf2',
  'slide': 1,
  'region': 0,
  'genotype': 'ECKO',
  'age_months': 6,
  'run_date': '20240229',
  'animal_id': '000',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240229_Foxf2-Slide01-cp-WT-ECKO/region_0-ECKO000'},
 'foxf2_s1_r1': {'cohort': 'foxf2',
  'slide': 1,
  'region': 1,
  'genotype': 'WT',
  'age_months': 6,
  'run_date': '20240229',
  'animal_id': '000',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240229_Foxf2-Slide01-cp-WT-ECKO/region_1-WT000'},
 'foxf2_s2_r1': {'cohort': 'foxf2',
  'slide': 2,
  'region': 1,
  'genotype': 'WT',
  'age_months': 6,
  'run_date': '20240322',
  'animal_id': '536',
  'organism': 'mouse',
  'path': '/dss/dssfs03/pn52re/pn52re-dss-0000/202402-Foxf2/merfish_output/20240322_Foxf2-Slide02-cp-WT-PCKO/region_1-WT536'},
 'foxf2_s2_r2': {'cohort': 'foxf2',
  'slide': 2,
  'region': 2,
  'genotype': 'PCKO',
  '

In [None]:
with open(base_path / "misc" / "samples_excluded.yaml") as f: ###############
    excluded  = yaml.safe_load(f)
excluded_samples = set(excluded.get(args.cohort, []))

# Load sdata
logger.info("Loading data...")
loads = []
for sample_dir in samples_path.glob(f"{args.cohort}*"): ###############
    if sample_dir.name in excluded_samples:
        continue
    if not (sample_dir / "sdata_z3.zarr").exists():
        logger.error("master sdata in %s not found.", sample_dir)
        continue
    loads.append(sample_dir)

# drop samples where AnnData could not be loaded (missing or invalid), e.g. aging s8 r1 Cellpose 2 Transcripts
with ProcessPoolExecutor(max_workers=int(os.getenv("SLURM_CPUS_PER_TASK", 1))) as ex:
    futures = {ex.submit(_load_one, p, args.seg_method, logger): p for p in loads}
    adata_list = [f.result() for f in as_completed(futures)]
adata_list = [
    (x, y) for x, y in adata_list if y is not None
]

2025-09-03 15:52:41,037 [INFO]: Loading data...


In [None]:
# Merge and process
adata = merge_adatas(
    adata_list,
    seg_method=args.seg_method,
    logger=logger,
    plot_qc_stats=True,
    save_path=save_path / "plots",
)
del adata_list

In [None]:
# workaround to fix adata.obs formatting ###############
adata.obs["sample"] = adata.obs["sample"].str.replace(
    fr"^{args.cohort}_(\d+)_(\d+)$", 
    fr"{args.cohort}_s\1_r\2", 
    regex=True
)
adata.obs.drop(columns=["spt_region"], inplace=True, errors="ignore")
adata.obs["condition"] = adata.obs["genotype"].astype(str) + "_" + adata.obs["age_months"].astype(str)

In [None]:
adata.obs["sample"].value_counts()

In [None]:
adata.obs["condition"].value_counts()

In [None]:
adata.obs.columns

In [None]:
adata.obsm["spatial"] = adata.obsm.get("spatial_microns", adata.obsm["spatial"])
adata = filter_spatial_outlier_cells(
    adata,
    data_dir=str(base_path),
    sample_metadata_file=sample_metadata_file, ############
    save_path=save_path / "plots",
    logger=logger,
)

In [None]:
# Special case: rerun filter_low_quality_cells with lower threshold for vpt_3D
if "vpt_3D" in args.seg_method:
    logger.info(
        "Segmentation method contains 'vpt_3D': applying low-quality cell filtering with min_counts=10 due to smaller cell sizes."
    )
    adata = filter_low_quality_cells(
        adata,
        save_path=save_path / "plots",
        min_counts=10,
        logger=logger,
    )
else:
    adata = filter_low_quality_cells(
        adata,
        save_path=save_path / "plots",
        logger=logger,
    )

adata = filter_genes(adata, save_path=save_path / "plots", logger=logger)

adata = normalize_counts(
    adata, save_path=save_path / "plots", seg_method=args.seg_method, logger=logger
)

In [None]:
adata = pca_umap_single(
    adata, save_path=save_path / "plots", logger=logger
)

In [None]:
!mamba install bioconda::harmonypy -y

In [None]:
adata = integration_harmony(
    adata, batch_key="slide", save_path=save_path / "plots", logger=logger ################
)

In [39]:
import logging
import math
import re
from os.path import isfile, join
from typing import List, Tuple

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import seaborn as sns
from anndata import AnnData, concat
from matplotlib.path import Path
from mpl_toolkits.axes_grid1 import make_axes_locatable
from tqdm import tqdm
from cellseg_benchmark._constants import cell_type_colors

def _plot_integration_comparison(
    adata: AnnData, save_path: str, umap_key: str, batch_key: str, point_size_factor: int = 320000
) -> None:
    """Helper function to plot before/after integration comparison."""

    fig, axes = plt.subplots(4, 2, figsize=(17, 22), gridspec_kw={"hspace": 0.01, "wspace": -0.54})
    
    fig.text(0.40, 0.89, "Unintegrated", fontsize=16, ha="center")
    fig.text(0.62, 0.89, f"Integrated (Harmony{f'; by {batch_key}' if batch_key else ''})", fontsize=16, ha="center")

    plot_configs = [
        ("sample", "Sample"),
        ("slide", "Slide"),
        ("condition", "Condition"),
        ("cell_type_mmc_raw_revised", "Cell Type")
    ]

    for i, (color_key, label) in enumerate(plot_configs):
        # Unintegrated
        sc.pl.embedding(
            adata,
            basis=umap_key.replace("_harmony", ""),
            color=color_key,
            show=False,
            ax=axes[i, 0],
            size=point_size_factor / adata.shape[0],
            legend_loc=None,
            title="",
        )

        # Integrated
        sc.pl.embedding(
            adata,
            basis=umap_key,
            color=color_key,
            show=False,
            ax=axes[i, 1],
            size=point_size_factor / adata.shape[0],
            legend_loc="right margin",
            title="",
        )

        axes[i, 0].text(
            -0.05, 0.5, label, transform=axes[i, 0].transAxes,
            fontsize=14, rotation=90, va="center", ha="center",
        )
    
    for ax in axes.ravel():
        ax.set_aspect("equal", adjustable="box")
        ax.set_xlabel("")
        ax.set_ylabel("")

    plt.savefig(
        join(save_path, "UMAP_integrated_harmony2.png"),
        dpi=150, bbox_inches="tight"
    )
    plt.close()

In [40]:
_plot_integration_comparison(
            adata, save_path=save_path / "plots" , umap_key= f"X_umap_harmony_20_50", batch_key="slide", point_size_factor=300000
        )

# tbc here
- add plotting code above to script

In [None]:
# Save result
logger.info("Saving integrated object...")
if "fov" not in adata.obs.columns:
    adata.obs["fov"] = ""
adata.obs["fov"] = adata.obs["fov"].astype(str)
output_path = save_path / "adatas"
output_path.mkdir(parents=True, exist_ok=True)
adata.write(output_path / "adata_integrated.h5ad.gz", compression="gzip")
logger.info("Done.")

In [4]:
%%time
logger.info("Loading data...")
adata = sc.read_h5ad(save_path / "adatas" / "adata_integrated.h5ad.gz")

2025-09-04 11:44:15,699 [INFO]: Loading data...


CPU times: user 35.6 s, sys: 4.15 s, total: 39.8 s
Wall time: 42.9 s


In [5]:
save_path

PosixPath('/dss/dssfs03/pn52re/pn52re-dss-0001/cellseg-benchmark/analysis/foxf2/Cellpose_1_nuclei_model')

In [14]:
adata.obs.drop(columns=["spt_region"], inplace=True, errors="ignore")
