In [6]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import pdist, squareform
from pathlib import Path

In [8]:
current_dir = Path.cwd()
outdir = current_dir / "celltype_associated"
outdir.mkdir(parents=True, exist_ok=True)

In [10]:
msi_csv = "/storage/scratch1/2/yhao306/sma-brain/data-sm/V11L12-038_A1_MSI_raw_counts_spotRows.csv"
msi_df = pd.read_csv(msi_csv, index_col=0)
coord_csv = "/storage/scratch1/2/yhao306/sma-brain/data-sm/V11L12-038_A1_spot_coordinates.csv"
coords_df = pd.read_csv(coord_csv, index_col=0) 
msi_csv = "../msi_norm.csv"
msi_norm = pd.read_csv(msi_csv, index_col=0)
msi_norm_hvg = "../msi_norm_hvg.csv"
msi_norm_hvg = pd.read_csv(msi_norm_hvg, index_col=0)
rna_norm_hvg = "../rna_norm_hvg.csv"
rna_norm_hvg = pd.read_csv(rna_norm_hvg, index_col=0)
rna_csv = "../rna_norm.csv"
rna_norm = pd.read_csv(rna_csv, index_col=0)
msi_hvg_null = pd.read_csv("msi_null_hvg.csv", index_col=0)
region_csv = "/storage/scratch1/2/yhao306/sma-brain/data-sm/sma/V11L12-038/V11L12-038_A1/output_data/V11L12-038_A1_RNA/outs/RegionLoupe.csv"
region_df = pd.read_csv(region_csv, index_col=0)
region_df.index = region_df.index + "_1"
print(region_df)
common_barcodes = rna_norm.index
region_df = region_df.loc[common_barcodes]
print(region_df)
region_df.index.name = 'barcode'
coords_df.index.name = 'barcode'
coords_small = coords_df[['x', 'y']]
df_merged = region_df.join(coords_small, how='inner')
print(df_merged.head())
deconv = pd.read_csv("../deconvolution.tsv", sep="\t", index_col=0)
df_final  = df_merged.join(deconv, how="inner")
df_final  = df_final.reset_index()
df_final = df_final.set_index("index")
cell_types = [c for c in df_final.columns if c not in ['RegionLoupe','x','y']]
prop_df = df_final[cell_types]
print(prop_df)

                     RegionLoupe
Barcode                         
AAACAAGTATCTCCCA-1_1          CP
AAACAGCTTTCAGAAG-1_1         CTX
AAACAGGGTCTATATT-1_1         NaN
AAACATTTCCCGGATT-1_1          cc
AAACCCGAACGAAATC-1_1          CP
...                          ...
TTGTTCAGTGTGCTAC-1_1         ACB
TTGTTCTAGATACGCT-1_1         CTX
TTGTTGTGTGTCAAGA-1_1         PAL
TTGTTTCCATACAACT-1_1         CTX
TTGTTTGTGTAAATTC-1_1         NaN

[2856 rows x 1 columns]
                     RegionLoupe
TGGGCACAAACAGAAC-1_1         NaN
TTGGAGTCTCCCTTCT-1_1         NaN
AATATCAAGGTCGGAT-1_1         CTX
CACCCTTTCCTCGCTC-1_1         CTX
CTGGTTCAACGCATCA-1_1         CTX
...                          ...
GCTCAACCTCTTAGAG-1_1         CTX
GGGCGAATTTCTCCAC-1_1         CTX
TGAAGTAGCTTACGGA-1_1         CTX
TACATCTTGTTTCTTG-1_1         CTX
CAGTAGATGATGTCCG-1_1         CTX

[2386 rows x 1 columns]
                     RegionLoupe   x  y
barcode                                
TGGGCACAAACAGAAC-1_1         NaN  46  8
TTGGA

In [4]:
import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist, squareform

def _zscore_cols(A):
    mu = A.mean(axis=0, keepdims=True)
    sd = A.std(axis=0, ddof=0, keepdims=True)
    sd = np.where(sd == 0, 1.0, sd)
    return (A - mu) / sd

def _build_W(xy, l="dmin", norm="row"):
    D = squareform(pdist(xy, metric="euclidean"))
    d_nonzero = D[D > 0]
    if isinstance(l, str) and l == "dmin":
        if d_nonzero.size == 0:
            raise ValueError("there is no dmin.")
        lval = float(np.min(d_nonzero))
    else:
        lval = float(l)
    W = np.exp(-(D**2) / (2.0 * lval * lval))
    np.fill_diagonal(W, 1.0)
    if norm == "row":
        W = W / W.sum(axis=1, keepdims=True)
    elif norm == "global":
        W = W / W.sum()
    return W

def compute_sci(Y_df, prop_df, coords_df, normalized=True, l="dmin", w_norm="row"):

    common = prop_df.index.intersection(Y_df.index).intersection(coords_df.index)
    if len(common) == 0:
        raise ValueError("prop_df / Y_df / coords_df do not have common barcodes。")
    Cdf = prop_df.loc[common].astype(float)
    Ydf = Y_df.loc[common].astype(float)
    xy  = coords_df.loc[common, ["x","y"]].to_numpy()

    W  = _build_W(xy, l=l, norm=w_norm)
    Cz = _zscore_cols(Cdf.to_numpy())
    Yz = _zscore_cols(Ydf.to_numpy())

    WY      = W @ Yz 
    SCI_raw = Cz.T @ WY 

    if not normalized:
        return pd.DataFrame(SCI_raw, index=Cdf.columns, columns=Ydf.columns)

    xvar = np.sum(Cz**2, axis=0)        # K
    yvar = np.sum(Yz**2, axis=0)        # M
    den  = np.sqrt(xvar[:, None] * yvar[None, :])
    den  = np.where(den == 0, 1.0, den)
    SCI_corr = SCI_raw / den

    return pd.DataFrame(SCI_corr, index=Cdf.columns, columns=Ydf.columns)

def compute_sci_pair(msi_norm_hvg, msi_hvg_null, prop_df, coords_df,
                     normalized=True, l="dmin", w_norm="row"):


    common_mets = msi_norm_hvg.columns.intersection(msi_hvg_null.columns)

    Y_real = msi_norm_hvg[common_mets]
    Y_null = msi_hvg_null[common_mets]

    S_real = compute_sci(Y_real, prop_df, coords_df, normalized=normalized, l=l, w_norm=w_norm)
    S_null = compute_sci(Y_null, prop_df, coords_df, normalized=normalized, l=l, w_norm=w_norm)
    return S_real, S_null


In [12]:
SCI_real, SCI_null = compute_sci_pair(
    msi_norm_hvg=msi_norm_hvg,
    msi_hvg_null=msi_hvg_null,
    prop_df=prop_df,
    coords_df=coords_df,
    normalized=True,
    l="dmin",
    w_norm="row"
)

SCI_real.to_csv(outdir/"SCI_celltype_vs_metabolite_real.csv")
SCI_null.to_csv(outdir/"SCI_celltype_vs_metabolite_null.csv")

In [15]:
import numpy as np
import pandas as pd

def _bh_fdr_rowwise(p_mat: np.ndarray) -> np.ndarray:
    p = np.asarray(p_mat, dtype=float)
    K, M = p.shape
    q = np.empty_like(p)
    for i in range(K):
        pi = p[i].copy()
        order = np.argsort(pi, kind="mergesort")
        ranks = np.arange(1, M + 1, dtype=float)
        qi = pi[order] * M / ranks
        qi = np.minimum.accumulate(qi[::-1])[::-1]
        out = np.empty_like(qi)
        out[order] = qi
        q[i] = np.clip(out, 0, 1)
    return q

def significant_by_ct_single_null(
    SCI_real: pd.DataFrame,
    SCI_null: pd.DataFrame,
    alpha: float = 0.05,
    two_sided: bool = True,
    save_prefix: str = "SCI_ct_metab_test"
):

    rows = SCI_real.index.intersection(SCI_null.index)
    cols = SCI_real.columns.intersection(SCI_null.columns)
    R = SCI_real.loc[rows, cols].astype(float)
    N = SCI_null.loc[rows, cols].astype(float)

    K, M = R.shape
    real_vals = R.to_numpy()   # K×M
    null_vals = N.to_numpy()   # K×M

    if two_sided:
        real_use = np.abs(real_vals)
        null_use = np.abs(null_vals)
    else:
        real_use = real_vals
        null_use = null_vals

    p_emp = np.empty_like(real_vals, dtype=float)
    for i in range(K):
        base = null_use[i]
        for j in range(M):
            p_emp[i, j] = (np.sum(base >= real_use[i, j]) + 1.0) / (M + 1.0)

    q_bh = _bh_fdr_rowwise(p_emp)


    long = []
    for i, ct in enumerate(rows):
        for j, met in enumerate(cols):
            long.append({
                "cell_type":  ct,
                "metabolite": met,
                "sci":        float(real_vals[i, j]),
                "p_emp":      float(p_emp[i, j]),
                "q_bh":       float(q_bh[i, j]),
            })
    long_df = pd.DataFrame(long)
    long_df["significant"] = long_df["q_bh"] < alpha

    per_ct = {}
    sig_rows = []
    for ct, sub in long_df.groupby("cell_type"):
        sig = sub[sub["significant"]].sort_values(["q_bh", "p_emp", "sci"])
        per_ct[ct] = sig
        if not sig.empty:
            sig_rows.append(sig)

    long_path = f"{save_prefix}_long.csv"
    long_df.to_csv(outdir/long_path, index=False)

    return long_df, per_ct

In [None]:
long_df, per_ct = significant_by_ct_single_null(
    SCI_real, SCI_null,
    alpha=0.05,
    two_sided=True,
    save_prefix="SCI_ct_metab_test"
)
per_ct["Oligodendrocytes"].head(10)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

try:
    long_df
except NameError:
    long_df = pd.read_csv("SCI_ct_metab_singleNull_long.csv")

all_cts = long_df["cell_type"].unique()
sig_counts = (long_df.assign(significant=long_df["significant"].astype(bool))
                      .groupby("cell_type")["significant"]
                      .sum()
                      .reindex(all_cts, fill_value=0)
                      .sort_values(ascending=False)
                      .astype(int))

sig_counts.to_csv("sig_metabolite_count_per_ct.csv", header=["n_significant"])

plt.figure(figsize=(12, 4.8))
x = range(len(sig_counts))
plt.bar(x, sig_counts.values)
plt.xticks(x, sig_counts.index, rotation=60, ha="right")
plt.ylabel("Number of significant metabolites")
plt.title("Significant metabolites per cell type (BH-FDR < 0.05)")
plt.tight_layout()
plt.savefig("sig_metabolite_count_per_ct.png", dpi=200)
plt.show()