In [1]:
import logging
from pathlib import Path
import anndata as ad
import numpy as np
import pandas as pd
import pandas_plink as pp
import scanpy as sc
from sklearn.preprocessing import StandardScaler
from omegaconf import OmegaConf
from tqdm.auto import trange
from time import time
from anndata.utils import asarray
from sklearn.preprocessing import quantile_transform as sk_quantile_transform

import numpy as np
import scipy.linalg as la
import scipy.stats as st


log = logging.getLogger(__name__)

DATA = Path("/home/lollo/Work/hackathon/data/Yazar_OneK1K") 

In [2]:
# DATA
def column_normalize(X):
    assert X.ndim == 2  # noqa: PLR2004
    return (X - X.mean(0)) / (X.std(0) * np.sqrt(X.shape[1]))


def load_gene_annotations():
    log.warning("Loading gene annotations")

    ann_path = DATA / "onek1k_gene_annotations.pk"
    if ann_path.exists():
        return pd.read_pickle(ann_path)
    else:
        gene_anns = sc.queries.biomart_annotations(
            "hsapiens",
            ["ensembl_gene_id", "start_position", "end_position", "chromosome_name", "strand"],
            host="grch37.ensembl.org",
        ).set_index("ensembl_gene_id")
        gene_anns = gene_anns.rename(
            columns={
                "start_position": "egene_start",
                "end_position": "egene_end",
                "chromosome_name": "chrom",
                "strand": "egene_strand",
            }
        )
        gene_anns["tss"] = np.where(
            gene_anns["egene_strand"] == 1, gene_anns["egene_start"], gene_anns["egene_end"]
        )
        gene_anns.to_pickle(ann_path)
    return gene_anns


def process_sdata(cfg, sdata):
    log.warning("Filtering single cell AnnData")

    # filter cell type
    sdata = sdata[sdata.obs.cell_label.isin(cfg.celltype_filter)].copy()

    # require min cells per individual
    cells_per_ind = sdata.obs.groupby("individual", observed=True).size()
    keep_indis = cells_per_ind[cells_per_ind >= cfg.min_cells_per_ind].index
    sdata = sdata[sdata.obs.individual.isin(keep_indis)]

    # filter non-standard genes
    gene_anns = load_gene_annotations()
    keep_genes = gene_anns.loc[sdata.var.index, "chrom"].isin(map(str, range(1, 23)))
    sdata = sdata[:, keep_genes]

    # collect all info in var
    sdata.var = sdata.var.merge(
        gene_anns[gene_anns.columns.difference(sdata.var.columns)],
        how="left",
        left_index=True,
        right_index=True,
    )
    sdata.var = sdata.var.rename(columns={"GeneSymbol": "egene_symbol"})
    sdata.var.index.name = "egene_id"

    # drop unnecessary HVG columns
    columns_to_drop = ["features", "highly_variable", "means", "dispersions", "dispersions_norm"]
    sdata.var = sdata.var.drop(columns=columns_to_drop, errors="ignore")

    # make sure genes are ordered by chrom
    sdata = sdata[:, sdata.var.chrom.sort_values().index]
    
    # lorenzo begin: computing pca
    if "X_pca" not in sdata.obsm.keys():
        sc.pp.pca(sdata)
    # lorenzo end
    sdata.obsm["X_pca"] = column_normalize(sdata.obsm["X_pca"][:, : cfg.n_cellstate_pcs])

    return sdata


def compute_mdatas(cfg, sdata):
    log.warning("Creating pseudobulk mdatas")

    obs = sdata.obs.groupby("individual", observed=True)[["sex", "age"]].first()
    obs.index = obs.index.astype(str)
    obs["age_std"] = StandardScaler().fit_transform(obs.age.values.reshape(-1, 1))
    obs[["sex1", "sex2"]] = np.eye(2)[(obs.sex.values - 1)]

    mdatas = {}
    keep_genes = np.full(sdata.shape[1], True)
    for ct in ["all"] + list(sdata.obs.cell_label.unique()):
        _sdata = sdata if ct == "all" else sdata[sdata.obs.cell_label == ct]

        mdata = sc.get.aggregate(_sdata, by="individual", func="mean")
        mdata.X = mdata.layers["mean"]
        keep_genes &= (mdata.X > 0).mean(axis=0) > cfg.min_frac_expressed
        mdata.obs = obs.loc[mdata.obs.index]

        sc.pp.highly_variable_genes(mdata, n_top_genes=5000)
        sc.tl.pca(mdata, n_comps=cfg.n_expr_batch_pcs)
        mdata.obsm["X_pca"] = column_normalize(mdata.obsm["X_pca"])

        mdatas[ct] = mdata
    for ct in list(mdatas.keys()):
        mdatas[ct] = mdatas[ct][:, keep_genes]
    return mdatas


def load_F_and_gene_cis_snps(cfg, egene_ann, mdata, gdata=None):
    if gdata is None or gdata.uns["chrom"] != egene_ann.chrom:
        # UNCOMMENT ME
        # cfg["bed_file_path"] = cfg["bed_file_path_template"].format(chrom=egene_ann.chrom)
        # bedfile = str(DATA / cfg.bed_file_path)
        # pcfile = str(DATA / cfg.genetic_pc_path)
        # gdata = read_plink(bedfile, pcfile, num_pcs=cfg.n_genetic_pcs)[mdata.obs.index]
        # gdata.uns["chrom"] = egene_ann.chrom
        # lorenzo begin: 
        bedfile = Path("/home/lollo/Work/hackathon/data/Yazar_OneK1K/OneK1K_imputation_post_qc_r2_08/plink/")
        bedfile = str(bedfile / cfg["bed_file_path_template"].format(chrom=egene_ann.chrom))
        pcfile = Path("/home/lollo/Work/hackathon/data/Yazar_OneK1K/OneK1K_imputation_post_qc_r2_08/prunedir")
        pcfile = Path("/home/lollo/Work/hackathon/data/Yazar_OneK1K/OneK1K_imputation_post_qc_r2_08/filter_vcf_r08")
        pcfile = str(pcfile / "chr{chrom:}.dose.filtered.R2_0.8.filtered.pruned".format(chrom=egene_ann.chrom))
        print(f"{pcfile=}, {bedfile=}")
        gdata = read_plink(bedfile)[mdata.obs.index]
        gdata.uns["chrom"] = egene_ann.chrom
        # lorenzo end

    F = np.concatenate(
        [
            mdata.obs[["sex1", "sex2", "age_std"]],
            gdata.obs,
            mdata.obsm["X_pca"][:, : cfg.n_expr_batch_pcs],
        ],
        axis=1,
    )
    F[:, 2:] = column_normalize(F[:, 2:])

    cis_start = max(egene_ann.egene_start - cfg.cis_margin, 0)
    cis_end = egene_ann.egene_end + cfg.cis_margin
    keep_snps = (gdata.var.pos > cis_start) & (gdata.var.pos < cis_end)

    if keep_snps.sum() == 0:
        log.warning("Skipping %s: no cis SNPs", egene_ann.name)
        return gdata, None, None

    cis_gdata = gdata[:, keep_snps]
    G = pd.DataFrame(cis_gdata.X.compute(), columns=cis_gdata.var.index, index=cis_gdata.obs.index)
    return gdata, F, G


def read_plink(bfile, pcfile=None, num_pcs=None):
    # load geno
    bim, fam, bed = pp.read_plink(bfile)
    del bim["cm"]
    del bim["i"]
    bim = bim.set_index("snp")
    fam = fam.set_index("iid")[[]]

    if pcfile is not None:
        # read pcs
        usecols = None if num_pcs is None else np.arange(num_pcs + 2)
        dfpc = pd.read_csv(pcfile, sep=" ", header=None, usecols=usecols)
        dfpc.columns = ["fid", "iid"] + [f"PC{i+1}" for i in range(dfpc.shape[1] - 2)]
        del dfpc["fid"]
        dfpc = dfpc.set_index("iid")
        assert (dfpc.index == fam.index).all(), "Individuals in PC files not matching"
        fam = pd.concat([fam, dfpc], axis=1)

    return ad.AnnData(bed.T, obs=fam, var=bim)


In [3]:
# MODELS
class GWAS:
    r"""
    Linear model for univariate association testing
    between `P` phenotypes and `S` inputs (`P`x`S` tests)

    Parameters
    ----------
    Y : (`N`, `P`) ndarray
        outputs
    F : (`N`, `K`) ndarray
        covariates. If not specified, an intercept is assumed.
    """

    def __init__(self, Y, F=None):
        if F is None:
            F = np.ones((Y.shape[0], 1))
        self.Y = Y
        self.F = F
        self.df = Y.shape[0] - F.shape[1]
        self._fit_null()

    def _fit_null(self):
        """Internal functon. Fits the null model"""
        self.FY = np.dot(self.F.T, self.Y)
        self.FF = np.dot(self.F.T, self.F)
        self.YY = np.einsum("ip,ip->p", self.Y, self.Y)
        # calc beta_F0 and s20
        self.A0i = la.inv(self.FF)
        self.beta_F0 = np.dot(self.A0i, self.FY)
        self.s20 = (self.YY - np.einsum("kp,kp->p", self.FY, self.beta_F0)) / self.df

    def process(self, G, verbose=False):
        r"""
        Fit genotypes one-by-one.

        Parameters
        ----------
        G : (`N`, `S`) ndarray
            inputs
        verbose : bool
            verbose flag.
        """
        t0 = time()
        # precompute some stuff
        GY = np.dot(G.T, self.Y)
        GG = np.einsum("ij,ij->j", G, G)
        FG = np.dot(self.F.T, G)

        # Let us denote the inverse of Areml as
        # Ainv = [[A0i + m mt / n, m], [mT, n]]
        A0iFG = np.dot(self.A0i, FG)
        n = 1.0 / (GG - np.einsum("ij,ij->j", FG, A0iFG))
        M = -n * A0iFG
        self.beta_F = (
            self.beta_F0[:, None, :]
            + np.einsum("ks,sp->ksp", M, np.dot(M.T, self.FY)) / n[None, :, None]
        )
        self.beta_F += np.einsum("ks,sp->ksp", M, GY)
        self.beta_g = np.einsum("ks,kp->sp", M, self.FY)
        self.beta_g += n[:, None] * GY

        # sigma
        self.s2 = self.YY - np.einsum("kp,ksp->sp", self.FY, self.beta_F)
        self.s2 -= GY * self.beta_g
        self.s2 /= self.df

        # dlml and pvs
        self.lrt = -self.df * np.log(self.s2 / self.s20)
        self.pv = st.chi2(1).sf(self.lrt)

        t1 = time()
        if verbose:
            print("Tested for %d variants in %.2f s" % (G.shape[1], t1 - t0))

    def getPv(self):
        """
        Get pvalues

        Returns
        -------
        pv : ndarray
        """
        return self.pv

    def getBetaSNP(self):
        """
        get effect size SNPs

        Returns
        -------
        beta : ndarray
        """
        return self.beta_g

    def getLRT(self):
        """
        get lik ratio test statistics

        Returns
        -------
        lrt : ndarray
        """
        return self.lrt

    def getBetaSNPste(self):
        """
        get standard errors on betas

        Returns
        -------
        beta_ste : ndarray
        """
        beta = self.getBetaSNP()
        pv = self.getPv()
        z = np.sign(beta) * np.sqrt(st.chi2(1).isf(pv))
        ste = beta / z
        return ste


def quantile_transform(
    x: np.ndarray,
    seed: int = 1,
) -> np.ndarray:
    """
    Gaussian quantile transform for values in a pandas Series.    :param x: Input pandas Series.
    :type x: pd.Series
    :param seed: Random seed.
    :type seed: int
    :return: Transformed Series.
    :rtype: pd.Series    .. note::
        “nan” values are kept
    """
    np.random.seed(seed)
    x_transform = x.copy()
    if isinstance(x_transform, pd.Series):
        x_transform = x_transform.to_numpy()
    is_nan = np.isnan(x_transform)
    n_quantiles = np.sum(~is_nan)
    x_transform[~is_nan] = sk_quantile_transform(
        x_transform[~is_nan].reshape([-1, 1]),
        n_quantiles=n_quantiles,
        subsample=n_quantiles,
        output_distribution="normal",
        copy=True,
    )[:, 0]
    return x_transform

def run_gwas(*, egene, mdata, F, G, suffix=None):
    # E_pb = gaussianize(asarray(mdata[:, [egene]].X) + 1e-4 * np.random.randn(mdata.shape[0], 1))
    E_pb = quantile_transform(asarray(mdata[:, [egene]].X) + 1e-4 * np.random.randn(mdata.shape[0], 1))
    gwas = GWAS(Y=E_pb, F=F)
    gwas.process(G=G)

    snp_idx = gwas.getPv().argmin()

    def get_top_snp(arr):
        return arr.ravel()[snp_idx].item()

    rdict = {
        "snp": G.columns[snp_idx],
        "pv": get_top_snp(gwas.getPv()),
        "beta": get_top_snp(gwas.getBetaSNP()),
        "beta_ste": get_top_snp(gwas.getBetaSNPste()),
        "stat": get_top_snp(gwas.getLRT()),
    }

    if suffix is not None:
        rdict = {f"{k}_{suffix}": v for k, v in rdict.items()}
    return rdict

In [4]:
# RUNS
def run_eqtl_mapping(cfg, mdatas, sdata):
    log.warning("Running GWAS")
    result_list = []

    offset = int(cfg.offset) * cfg.n_genes
    gdata = None
    for gene_idx in trange(offset, offset + cfg.n_genes, desc="Genes"):
        if gene_idx >= sdata.var.shape[0]:
            break
        egene_ann = sdata.var.iloc[gene_idx]
        egene = egene_ann.name

        gdata, F, G = load_F_and_gene_cis_snps(cfg, egene_ann, mdatas["all"], gdata)

        test_results = (
            {"egene": egene, "n_cis_snps": 0}
            if G is None
            else run_tests_on_gene(cfg, egene, sdata, mdatas, F, G)
        )
        result_list.append({**test_results, "gene_idx": gene_idx})

    rdf = pd.DataFrame(result_list)


# @catch_exceptions_as_dict
def run_tests_on_gene(cfg, egene, sdata, mdatas, F, G):
    res = {"egene": egene, "n_cis_snps": G.shape[1]}
    # run_gwas(egene=egene, mdata=mdatas["all"], F=F, G=G, suffix="all", out=res)
    run_gwas(egene=egene, mdata=mdatas["all"], F=F, G=G, suffix="all")

    if cfg.run_coloc:
        run_coloc(
            egene=egene,
            mdata1=mdatas["CD8 ET"],
            mdata2=mdatas["NK"],
            F=F,
            G=G,
            suffix="betw",
            out=res,
        )
    Gt = G[[res["snp_all"]]]
    for ct, mdata in mdatas.items():
        ppct = ct.replace(" ", "_").lower()

        if ct != "all":
            if cfg.run_locus:
                run_gwas(egene=egene, mdata=mdata, F=F, G=G, suffix=f"{ppct}_locus", out=res)

            if cfg.run_coloc:
                run_coloc(
                    egene=egene, mdata1=mdatas["all"], mdata2=mdata, F=F, G=G, suffix=ppct, out=res
                )

            run_gwas(egene=egene, mdata=mdata, F=F, G=Gt, suffix=ppct, out=res)

            if cfg.run_ivpblrt:
                run_pb_lrt(cfg=cfg, egene=egene, mdata=mdata, F=F, G=Gt, suffix=ppct, out=res)

        if cfg.run_ivgwas:
            run_gwas_iv(egene=egene, mdata=mdata, F=F, G=Gt, suffix=ppct, out=res)

    if cfg.run_mil:
        fit_cis_mixmil(cfg=cfg, egene=egene, sdata=sdata, F=F, G=Gt, out=res)
    return res


In [5]:
# scdata_path = "/home/lollo/Work/hackathon/data/Yazar_OneK1K/OneK1K_cohort_gene_expression_matrix_14_celltypes.h5ad.gz"
scdata_path = "/home/lollo/Work/hackathon/data/Yazar_OneK1K/debug_OneK1K_cohort_gene_expression_matrix_14_celltypes.h5ad"

conf_yaml = """
# process_sdata config
celltype_filter: ["CD4 ET", ]
min_cells_per_ind: 10
n_cellstate_pcs: 50

# compute_mdatas config
min_frac_expressed: 0.1

# run_eqtl mapping config
offset: 100
n_genes: 100

# load_F_and_gene_cis_snps config
bed_file_path_template: "chr{chrom:}.dose.filtered.R2_0.8"
# genetic_pc_path: "/home/lollo/Work/hackathon/data/Yazar_OneK1K/OneK1K_imputation_post_qc_r2_08/plink"
genetic_pc_path: "/home/lollo/Work/hackathon/data/Yazar_OneK1K/OneK1K_imputation_post_qc_r2_08/pcdir/wgs.dose.filtered.R2_0.8"
n_genetic_pcs: 20
n_expr_batch_pcs: 15
cis_margin: 1_000_000

# run_test_on_gene config
run_coloc: False
run_locus: False
run_iivpblrt: False
run_iv_gwas: False
run_mil: False
"""

conf = OmegaConf.create(conf_yaml)

In [6]:

log.warning("Starting cis pipeline")
# sdata = sc.read_h5ad(RUNS / cfg.sc_adata_path)
sdata = sc.read_h5ad(scdata_path)

log.warning("Loaded single cell AnnData with shape %s", sdata.shape)

sdata = process_sdata(conf, sdata)

mdatas = compute_mdatas(conf, sdata)

sdata = sdata[:, sdata.var_names.isin(mdatas["all"].var_names)]


Starting cis pipeline
Loaded single cell AnnData with shape (25908, 32738)
Filtering single cell AnnData
Loading gene annotations
  adata.obsm[key_obsm] = X_pca
Creating pseudobulk mdatas
  return x * y
  var = mean_sq - mean**2
  var = mean_sq - mean**2
  return x * y
  var = mean_sq - mean**2
  var = mean_sq - mean**2


In [7]:
# offset = 100
# n_genes = 100
# n_genetic_pcs = 15

# for gene_idx in trange(offset, offset + n_genes, desc="Genes"):
#     if gene_idx >= sdata.var.shape[0]:
#         break
#     egene_ann = sdata.var.iloc[gene_idx]
#     egene = egene_ann.name
#     # lorenzo begin: 
#     bedfile = Path("/home/lollo/Work/hackathon/data/Yazar_OneK1K/OneK1K_imputation_post_qc_r2_08/plink/")
#     bedfile = str(bedfile / "chr{chrom:}.dose.filtered.R2_0.8".format(chrom=egene_ann.chrom))
#     # pcfile = Path("/home/lollo/Work/hackathon/data/Yazar_OneK1K/OneK1K_imputation_post_qc_r2_08/prunedir")
#     # pcfile = Path("/home/lollo/Work/hackathon/data/Yazar_OneK1K/OneK1K_imputation_post_qc_r2_08/filter_vcf_r08")
#     # pcfile = str(pcfile / "chr{chrom:}.dose.filtered.R2_0.8.filtered.pruned".format(chrom=egene_ann.chrom))
#     # pcfile = ("/home/lollo/Work/hackathon/data/Yazar_OneK1K/OneK1K_imputation_post_qc_r2_08/pcdir/wgs.dose.filtered.R2_0.8.filtered.pruned.fam")
#     pcfile = ("/home/lollo/Work/hackathon/data/Yazar_OneK1K/OneK1K_imputation_post_qc_r2_08/pcdir/wgs.dose.filtered.R2_0.8.filtered.pruned")
#     # pcfile = ("/home/lollo/Work/hackathon/data/Yazar_OneK1K/OneK1K_imputation_post_qc_r2_08/pcdir/wgs.dose.filtered.R2_0.8.filtered.pruned.bnosex")
#     # pcfile = ("/home/lollo/Work/hackathon/data/Yazar_OneK1K/OneK1K_imputation_post_qc_r2_08/pcdir/wgs.dose.filtered.R2_0.8.filtered.pruned.bim")
#     print(f"{pcfile=}, {bedfile=}")
#     # lorenzo end
#     # gdata = read_plink(bedfile, pcfile, num_pcs=n_genetic_pcs)[mdata["all"].obs.index]
#     gdata = read_plink(bedfile, num_pcs=n_genetic_pcs)[mdatas["all"].obs.index]

In [8]:

log.warning("Run eqtl mapping on adata\n %s", str(sdata))
run_eqtl_mapping(conf, mdatas, sdata)

logging.warning("Finished cis pipeline")

Run eqtl mapping on adata
 View of AnnData object with n_obs × n_vars = 1332 × 11910
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'pool', 'individual', 'percent.mt', 'latent', 'nCount_SCT', 'nFeature_SCT', 'cell_type', 'cell_label', 'sex', 'age'
    var: 'egene_symbol', 'chrom', 'egene_end', 'egene_start', 'egene_strand', 'tss'
    uns: 'pca'
    obsm: 'X_pca'
    varm: 'PCs'
Running GWAS


Genes:   0%|          | 0/100 [00:00<?, ?it/s]



pcfile='/home/lollo/Work/hackathon/data/Yazar_OneK1K/OneK1K_imputation_post_qc_r2_08/filter_vcf_r08/chr6.dose.filtered.R2_0.8.filtered.pruned', bedfile='/home/lollo/Work/hackathon/data/Yazar_OneK1K/OneK1K_imputation_post_qc_r2_08/plink/chr6.dose.filtered.R2_0.8'


Mapping files: 100%|██████████| 3/3 [01:12<00:00, 24.02s/it]
  gdata.uns["chrom"] = egene_ann.chrom


KeyError: 'snp_all'