# scRNA-seq batch integration and cell-type annotation
related to sup fig1
## Overview

**Inputs (under `./data/`):** `GSE149614.h5ad`, `GSE151530.h5ad`, `GSE156625.h5ad`, `skrx2fz79n.h5ad` (Mendeley ID: skrx2fz79n).  
Each input must contain **raw counts** in `.X` (or provide `layers['counts']`), and must include a per-cell **sample identity** column in `.obs` (configured below).

**Outputs (under `./outputs/`):**
- Integrated AnnData: `./outputs/HCC_integrated_harmony.h5ad`
- Key figures: `./outputs/figures/umap_*.pdf`, `./outputs/figures/dotplot_markers_*.pdf`
- Cluster markers (Wilcoxon): `./outputs/rank_genes_groups_leiden_res1p2.csv`
- Run metadata: `./outputs/config.json`, `./outputs/filtering_summary.json`


In [None]:
from __future__ import annotations

import json
import random
import warnings
from pathlib import Path

import numpy as np
import pandas as pd

import anndata as ad
import scanpy as sc

warnings.simplefilter("ignore", category=UserWarning)
sc.set_figure_params(dpi=120, frameon=False)

# =========================
# CONFIG
# =========================
CONFIG = {
    "random_seed": 0,

    "dataset_key": "dataset",
    "sample_key": "sample",
    "min_genes_per_cell": 300,
    "max_genes_per_cell": 6000,
    "max_pct_mito": 20.0,
    "min_cells_per_gene": 3,

    "target_sum": 1e4,

    "n_hvg": 2000,
    "hvg_flavor": "seurat_v3",
    "n_pcs": 50,
    "harmony_n_pcs": 20,

    "n_neighbors": 15,
    "leiden_resolution": 1.2,

    "rank_genes_method": "wilcoxon",

    "run_scrublet": True,
    "remove_predicted_doublets": True,
    "scrublet_expected_doublet_rate": 0.06,

    "exclude_clusters": [],            # e.g. ["7","13"]

    "cluster_to_celltype": {},         # e.g. {"0": "T", "1": "Myeloid", ...}
}

SAMPLE_KEY_FALLBACKS = ["S_ID", "orig.ident", "patient", "donor", "sample_id"]

DATA_DIR = Path("./data")
OUT_DIR = Path("./outputs")
FIG_DIR = OUT_DIR / "figures"
OUT_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)

INPUTS = {
    "GSE151530": DATA_DIR / "GSE151530.h5ad",
    "GSE156625": DATA_DIR / "GSE156625.h5ad",
    "GSE149614": DATA_DIR / "GSE149614.h5ad",
    "skrx2fz79n": DATA_DIR / "skrx2fz79n.h5ad",
}

SEED = int(CONFIG["random_seed"])
RES_TAG = str(CONFIG["leiden_resolution"]).replace(".", "p")
random.seed(SEED)
np.random.seed(SEED)

print("scanpy:", sc.__version__)
print("anndata:", ad.__version__)
try:
    import harmonypy as hm
    import harmonypy
    print("harmonypy:", harmonypy.__version__)
except Exception:
    print("harmonypy: not available (required for Harmony integration)")
try:
    import scrublet
    import scrublet as scr
    print("scrublet:", getattr(scrublet, "__version__", "unknown"))
except Exception:
    print("scrublet: not available (optional; required if CONFIG['run_scrublet']=True)")

print("DATA_DIR:", DATA_DIR.resolve())
print("OUT_DIR :", OUT_DIR.resolve())


## Data loading

In [None]:
adatas = []
keys = []

missing_files = [str(p) for p in INPUTS.values() if not p.exists()]
if missing_files:
    raise FileNotFoundError(
        "Missing input .h5ad file(s) under ./data/. Expected:\n- "
        + "\n- ".join(missing_files)
    )

for ds, path in INPUTS.items():
    a = sc.read_h5ad(path)
    a.var_names_make_unique()
    adatas.append(a)
    keys.append(ds)

adata = ad.concat(
    adatas,
    join="inner",
    label=CONFIG["dataset_key"],
    keys=keys,
    index_unique="-",
    merge="same",
)

# Ensure a raw-count layer exists
if "counts" not in adata.layers:
    adata.layers["counts"] = adata.X.copy()

# Validate (or auto-fallback) sample key
sample_key = CONFIG["sample_key"]
if sample_key not in adata.obs.columns:
    for cand in SAMPLE_KEY_FALLBACKS:
        if cand in adata.obs.columns:
            print(f"INFO: sample_key '{sample_key}' not found; using fallback '{cand}'.")
            sample_key = cand
            break
    else:
        cols = list(adata.obs.columns)
        raise KeyError(
            f"Missing sample identity column '{CONFIG['sample_key']}' in adata.obs.\n"
            f"Available obs columns (first 40): {cols[:40]}"
        )
CONFIG["sample_key"] = sample_key

# Lightweight sanity checks
print("Merged AnnData:", adata)
print("Datasets:", adata.obs[CONFIG["dataset_key"]].value_counts().to_dict())
print("Sample key:", CONFIG["sample_key"], "| n_samples =", adata.obs[CONFIG["sample_key"]].nunique())


## Workflow

1) Concatenate datasets on shared genes (inner join).  
2) QC filtering: 300–6,000 detected genes per cell and <20% mitochondrial fraction; remove genes detected in <3 cells.  
3) Normalize counts to 10,000 per cell and log1p-transform.  
4) Select HVGs (n=2,000) using Seurat v3 procedure on raw counts with dataset-of-origin as batch covariate.  
5) PCA → Harmony integration (harmonypy) on the first 20 PCs with sample identity as batch covariate → neighbors/UMAP/Leiden (**resolution = 1.2**).  
6) Doublet detection with Scrublet and optional removal of predicted doublets.  
7) Marker discovery (Wilcoxon) and **manual major cell-type annotation** by inspecting cluster markers/dotplots and mapping Leiden clusters to broad lineages (`adata.obs['majortype']`).  
8) Export marker tables, key figures, and the final integrated `.h5ad` object + run metadata.


In [None]:
# =========================
# QC, normalization, HVGs, PCA, Harmony integration, UMAP, Leiden
# =========================
from scipy import sparse

# --- QC ---
Xc = adata.layers["counts"]
var_upper = adata.var_names.str.upper()
mt_mask = var_upper.str.startswith("MT-").to_numpy()

if sparse.issparse(Xc):
    total_counts = np.asarray(Xc.sum(axis=1)).ravel()
    n_genes = np.asarray((Xc > 0).sum(axis=1)).ravel()
    mt_counts = np.asarray(Xc[:, mt_mask].sum(axis=1)).ravel() if mt_mask.any() else np.zeros(adata.n_obs)
else:
    total_counts = Xc.sum(axis=1)
    n_genes = (Xc > 0).sum(axis=1)
    mt_counts = Xc[:, mt_mask].sum(axis=1) if mt_mask.any() else np.zeros(adata.n_obs)

pct_mt = mt_counts / np.maximum(total_counts, 1) * 100.0
adata.obs["total_counts"] = total_counts
adata.obs["n_genes_by_counts"] = n_genes
adata.obs["pct_counts_mt"] = pct_mt

qc_mask = (
    (adata.obs["n_genes_by_counts"] >= CONFIG["min_genes_per_cell"])
    & (adata.obs["n_genes_by_counts"] <= CONFIG["max_genes_per_cell"])
    & (adata.obs["pct_counts_mt"] < CONFIG["max_pct_mito"])
)
n0 = adata.n_obs
adata = adata[qc_mask].copy()
print(f"QC filter cells: {n0} -> {adata.n_obs}")

Xc = adata.layers["counts"]
if sparse.issparse(Xc):
    gene_ncells = np.asarray((Xc > 0).sum(axis=0)).ravel()
else:
    gene_ncells = (Xc > 0).sum(axis=0)
keep_genes = gene_ncells >= int(CONFIG["min_cells_per_gene"])
g0 = adata.n_vars
adata = adata[:, keep_genes].copy()
print(f"Filter genes (min_cells={CONFIG['min_cells_per_gene']}): {g0} -> {adata.n_vars}")

# --- Scrublet ---
if CONFIG["run_scrublet"]:
    try:
        import scrublet as scr
    except Exception as e:
        raise ImportError("Scrublet is required when CONFIG['run_scrublet']=True.") from e

    Xc = adata.layers["counts"]
    scores = np.full(adata.n_obs, np.nan, dtype=float)
    preds = np.zeros(adata.n_obs, dtype=bool)

    sample_key = CONFIG["sample_key"]
    for s in adata.obs[sample_key].unique():
        idx = np.where(adata.obs[sample_key].to_numpy() == s)[0]
        if idx.size < 50:
            # too small to run robustly; mark as non-doublet by default
            continue

        Xm = Xc[idx]
        scrub = scr.Scrublet(
            Xm,
            expected_doublet_rate=float(CONFIG["scrublet_expected_doublet_rate"]),
        )
        sc_score, sc_pred = scrub.scrub_doublets()
        scores[idx] = sc_score
        preds[idx] = sc_pred

    adata.obs["doublet_score"] = scores
    adata.obs["predicted_doublet"] = preds

    if CONFIG["remove_predicted_doublets"]:
        n0 = adata.n_obs
        adata = adata[~adata.obs["predicted_doublet"].fillna(False)].copy()
        print(f"Remove predicted doublets: {n0} -> {adata.n_obs}")

Xc = adata.layers["counts"]
if sparse.issparse(Xc):
    Xc = Xc.tocsr()
    libsize = np.asarray(Xc.sum(axis=1)).ravel()
    scale = float(CONFIG["target_sum"]) / np.maximum(libsize, 1)
    Xn = Xc.multiply(scale[:, None]).tocsr()
    Xn = Xn.copy()
    Xn.data = np.log1p(Xn.data)
else:
    libsize = Xc.sum(axis=1)
    scale = float(CONFIG["target_sum"]) / np.maximum(libsize, 1)
    Xn = np.log1p(Xc * scale[:, None])

adata.layers["log1p"] = Xn
adata.X = adata.layers["log1p"]

try:
    sc.pp.highly_variable_genes(
        adata,
        flavor=str(CONFIG["hvg_flavor"]),
        n_top_genes=int(CONFIG["n_hvg"]),
        batch_key=CONFIG["dataset_key"],
        layer="counts",
    )
except TypeError:
    adata_tmp = adata.copy()
    adata_tmp.X = adata_tmp.layers["counts"]
    sc.pp.highly_variable_genes(
        adata_tmp,
        flavor=str(CONFIG["hvg_flavor"]),
        n_top_genes=int(CONFIG["n_hvg"]),
        batch_key=CONFIG["dataset_key"],
    )
    adata.var["highly_variable"] = adata_tmp.var["highly_variable"].values

adata = adata[:, adata.var["highly_variable"]].copy()
print("HVGs retained:", adata.n_vars)

# --- PCA ---
sc.pp.pca(adata, n_comps=int(CONFIG["n_pcs"]), svd_solver="arpack")

# --- Harmony integration ---
try:
    import harmonypy as hm
except Exception as e:
    raise ImportError("harmonypy is required for Harmony integration.") from e

Z = adata.obsm["X_pca"][:, : int(CONFIG["harmony_n_pcs"])]
meta = adata.obs[[CONFIG["sample_key"]]].copy()
ho = hm.run_harmony(Z, meta, vars_use=[CONFIG["sample_key"]], max_iter_harmony=50)
adata.obsm["X_pca_harmony"] = ho.Z_corr.T

sc.pp.neighbors(adata, n_neighbors=int(CONFIG["n_neighbors"]), use_rep="X_pca_harmony")
sc.tl.umap(adata, random_state=SEED)
cluster_key = "leiden"
sc.tl.leiden(adata, resolution=float(CONFIG["leiden_resolution"]), key_added=cluster_key)

if CONFIG["exclude_clusters"]:
    n0 = adata.n_obs
    adata = adata[~adata.obs[cluster_key].isin(list(CONFIG["exclude_clusters"]))].copy()
    print(f"Exclude clusters {CONFIG['exclude_clusters']}: {n0} -> {adata.n_obs}")

print("Computed:", "X_pca_harmony", "|", "UMAP", "|", cluster_key, f"(res={CONFIG['leiden_resolution']})", "| n_clusters =", adata.obs[cluster_key].nunique())
print("Next:", "rank_genes_groups → manual cluster→majortype mapping (see below).")


## Results & exports

This section exports (i) cluster marker tables, (ii) key UMAPs and marker dotplots, (iii) the manual cluster→cell-type mapping, and (iv) the final integrated `.h5ad` object plus run metadata under `./outputs/`.


In [None]:
# =========================
# Differential markers per Leiden cluster (Wilcoxon) + export
# =========================
cluster_key = "leiden"

sc.tl.rank_genes_groups(adata, groupby=cluster_key, method=str(CONFIG["rank_genes_method"]))

groups = (
    list(adata.obs[cluster_key].cat.categories)
    if hasattr(adata.obs[cluster_key], "cat")
    else sorted(adata.obs[cluster_key].unique())
)

dfs = []
for g in groups:
    df = sc.get.rank_genes_groups_df(adata, group=g)
    df.insert(0, "cluster", g)
    dfs.append(df)

markers_df = pd.concat(dfs, axis=0, ignore_index=True)

markers_csv = OUT_DIR / f"rank_genes_groups_{cluster_key}_res{RES_TAG}.csv"
markers_df.to_csv(markers_csv, index=False)

print("Saved:", markers_csv)


In [None]:
# =========================
# UMAPs + marker dotplots + manual cell-type annotation
# =========================
import matplotlib.pyplot as plt

cluster_key = "leiden"  # Leiden at resolution = CONFIG["leiden_resolution"] (here: 1.2)

def save_umap(color, fname, **kwargs):
    sc.pl.umap(adata, color=color, show=False, **kwargs)
    plt.savefig(FIG_DIR / fname, bbox_inches="tight")
    plt.close()

# --- UMAPs (QC/integration outputs) ---
save_umap([CONFIG["dataset_key"]], "umap_by_dataset.pdf", wspace=0.4)
save_umap([CONFIG["sample_key"]], "umap_by_sample.pdf", wspace=0.4)
save_umap([cluster_key], "umap_by_cluster.pdf", wspace=0.4)
if "predicted_doublet" in adata.obs.columns:
    save_umap(["predicted_doublet"], "umap_by_predicted_doublet.pdf", wspace=0.4)

# --- Canonical markers for broad lineages (used for manual annotation) ---
marker_genes_dict = {
    "B cell": ["CD79A", "CD79B", "MS4A1"],
    "Plasma": ["CD38", "XBP1"],
    "T cell": ["CD3D", "CD3E", "CD8A", "CD8B"],
    "Myeloid": ["CD163", "CD68", "LYZ"],
    "Epithelial": ["EPCAM", "KRT18", "KRT8"],
    "Endothelial": ["VWF", "PECAM1", "ENG"],
    "Fibroblast": ["DCN", "COL1A1", "COL1A2", "LUM"],
}

marker_genes_dict = {k: [g for g in v if g in adata.var_names] for k, v in marker_genes_dict.items()}
marker_genes_dict = {k: v for k, v in marker_genes_dict.items() if len(v) > 0}

# Dotplot by Leiden clusters (res=1.2)
sc.pl.dotplot(
    adata,
    marker_genes_dict,
    groupby=cluster_key,
    standard_scale="var",
    show=False,
)
plt.savefig(FIG_DIR / f"dotplot_markers_by_{cluster_key}.pdf", bbox_inches="tight")
plt.close()

# =========================
# Manual annotation (cluster → broad lineage)
# =========================

merge_dict = {
    '0':  'Fibro',
    '1':  'T&NK',
    '2':  'T&NK',
    '3':  'T&NK',
    '4':  'Endo',
    '5':  'Endo',
    '6':  'Endo',
    '7':  'Hepato',
    '8':  'Hepato',
    '9':  'Hepato',
    '10': 'Myeloid',
    '11': 'Plasma',
    '12': 'Myeloid',
    '13': 'Myeloid',
    '14': 'Hepato',
    '15': 'Hepato',
    '16': 'Hepato',
    '17': 'Hepato',
    '18': 'Myeloid',
    '19': 'T&NK',
    '20': 'T&NK',
    '21': 'T&NK',
    '22': 'T&NK',
    '23': 'T&NK',
    '24': 'T&NK',
    '25': 'T&NK',
    '26': 'Myeloid',
    '27': 'Hepato',
    '28': 'T&NK',
    '29': 'Myeloid',
    '30': 'Myeloid',
    '31': 'B',
    '32': 'Hepato',
    '33': 'Hepato',
    '34': 'T&NK',
    '35': 'Hepato',
    '36': 'Hepato',
    '37': 'T&NK',
    '38': 'Myeloid',
    '39': 'Hepato',
    '40': 'Hepato',
    '41': 'T&NK',
}


# Store mapping into CONFIG for provenance
CONFIG["cluster_to_celltype"] = merge_dict

raw_cluster = adata.obs[cluster_key].astype(str)
mapped = raw_cluster.map(merge_dict)

adata.obs["majortype"] = mapped.fillna("Unassigned").astype("category")

# Save a UMAP colored by manual majortype labels
save_umap(["majortype"], "umap_by_majortype.pdf", wspace=0.4)

# Dotplot by manual majortype labels (+ dendrogram for readability)
try:
    sc.tl.dendrogram(adata, groupby="majortype")
    dendro = True
except Exception:
    dendro = False

sc.pl.dotplot(
    adata,
    marker_genes_dict,
    groupby="majortype",
    dendrogram=dendro,
    standard_scale="var",
    show=False,
)
plt.savefig(FIG_DIR / "dotplot_markers_by_majortype.pdf", bbox_inches="tight")
plt.close()

print("Saved figures to:", FIG_DIR)
print("majortype categories:", list(adata.obs["majortype"].cat.categories))


In [None]:
# =========================
# exports
# =========================

out_h5ad = OUT_DIR / "HCC_integrated_harmony.h5ad"
adata.write_h5ad(out_h5ad, compression="gzip")