In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad
import decoupler as dc
import numpy as np
import seaborn as sns
import matplotlib as mpl
from matplotlib.colors import Normalize, to_hex
from matplotlib.patches import Patch


# read the data

In [None]:
mouse_all = ad.read_zarr('/mnt/ssd/atlases/Mouse_Atlas_Harmonised.zarr')
human_all = ad.read_zarr('/mnt/ssd/atlases/Human_Atlas_Harmonised.zarr')

In [None]:
mouse_mal = mouse_all[mouse_all.obs.Level_4.str.contains("Malignant")].copy()
human_mal = human_all[human_all.obs.Level_4.str.contains("Malignant")].copy()

In [None]:
hallmark_h = dc.op.hallmark(organism='human')

In [None]:
hm_dict = {}
for hm in hallmark_h.source.unique():
    hm_dict[hm] = hallmark_h[hallmark_h.source==hm].target.unique().tolist()

In [None]:
mouse_mal.var["h"] = [g.upper() for g in mouse_mal.var_names]
mouse_all.var["h"] = [g.upper() for g in mouse_all.var_names]

In [None]:
mouse_mal.var_names = mouse_mal.var["h"]
mouse_all.var_names = mouse_all.var["h"]


In [None]:
mouse_mal = mouse_mal[:, intersect].copy()
mouse_all = mouse_all[:, intersect].copy()
human_mal = human_mal[:, intersect].copy()
human_all = human_mal[:, intersect].copy()

In [None]:
mouse_mal = mouse.copy()
human_mal = human.copy()

# pathway overlap

In [None]:
from typing import Dict, List, Literal, Optional
try:
    from anndata import AnnData
except ImportError as e:
    raise ImportError("This function requires `anndata`. Install via `pip install anndata`.") from e

Correlation = Literal["pearson", "spearman"]

def gene_set_coexpression_from_anndatas(
    adata_true: AnnData,
    adata_recon: AnnData,
    geneset_dict: Dict[str, List[str]],
    *,
    overlap_threshold: int = 5,
    min_cells: int = 20,
    threshold: float = 0.0,
    correlation: Correlation = "pearson",
) -> Dict[str, float]:
    """
    Compute gene-set coexpression similarity between 'true' and 'reconstructed' AnnData objects.

    Parameters
    ----------
    adata_true, adata_recon : AnnData
        AnnData with X shaped (n_cells, n_genes). Gene names must be in .var_names.
    geneset_dict : dict[str, list[str]]
        Mapping of gene-set name -> list of gene symbols (matching .var_names).
    overlap_threshold : int
        Minimum overlapping (and expressed) genes required to score a gene set.
    min_cells : int
        A gene counts as "expressed" if it's non-zero in at least this many cells
        (required in BOTH datasets).
    threshold : float
        Pairs with |corr_true| < threshold are ignored (set to NaN) when averaging.
    correlation : {"pearson","spearman"}
        Correlation measure used for coexpression.

    Returns
    -------
    dict
        {geneset_name: score, ..., "average": mean_of_all_set_scores}
        where score = mean over upper triangle of:
            1 - |C_true - C_recon| / 2
        within that gene set.
    """
    # --- helpers --------------------------------------------------------------
    def _to_dense_float(M):
        import scipy.sparse as sp
        return M.toarray().astype(float) if sp.issparse(M) else M.astype(float)

    def _rank_cols(M: np.ndarray) -> np.ndarray:
        # average ranks for ties, column-wise
        R = np.empty_like(M, dtype=float)
        for j in range(M.shape[1]):
            x = M[:, j]
            order = np.argsort(x, kind="mergesort")
            ranks = np.empty_like(order, dtype=float)
            ranks[order] = np.arange(1, len(x) + 1, dtype=float)
            vals, inv = np.unique(x, return_inverse=True)
            for v in range(len(vals)):
                mask = (inv == v)
                if mask.any():
                    avg = ranks[mask].mean()
                    ranks[mask] = avg
            R[:, j] = ranks
        return R

    def _corr_mat(M: np.ndarray, kind: Correlation) -> np.ndarray:
        if M.shape[1] <= 1:
            return np.ones((1, 1), dtype=float)
        if kind == "pearson":
            return np.corrcoef(M, rowvar=False)
        elif kind == "spearman":
            R = _rank_cols(M)
            return np.corrcoef(R, rowvar=False)
        else:
            raise ValueError("correlation must be 'pearson' or 'spearman'")

    # --- align genes by name --------------------------------------------------
    genes_true = np.array(adata_true.var_names, dtype=str)
    genes_recon = np.array(adata_recon.var_names, dtype=str)
    shared = np.intersect1d(genes_true, genes_recon, assume_unique=False)

    if shared.size == 0:
        return {"average": np.nan}

    # index maps
    idx_true = {g: i for i, g in enumerate(genes_true)}
    idx_recon = {g: i for i, g in enumerate(genes_recon)}
    t_cols = np.array([idx_true[g] for g in shared], dtype=int)
    r_cols = np.array([idx_recon[g] for g in shared], dtype=int)

    X_t = _to_dense_float(adata_true.X[:, t_cols])
    X_r = _to_dense_float(adata_recon.X[:, r_cols])
    shared_names = shared  # keeps same order as columns above

    # --- filter to expressed genes in BOTH matrices ---------------------------
    expr_t = (X_t != 0).sum(axis=0) >= min_cells
    expr_r = (X_r != 0).sum(axis=0) >= min_cells
    keep = expr_t & expr_r

    if not np.any(keep):
        return {"average": np.nan}

    X_t = X_t[:, keep]
    X_r = X_r[:, keep]
    kept_names = shared_names[keep]
    name_to_kept = {g: i for i, g in enumerate(kept_names)}

    # --- score each gene set --------------------------------------------------
    results: Dict[str, float] = {}

    for gset_name, genes in geneset_dict.items():
        idxs = [name_to_kept[g] for g in genes if g in name_to_kept]
        idxs = sorted(set(idxs))
        if len(idxs) < max(2, overlap_threshold):
            continue

        A = X_t[:, idxs]
        B = X_r[:, idxs]

        C_true = _corr_mat(A, correlation)
        C_recon = _corr_mat(B, correlation)

        # ignore low-magnitude pairs per true corr
        mask_low = np.abs(C_true) < threshold
        diff = np.abs(C_true - C_recon)
        diff[mask_low] = np.nan

        sim = 1.0 - diff / 2.0
        np.fill_diagonal(sim, np.nan)

        tri = sim[np.triu_indices(sim.shape[0], k=1)]
        score = np.nanmean(tri) if tri.size else np.nan
        results[gset_name] = float(score)

    results["average"] = float(np.nanmean(list(results.values()))) if results else np.nan
    return results


In [None]:
results = {}

In [None]:
models = ['orthotopic', 'endogenous']
for ct in mouse.obs["Level_4"].unique():
    input_m = mouse_mal[mouse_mal.obs.Level_4 == ct] 
    input_h = human_mal[human_mal.obs.Level_4 == ct]
    for md in models:
        ad_m = input_m[input_m.obs.Model == md] 
        results[f"{ct}_{md}"] = gene_set_coexpression_from_anndatas(ad_m, input_h, hm_dict)

In [None]:
results_to_plot = {cell_type: values['average'] for cell_type, values in results.items()}
orthotopic = {}
endogenous = {}

for key, value in results_to_plot.items():
    if key.endswith('_orthotopic'):
        cell_type = key.replace('_orthotopic', '')
        orthotopic[cell_type] = value
    elif key.endswith('_endogenous'):
        cell_type = key.replace('_endogenous', '')
        endogenous[cell_type] = value

In [None]:
df = pd.DataFrame({
    'cell_type': list(orthotopic.keys()),
    'Orthotopic': [orthotopic[k] for k in orthotopic.keys()],
    'Endogenous': [endogenous[k] for k in orthotopic.keys()]
}).sort_values('Orthotopic', ascending=False)

In [None]:
x = np.arange(len(df))
bar_width = 0.35

plt.figure(figsize=(18, 6))
plt.bar(x - bar_width/2, df['Orthotopic'], width=bar_width, color='steelblue', label='Orthotopic')
plt.bar(x + bar_width/2, df['Endogenous'], width=bar_width, color='orange', label='Endogenous')

wrapped_labels = ["\n".join(wrap(lbl.replace("Malignant Cell - ", ""), 15)) for lbl in df["cell_type"]]
plt.xticks(x, wrapped_labels, rotation=0, fontsize=9)
plt.ylabel("Average score", fontsize=11)
plt.title("Pathway Averages by Cell Type: Orthotopic vs Endogenous", fontsize=13, pad=10)
plt.ylim(0.8, 1.0)

plt.grid(False)
plt.legend(frameon=False, bbox_to_anchor=(1.02, 1), loc='upper left')

plt.tight_layout(pad=2)
plt.savefig("/mnt/kkf2/Cell/AG-Saur/KKF2/Daniele/pdac_atlas_figures/figure6/pathway_correlation.png", dpi = 300)
plt.show()


# celltype wassertein

In [None]:
import ot

In [None]:
def sinkhorn_divergence(X, Y, reg=0.05, unbalanced=False, tau=1.0):
    a = np.ones(X.shape[0]) / X.shape[0]
    b = np.ones(Y.shape[0]) / Y.shape[0]
    M = ot.dist(X, Y, metric='euclidean') ** 2  # cost matrix
    if unbalanced:
        return ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg, tau)
    else:
        # Sinkhorn divergence = 2*OT - self terms
        ot_xy = ot.sinkhorn2(a, b, M, reg)
        ot_xx = ot.sinkhorn2(a, a, ot.dist(X, X)**2, reg)
        ot_yy = ot.sinkhorn2(b, b, ot.dist(Y, Y)**2, reg)
        return ot_xy - 0.5*ot_xx - 0.5*ot_yy


In [None]:
def sinkhorn_w2_gene_space(X, Y, reg=0.01):
    # X, Y: (n, g) z-scored gene matrices; cost = ||x-y||^2
    a = np.ones(X.shape[0]) / X.shape[0]
    b = np.ones(Y.shape[0]) / Y.shape[0]
    M = ot.dist(X, Y, metric='euclidean') ** 2
    # Sinkhorn divergence (unbiased): 2*OT(X,Y) - OT(X,X) - OT(Y,Y)
    ot_xy = ot.sinkhorn2(a, b, M, reg)
    ot_xx = ot.sinkhorn2(a, a, ot.dist(X, X, metric='euclidean')**2, reg)
    ot_yy = ot.sinkhorn2(b, b, ot.dist(Y, Y, metric='euclidean')**2, reg)
    return ot_xy - 0.5*ot_xx - 0.5*ot_yy

def subsample_equal(X, Y, max_cells=2000, rng=0):
    rs = np.random.RandomState(rng)
    nx, ny = X.shape[0], Y.shape[0]
    k = min(max_cells, nx, ny)
    if k < 10:
        return None, None
    ix = rs.choice(nx, size=k, replace=False)
    iy = rs.choice(ny, size=k, replace=False)
    return X[ix], Y[iy]

In [None]:
cell_types = human_mal.obs['Level_4'].unique().tolist()


In [None]:
rows = []
for ct in cell_types:
    print(f"processing {ct}")
    X = human_mal[human_mal.obs['Level_4'] == ct].layers['counts']
    X = X.toarray() if hasattr(X, "toarray") else X
    Y = mouse_mal[mouse_mal.obs['Level_4'] == ct].layers['counts']
    Y = Y.toarray() if hasattr(Y, "toarray") else Y

    Xs, Ys = subsample_equal(X, Y, max_cells=2000, rng=0)
    D_sinkhorn = float(sinkhorn_w2_gene_space(Xs, Ys, reg=0.1))
    rows.append((ct, D_sinkhorn))


In [None]:
rows

# celltype distributions analysis

In [None]:
def celltype_composition(adata, group_col='Sample_ID', celltype_col='Level_4'):
    ct_counts = (
        adata.obs
        .groupby([group_col, celltype_col])
        .size()
        .unstack(fill_value=0)
    )
    ct_frac = ct_counts.div(ct_counts.sum(axis=1), axis=0)
    return ct_frac

mouse_comp = celltype_composition(mouse)
human_comp = celltype_composition(human)
common_cts = set(mouse_comp.columns).union(human_comp.columns)
mouse_comp = mouse_comp.reindex(columns=common_cts, fill_value=0)
human_comp = human_comp.reindex(columns=common_cts, fill_value=0)

In [None]:
bulk = sc.read("/mnt/storage/Shrey/PDAC_Downstream/ps_adata_all_cells_bulk_filtered.h5ad")

In [None]:
def make_sample_id(sid, ds):
    if sid.startswith(ds + "_"):
        return sid[len(ds) + 1:]
    if sid.startswith(ds):
        return sid[len(ds):]
    return sid

bulk.obs["Sample_ID"] = [
    make_sample_id(sid, ds) for sid, ds in zip(bulk_obs["Dataset_ID"], bulk_obs["Dataset"])
]
mapping = {
    'Epi-High': '1',
    'Acinar-Like': '1',
    'Hypoxia/Senescence-High': '2',
    'EMT-Start': '3',
    'Tip':'4',
    'Mes-High':'5'
}
bulk.obs['leiden_cell_comp'] = bulk.obs['Cluster_Names'].map(mapping)

In [None]:
from scipy.spatial.distance import jensenshannon

dist_mat = pd.DataFrame(
    np.zeros((len(mouse_comp), len(human_comp))),
    index=mouse_comp.index,
    columns=human_comp.index
)

for m in mouse_comp.index:
    for h in human_comp.index:
        dist_mat.loc[m, h] = jensenshannon(mouse_comp.loc[m], human_comp.loc[h])
        
model_map = {k:v for k,v in zip(mouse_all.obs.Sample_ID, mouse_all.obs.Model)}
dist_mat["Model"] = dist_mat.index.map(model_map)
leiden_map = {k:v for k,v in zip(bulk.obs.Sample_ID, bulk.obs.leiden_cell_comp)}
emt_map = {k:v for k,v in zip(bulk.obs.Sample_ID, bulk.obs.mesenchymal_ecm_markers_score)}
leiden_row = pd.Series(dist_mat.columns.map(leiden_map), index=dist_mat.columns, name="leiden")
emt_row = pd.Series(dist_mat.columns.map(emt_map), index=dist_mat.columns, name="EMT")
dist_mat_l = pd.concat([dist_mat, leiden_row.to_frame().T, emt_row.to_frame().T])
dist_mat_l = dist_mat_l.loc[:, ~dist_mat_l.loc["leiden"].isna()]
cols = list(dist_mat_l.columns)
cols.append("Model")
dist_mat = dist_mat.loc[:, cols]
dist_mat_l.lo

In [None]:
model_colors = {"endogenous": "#1f77b4", "orthotopic": "#ff7f0e"}
row_colors = mouse_meta.loc[mat.index, "Model"].map(model_colors)

unique_leiden = sorted(pd.Series(leiden_series.loc[mat.columns]).astype(str).unique())
leiden_palette = dict(zip(unique_leiden, sns.color_palette("tab10", len(unique_leiden))))
col_colors_leiden = pd.Series(leiden_series.loc[mat.columns].astype(str)).map(leiden_palette)

emt_vals = pd.to_numeric(emt_series.loc[mat.columns], errors="coerce")
norm_emt = Normalize(vmin=0.0, vmax=0.15)
cmap_emt = sns.mpl_palette("viridis", as_cmap=True) 
col_colors_emt = pd.Series(
    [to_hex(cmap_emt(norm_emt(v))) if np.isfinite(v) else "#D3D3D3" for v in emt_vals],
    index=mat.columns,
    name="EMT"
)

col_colors_df = pd.DataFrame({"Leiden": col_colors_leiden, "EMT": col_colors_emt}, index=mat.columns)

g = sns.clustermap(
    mat,
    row_cluster=True,
    col_cluster=True,
    row_colors=row_colors,
    col_colors=col_colors_df,
    cmap="coolwarm",
    vmin=0.25, vmax=0.75,
    figsize=(12, 8),
    cbar_pos=None,
    dendrogram_ratio=(.15, .15),
    colors_ratio=(.03, .03)
)

sm_emt = mpl.cm.ScalarMappable(norm=norm_emt, cmap=cmap_emt)
sm_emt.set_array([])
cax_emt = g.fig.add_axes([.95, 0.25, 0.015, 0.5])  # [left, bottom, width, height]
cbar_emt = g.fig.colorbar(sm_emt, cax=cax_emt, orientation="vertical")
cbar_emt.set_label("EMT score (0–0.3)", fontsize=9)

sm_hm = mpl.cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=0.25, vmax=0.75), cmap="coolwarm")
sm_hm.set_array([])
cax_hm = g.fig.add_axes([1.05, 0.25, 0.015, 0.5])
cbar_hm = g.fig.colorbar(sm_hm, cax=cax_hm, orientation="vertical")
cbar_hm.set_label("Heatmap", fontsize=9)

handles = [Patch(facecolor=c, label=f"Model: {k}") for k, c in model_colors.items()]
handles += [Patch(facecolor=c, label=f"Leiden {k}") for k, c in leiden_palette.items()]
g.fig.legend(handles=handles, loc="center left", bbox_to_anchor=(0.80, 0.5), frameon=False, title="Annotations")

g.ax_heatmap.set_xticklabels([])
g.ax_heatmap.set_yticklabels([])
g.ax_heatmap.tick_params(left=False, bottom=False)
g.ax_heatmap.set_xlabel("")
g.ax_heatmap.set_ylabel("")
plt.savefig("/mnt/kkf2/Cell/AG-Saur/KKF2/Daniele/pdac_atlas_figures/figure6/cellcomp_corr.png", dpi = 300)

plt.show()


# DGE overlap

In [None]:
def _sanitize_obsm_varm(a: ad.AnnData):
    # Move any 1-D arrays out of .obsm/.varm
    for attr in ["obsm", "varm"]:
        mapping = getattr(a, attr)
        bad_keys = []
        for k, v in list(mapping.items()):
            # Convert pandas objects to numpy for ndim check
            arr = v.values if hasattr(v, "values") else v
            if getattr(arr, "ndim", None) != 2:
                bad_keys.append(k)
        for k in bad_keys:
            v = mapping[k]
            # Decide target: vectors tied to cells -> .obs, tied to genes -> .var
            if attr == "obsm":    # per-cell
                a.obs[k] = np.asarray(v).ravel()
            else:                 # per-gene
                a.var[k] = np.asarray(v).ravel()
            del mapping[k]

def _align_obsm_keys(a1: ad.AnnData, a2: ad.AnnData):
    common = set(a1.obsm.keys()) & set(a2.obsm.keys())
    for a in (a1, a2):
        for k in list(a.obsm.keys()):
            if k not in common:
                del a.obsm[k]

In [None]:
_sanitize_obsm_varm(mouse_mal)
_sanitize_obsm_varm(human_mal)
_align_obsm_keys(mouse_mal, human_mal)

In [None]:
all_together = ad.concat([human_mal, mouse_mal], label='Organism', keys=['human','mouse'])

In [None]:
all_together

In [None]:
all_together.obs['Model'] = mouse_mal.obs.Model.astype(str)
all_together.obs['Model'] = all_together.obs['Model'].replace(np.nan, 'human').astype('category')
all_together.obs['Model'].value_counts()

In [None]:
dge_list = {}
sc.tl.rank_genes_groups(all_together, groups=['endogenous'], reference='human', groupby = 'Model', layer='log_norm')
endog_df = sc.get.rank_genes_groups_df(all_together, group='endogenous')
sc.tl.rank_genes_groups(all_together, groups=['orthotopic'], reference='human', groupby = 'Model', layer='log_norm')
ortho_df = sc.get.rank_genes_groups_df(all_together, group='orthotopic')

In [None]:
merged = endog_df.merge(ortho_df, on="names", suffixes=("_endog", "_ortho"))

In [None]:
from scipy.stats import pearsonr, spearmanr
pearson_corr, _ = pearsonr(merged["scores_endog"], merged["scores_ortho"])
spearman_corr, _ = spearmanr(merged["scores_endog"], merged["scores_endog"])
print(f"Pearson r = {pearson_corr:.2f}, Spearman ρ = {spearman_corr:.2f}")

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

plt.figure(figsize=(6,6))
norm = plt.Normalize(vmin=-10, vmax=10)

# --- Scatter plot ---
scatter = sns.scatterplot(
    data=merged,
    x="scores_endog",
    y="scores_ortho",
    hue="logfoldchanges_endog",
    s=20,
    palette="coolwarm",
    edgecolor="black",
    linewidth=0.1,
    alpha=0.9,
    legend=False
)

# --- Add correlation line (on top) ---
sns.regplot(
    data=merged,
    x="scores_endog",
    y="scores_ortho",
    scatter=False,
    color="black",
    line_kws={"lw": .5, "ls": "--", "alpha": 0.5, "zorder": 10}
)

# Reference lines
plt.axhline(0, color='grey', lw=0.2, ls='--', zorder=1)
plt.axvline(0, color='grey', lw=0.2, ls='--', zorder=1)

# Titles and labels
plt.title(f"Correlation of DE genes\nPearson r = {pearson_corr:.2f}", fontsize=12)
plt.xlabel("Scores (endogenous vs human)", fontsize=11)
plt.ylabel("Scores (orthotopic vs human)", fontsize=11)

# --- Colorbar only ---
sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=plt.gca(), fraction=0.046, pad=0.04)
cbar.set_label("logFC", rotation=270, labelpad=15)

sns.despine()
plt.tight_layout()


plt.savefig(
    "/mnt/kkf2/Cell/AG-Saur/KKF2/Daniele/pdac_atlas_figures/figure6/deg_scores_corr.png",
    dpi=300,
    bbox_inches="tight"
)

plt.show()


### Ortho vs Endo

In [None]:
sc.tl.rank_genes_groups(mouse_tme, groupby = "Model", layer = "log_norm")

In [None]:
sc.pl.rank_genes_groups_matrixplot(mouse_all, values_to_plot = "logfoldchanges", cmap = "RdBu_r", min_logfoldchange=2, n_genes=40)

In [None]:
hallmark_h = dc.op.resource('MSigDB') # hallmark(organism='human')

In [None]:
net = hallmark_h.rename(columns={'genesymbol' : 'target', 'geneset': 'source'})
net = net.drop_duplicates(subset=['source', 'target'])


In [None]:
dc.mt.ulm(data=mouse_tme, net=net)

In [None]:
score = dc.pp.get_obsm(adata=adata, key="score_ulm")

In [None]:
sc.pl.matrixplot(
    adata=score,
    var_names=score.var_names,
    groupby="Model",
    dendrogram=True,
    standard_scale="var",
    colorbar_title="Z-scaled scores",
    cmap="RdBu_r",
)