# Scalable pipeline for computing LD matrix in big sample phenotype

## Aim

To extract the summary statistics and genotype on specific genomic regions and calculate their LD matrix.

## Pre-requisites

### Two way to use this pipelin in csglogin

`export PATH=/home/yh3455/miniconda3/bin:$PATH`

### Or insatll the following packages in your env

Make sure you install the pre-requisited before running this notebook:

```
pip install scipy
pip install torch
pip install dask
pip install liftover
pip install pandas-plink
pip install bgen-reader
```

### Input

- `--region-file`, including a list of regions
    - Each locus will be represented by one line in the region file with 3 columns chr, start, and end. e.g. `7 27723990 28723990`
- `--geno-path`, the path of a genotype inventory, which lists the path of all genotype file in `bgen` format or in `plink` format.
    - The list is a file with 2 columns: `chr genotype_file_chr.ext`. 
    - The first column is chromosome ID, the 2nd file is genotype for that chromosome.
    - When chromosome ID is 0, it implies that the genotype file contains all the genotypes.
- `--pheno-path`, the path of a phenotype. Only for one genotype data. If `None`, only `pld` will be calculated.
    - The phenotype file should have a column with the name `IID`, which is used to represent the sample ID.
- `--sumstats-path`, the path of the GWAS file, including all summary statistics (eg, $\hat{\beta}$, $SE(\hat{\beta})$ and p-values)
    - These summary statistics should contain at least these columns: `chrom, pos, ref, alt, snp_id, bhat, sbhat, p`
- `--unrelated-samples`, the file path of unrelated samples with a column named `IID`. If `None`, all samples will be considered unrelative.  
- `--cwd`, the path of output directory


- `--imp-geno-path`, the path of a genotype inventory, which lists the path of all genotype file in `bgen` format or in `plink` format.
    - The list is a file with 2 columns: `chr genotype_file_chr.ext`. 
    - The first column is chromosome ID, the 2nd file is genotype for that chromosome.
    - When chromosome ID is 0, it implies that the genotype file contains all the genotypes.
- `--imp-sumstats-path`, the path of the GWAS file, including all summary statistics (eg, $\hat{\beta}$, $SE(\hat{\beta})$ and p-values)
    - These summary statistics should contain at least these columns: `chrom, pos, ref, alt, snp_id, bhat, sbhat, p`
- `--imp-ref`, the reference genome if exome genotype and imputed genotype are different. If `None`, The two genotype data will be considered from the same  

### Output
- `rg_stat`, the reginonal summary stats
    - The rowname is the variant ID.
    - It should contain at least the following columns: `CHR, BP, SNP, ALT, REF, BETA, SE, Z, P`.
- `rg_geno`,the regional genotypes
    - The rowname is the variant ID, which should match with the rowname of `rg_stat`.
    - The column name is the sample's IID, which is sorted by the sample in phenotype.
- `pld`, the regional approximate population LD calculated by unrelated individuals
- `sld`, the regional approximate sample LD calcualted by unrelated individuals in a phenotype.

## Workflow codes

In [1]:
[global]
# Work directory where output will be saved to
parameter: cwd = path
# Region specifications
parameter: region_file = path
# Genotype file inventory
parameter: geno_path = path
# Phenotype path
parameter: pheno_path = path
# Sample file path, for bgen format
parameter: bgen_sample_path = path('.')
# Path to summary stats file
parameter: sumstats_path = path
# Path to summary stats format configuration
parameter: format_config_path = path('.')
# Path to samples of unrelated individuals
parameter: unrelated_samples = path
# Number of tasks to run in each job on cluster
parameter: job_size = int
# Number of tasks to run in each job on cluster
parameter: imp_geno_path = path
# Path to summary stats file
parameter: imp_sumstats_path = path
# The reference genome of imputed genotype data
parameter: imp_ref = str

fail_if(not region_file.is_file(), msg = 'Cannot find regions to extract. Please specify them using ``--region-file`` option.')
# Load all regions of interest. Each item in the list will be a region: (chr, start, end)
regions = list(set([tuple(x.strip().split()) for x in open(region_file).readlines() if x.strip()]))

In [1]:
[default_1 (export utils script)]
depends: Py_Module('torch'), Py_Module('numpy'), Py_Module('pandas'), Py_Module('dask'), Py_Module('scipy'), Py_Module('bgen_reader'), Py_Module('pandas_plink'), Py_Module('liftover'), Py_Module('xxhash')
parameter: scan_window = 500000
output: f'{cwd:a}/utils.py'
report: expand = '${ }', output=f'{cwd:a}/utils.py'
    import torch
    import numpy as np
    import pandas as pd
    from scipy.stats import norm
    import dask.array as da
    import dask.dataframe as dd
    from bgen_reader import open_bgen  
    from pandas_plink import read_plink
    from liftover import get_lifter
    from xxhash import xxh32 as xxh

    #functions to read sumstats
    def read_sumstat(file, config_file):
        try:
            sumstats = pd.read_csv(file, compression='gzip', header=0, sep='\t', quotechar='"')
        except:
            sumstats = pd.read_csv(file, header=0, sep='\t', quotechar='"')
        if config_file is not None:
            import yaml
            config = yaml.safe_load(open(config_file, 'r'))
            try:
                sumstats = sumstats.loc[:,list(config.values())]
            except:
                raise ValueError(f'According to {config_file}, input summary statistics should have the following columns: {list(config.values())}.')
            sumstats.columns = list(config.keys())
        sumstats.SNP = sumstats.SNP.apply(shorten_id)
        sumstats.CHR = sumstats.CHR.astype(int)
        sumstats.POS = sumstats.POS.astype(int)
        return sumstats

    def read_regenie(file):
        try:
            sumstats = pd.read_csv(file, compression='gzip', header=0, sep='\t', quotechar='"')
        except:
            sumstats = pd.read_csv(file, header=0, sep='\t', quotechar='"')
        sumstats.SNP = 'chr'+sumstats.CHR.astype(str) + ':' + sumstats.POS.astype(str) + ':' + sumstats.REF.astype(str) + ':' + sumstats.ALT.astype(str)
        sumstats.CHR = sumstats.CHR.astype(int)
        sumstats.POS = sumstats.POS.astype(int)
        return sumstats

    #util functions
    def shorten_id(x):
        return x if len(x) < 30 else f"{x.split('_')[0]}_{xxh(x).hexdigest()}"

    def regional_stats(sumstats, region):
        ss = sumstats[(sumstats.CHR == region[0]) & (sumstats.POS >= region[1]) & (sumstats.POS <= region[2])].copy()
        ss['Z'] = list(p2z(ss.P,ss.BETA))
        return ss


    def p2z(pval,beta,twoside=True):
        if twoside:
            pval = pval/2
        z=np.abs(norm.ppf(pval))
        ind=beta<0
        z[ind]=-z[ind]
        return z


    #functions to read genotype data
    def read_bgen(file, sample_file=None):
        bg = open_bgen(file,verbose=False)
        snp,aa0,aa1 = [],[],[]
        for c,p,alleles in zip(bg.chromosomes,bg.positions,bg.allele_ids):
            a0,a1 = alleles.split(',')
            aa0.append(a0)
            aa1.append(a1)
            snp.append(':'.join(['chr'+str(int(c)),str(p),a1,a0]))
        bg.bim = pd.DataFrame({'chrom':bg.chromosomes.astype(int),'snp':snp,'pos':bg.positions,'a0':aa0,'a1':aa1})
        if sample_file is not None:
            fam = pd.read_csv(sample_file, header=0, delim_whitespace=True, quotechar='"',skiprows=1)
            fam.columns = ['fid','iid','missing','sex']
            bg.fam = fam
        return bg

    def read_pl(file):
        (bim,fam,bed) = read_plink(file, verbose=False)
        geno = bed
        geno.bim = bim
        geno.fam = fam
        return geno

    def read_geno(geno_file):
        if geno_file.endswith('.bed'):
            geno = read_pl(geno_file[:-4])
        elif geno_file.endswith('.bgen'):
            sample_file = geno_file.replace('.bgen', '.sample')
            geno = read_bgen(geno_file,sample_file)
        else:
            raise ValueError('Plesae provide the genotype files with PLINK binary format or BGEN format')
        return geno

    #The function to find an overlap region between geno data with sumstat
    def geno_in_stat(geno,stat,notin=False):
        bim = geno.bim
        fam = geno.fam
        idx = bim.snp.isin(stat.SNP)
        if notin:
            idx = idx == False
        else:
            if sum(idx)!=stat.shape[0]:
                print("error mismatch between geno and stat")
        if isinstance(geno,da.core.Array):
            geno = geno[idx,:]
        else:
            int_idx = list(idx[idx].index)
            geno = bgen2dask(geno,int_idx,step=500).T
        geno.bim = bim[idx]
        geno.fam = fam
        return geno

    #The function to covert bgen to dask array
    def bgen2dask(bgen,index,step=500):
        genos = []
        n = len(index)
        for i in range(0,n,step):
            onecode_geno = bgen.read(index[i:min(n,i+step)])
            geno = onecode_geno.argmax(axis=2).astype(np.int8)
            genos.append(da.from_array(geno))
        return(da.concatenate(genos,axis=1))

    #The function to liftover bim
    def bim_liftover(bim,chainmap):
        new_bim = []
        for c,p,a0,a1 in zip(bim.chrom,bim.pos,bim.a0,bim.a1):
            new_c,new_p,_ = chainmap[int(c)][p][0]
            snp = ':'.join([new_c,str(new_p),a0,a1])
            new_bim.append([int(new_c[3:]),snp,new_p,a0,a1])
        new_bim = pd.DataFrame(new_bim,columns=['chrom','snp','pos','a0','a1'])
        return new_bim

    #The function to find an overlap samples between geno data with unr
    def geno_in_unr(geno,unr):
        bim = geno.bim
        fam = geno.fam
        idx = fam.iid.astype(str).isin(unr.IID.astype(str))
        geno = geno[:,idx]
        geno.bim = bim
        geno.fam = fam[idx]
        return geno

    #functions to calculate LD matrix
    def geno_LD(x,y=None,step=100):
        if y is None:
            dd = dask_corr(x,step)
            return(dict2mat(dd))
        else:
            dd = dask_corr_pair(x,y,step)
            return(dict2mat_pair(dd))

    def dask_corr(genos,step=100):
        #sample by snps (normalized)
        nsample = genos.shape[0]
        nsnp = genos.shape[1]
        da_corr = {}
        for i in range(0,nsnp,step):
            da_corr[i] = {}
            geno_i = genos[:,i:min(i+step,nsnp)].compute().astype(np.float64)
            geno_i = (geno_i - np.nanmean(geno_i,axis=0)[None,:])/np.nanstd(geno_i,axis=0)[None,:]
            geno_i = torch.from_numpy(geno_i)
            geno_i[torch.isnan(geno_i)] = 0
            chunk_i = da.from_array((torch.matmul(geno_i.T,geno_i)/nsample).numpy())
            da_corr[i][i]=chunk_i
            for j in range(i+step,nsnp,step):
                geno_j = genos[:,j:min(j+step,nsnp)].compute().astype(np.float64)
                geno_j = (geno_j - np.nanmean(geno_j,axis=0)[None,:])/np.nanstd(geno_j,axis=0)[None,:]
                geno_j = torch.from_numpy(geno_j)
                geno_j[torch.isnan(geno_j)] = 0
                cor_ij = da.from_array((torch.matmul(geno_i.T,geno_j)/nsample).numpy())
                da_corr[i][j]=cor_ij
        return da_corr

    def dict2mat(dd):
        da_mat=[]
        for i in dd.keys():
            rowi = []
            for j in dd.keys():
                if i>j:
                    rowi.append(dd[j][i].T)
                else:
                    rowi.append(dd[i][j])
            rowi = da.concatenate(rowi,axis=1)
            da_mat.append(rowi)
        return(da.concatenate(da_mat,axis=0))

    def dask_corr_pair(genos,pgenos,step=100):
        #sample by snps (normalized)
        nsample = genos.shape[0]
        nsnp = genos.shape[1]
        psample = pgenos.shape[0]
        psnp = pgenos.shape[1]
        if nsample != psample: print("error: sample not match")
        da_corr = {}
        for i in range(0,nsnp,step):
            da_corr[i] = {}
            geno_i = genos[:,i:min(i+step,nsnp)].compute().astype(np.float64)
            geno_i = (geno_i - np.nanmean(geno_i,axis=0)[None,:])/np.nanstd(geno_i,axis=0)[None,:]
            geno_i = torch.from_numpy(geno_i)
            geno_i[torch.isnan(geno_i)] = 0
            for j in range(0,psnp,step):
                geno_j = pgenos[:,j:min(j+step,psnp)].compute().astype(np.float64)
                geno_j = (geno_j - np.nanmean(geno_j,axis=0)[None,:])/np.nanstd(geno_j,axis=0)[None,:]
                geno_j = torch.from_numpy(geno_j)
                geno_j[torch.isnan(geno_j)] = 0
                cor_ij = da.from_array((torch.matmul(geno_i.T,geno_j)/nsample).numpy())
                da_corr[i][j]=cor_ij
        return da_corr

    def dict2mat_pair(dd):
        da_mat=[]
        for i in dd.keys():
            rowi = []
            for j in dd[0].keys():
                rowi.append(dd[i][j])
            rowi = da.concatenate(rowi,axis=1)
            da_mat.append(rowi)
        return(da.concatenate(da_mat,axis=0))

    def main(region,geno_path,sumstats_path,pheno_path,unr_path,imp_geno_path,imp_sumstats_path,imp_ref,output_sumstats,output_LD):

        print('1. Preprocess sumstats (regenie format) and extract it from a region')
        if pheno_path is not None:
            # Load phenotype file
            pheno = pd.read_csv(pheno_path, header=0, delim_whitespace=True, quotechar='"')
        if unr_path is not None:
            # Load unrelated sample file
            unr = pd.read_csv(unr_path, header=0, delim_whitespace=True, quotechar='"')  
        # Load the file of summary statistics and standardize it.
        exome_sumstats = read_regenie(sumstats_path)
        exome_geno = read_geno(geno_path)
        print('1.1. Region extraction')
        exome_sumstats = regional_stats(exome_sumstats,region)
        exome_geno = geno_in_stat(exome_geno,exome_sumstats)

        if imp_geno_path is not None:
            #two genotype data
            imput_sumstats = read_regenie(imp_sumstats_path)
            imput_geno = read_geno(imp_geno_path)   
            if imp_ref is None:
                imput_sumstats = regional_stats(imput_sumstats,region)
                imput_geno = geno_in_stat(imput_geno,imput_sumstats)
            else:
                print('1.2. LiftOver the region')
                from liftover import get_lifter
                hg38toimpref = get_lifter('hg38',imp_ref)
                imp_start = hg38toimpref[region[0]][region[1]][0][1]
                imp_end = hg38toimpref[region[0]][region[2]][0][1]
                imp_region = [region[0],imp_start,imp_end]
                imput_sumstats = regional_stats(imput_sumstats,imp_region)
                imput_geno = geno_in_stat(imput_geno,imput_sumstats)
                print('1.3. Regional SNPs Liftover')
                impreftohg38 = get_lifter(imp_ref,'hg38') #oppsite with hg38toimpref
                imput_geno.bim = bim_liftover(imput_geno.bim,impreftohg38)
                imput_sumstats.POS = list(imput_geno.bim.pos)
                imput_sumstats.SNP = list(imput_geno.bim.snp)
            print('1.1.1 Get exome unique sumstats and geno and Combine sumstats')
            exome_unique_snp_idx = exome_sumstats.SNP.isin(imput_sumstats.SNP)==False
            exome_sumstats_diff = exome_sumstats[exome_unique_snp_idx]
            sumstats = pd.concat([exome_sumstats_diff,imput_sumstats])
            exome_geno = geno_in_stat(exome_geno,imput_sumstats,notin=True)
        else:
            #one genotype data
            sumstats = exome_sumstats

        print('2. Remove relative samples')
        if unr_path is not None:
            exome_geno = geno_in_unr(exome_geno,unr)
            if imp_geno_path is not None:
                imput_geno = geno_in_unr(imput_geno,unr)
        else:
            print('Warning:There is no file of relative sample. All sample are included in computing LD matrix')

        if pheno_path is not None:
            pass #sld and pld

        print('3. Calculate LD matrix')
        if imp_geno_path is None:
            cor_da = geno_LD(exome_geno.T)
        else:
            xx = geno_LD(exome_geno.T)
            yy = geno_LD(imput_geno.T,step=500)

            imput_fam = imput_geno.fam
            imput_fam.index = list(imput_fam.iid.astype(str))
            imput_fam['i'] = list(range(imput_fam.shape[0]))
            imput_fam_comm = imput_fam.loc[list(exome_geno.fam.iid.astype(str))]
            imput_geno_comm=imput_geno[:,list(imput_fam_comm.i)]
            xy = geno_LD(exome_geno.T,imput_geno_comm.T,step=500)
            cor_da = da.concatenate([da.concatenate([xx,xy],axis=1),da.concatenate([xy.T,yy],axis=1)],axis=0)

        print('4. Output sumstats and LD matrix')
        index = list(sumstats.SNP.apply(shorten_id))
        sumstats.SNP = index
        sumstats.index = list(range(sumstats.shape[0]))
        sumstats.to_csv(output_sumstats, sep = "\t", header = True, index = True)

        corr = cor_da.compute()
        np.fill_diagonal(corr, 1)
        corr = pd.DataFrame(corr, columns=index)
        corr.to_csv(output_LD, sep = "\t", header = True, index = False)

## Extract data

This step runs in parallel for all loci listed in the region file (via `for_each`).

In [3]:
[default_2 (extract genotypes)]
depends: f'{cwd:a}/utils.py'
input: geno_path, pheno_path, sumstats_path, unrelated_samples, imp_geno_path,imp_sumstats_path,imp_ref, for_each = 'regions'
output: sumstats = f'{cwd:a}/{_regions[0]}_{_regions[1]}_{_regions[2]}/{sumstats_path:bn}_{_regions[0]}_{_regions[1]}_{_regions[2]}.sumstats.gz',
        genotype = f'{cwd:a}/{_regions[0]}_{_regions[1]}_{_regions[2]}/{sumstats_path:bn}_{_regions[0]}_{_regions[1]}_{_regions[2]}.genotype.gz',
        pld = f'{cwd:a}/{_regions[0]}_{_regions[1]}_{_regions[2]}/{sumstats_path:bn}_{_regions[0]}_{_regions[1]}_{_regions[2]}.pre_pop_ld.pickle',
        sld = f'{cwd:a}/{_regions[0]}_{_regions[1]}_{_regions[2]}/{sumstats_path:bn}_{_regions[0]}_{_regions[1]}_{_regions[2]}.pre_sample_ld.pickle'
task: trunk_workers = 1, trunk_size = job_size, walltime = '4h', mem = '60G', cores = 1, tags = f'{step_name}_{_output[0]:bn}'
python: expand = '${ }', input = f'{cwd:a}/utils.py', stderr = f'{_output[0]:n}.stderr', stdout = f'{_output[0]:n}.stdout'
    

    import os
    # output path files that we will need in our final version
    output_sumstats = ${_output['sumstats']:r}
    output_genotype = ${_output['genotype']:r}
    output_pld = ${_output['pld']:r}
    output_sld = ${_output['sld']:r}

    # this general path is used to create other temporary files that we need to calculate the ld matrices later on
    cwd = os.getcwd()
    output_general = '${cwd}/${_regions[0]}_${_regions[1]}_${_regions[2]}/${sumstats_path:bn}_${_regions[0]}_${_regions[1]}_${_regions[2]}'

    input_sample_path = ${bgen_sample_path:r}
    input_geno_path = ${_input[0]:r}
    input_pheno_path = ${_input[1]:r}
    input_sumstats_path = ${_input[2]:r}
    input_unrelated_samples = ${_input[3]:r}
    imp_geno_path = ${_input[4]:r}
    imp_sumstats_path = ${_input[5]:r}
    imp_ref =  ${_input[6]:r}
    
    input_format_config = ${format_config_path:r} if ${format_config_path.is_file()} else None

    
    # Load genotype file for the region of interest
    geno_inventory = dict([x.strip().split() for x in open(${_input[0]:r}).readlines() if x.strip()])
    chrom = "${_regions[0]}"
    if chrom.startswith('chr'):
        chrom = chrom[3:]
    if chrom not in geno_inventory:
        geno_file = geno_inventory['0']
    else:
        geno_file = geno_inventory[chrom]

    print(geno_file, input_sumstats_path, input_pheno_path, input_unrelated_samples,output_sumstats, output_pld)


    if not os.path.isfile(geno_file):
        # relative path
        if not os.path.isfile('${_input[0]:ad}/' + geno_file):
            raise ValueError(f"Cannot find genotype file {geno_file}")
        else:
            geno_file = '${_input[0]:ad}/' + geno_file


    region = (int(chrom), ${_regions[1]}, ${_regions[2]})

    imput_pheno_path = None
    imp_geno_path,imp_sumstats_path,imp_ref = None,None,None
    print(region, geno_file, input_sumstats_path, input_pheno_path, input_unrelated_samples,imp_geno_path,imp_sumstats_path,imp_ref,
                                output_sumstats, output_pld)
    main(region, geno_file, input_sumstats_path, input_pheno_path, input_unrelated_samples,imp_geno_path,imp_sumstats_path,imp_ref,
                                output_sumstats, output_pld)

region = [5,272741,1213528-900000]
geno_path = 'MWE_region_extraction/ukb23156_c5.merged.filtered.5_272741_1213528.bed'
sumstats_path = 'MWE_region_extraction/090321_UKBB_Hearing_aid_f3393_expandedwhite_6436cases_96601ctrl_PC1_2_f3393.regenie.snp_stats'
pheno_path = None
unr_path = 'MWE_region_extraction/UKB_genotypedatadownloaded083019.090221_sample_variant_qc_final_callrate90.filtered.extracted.white_europeans.filtered.092821_ldprun_unrelated.filtered.prune.txt'
imp_geno_path = 'MWE_region_extraction/ukb_imp_chr5_v3_05_272856_1213643.bgen'
imp_sumstats_path = 'MWE_region_extraction/100521_UKBB_Hearing_aid_f3393_expandedwhite_15601cases_237318ctrl_500k_PC1_PC2_f3393.regenie.snp_stats'
imp_ref = 'hg19'

output_sumstats = 'test.snp_stats'
output_LD = 'test_corr.csv'

main(region,geno_path,sumstats_path,pheno_path,unr_path,imp_geno_path,imp_sumstats_path,imp_ref,output_sumstats,output_LD)

 sos run /home/yh3455/Github/bioworkflows/GWAS/LD_merged_exo_imp.ipynb     default    --cwd /home/yh3455/Github/bioworkflows/GWAS/test    --region-file /home/dmc2245/UKBiobank/results/LD_clumping/092321_f3393_200Kexomes/090321_UKBB_Hearing_aid_f3393_expandedwhite_6436cases_96601ctrl_PC1_2_f3393.regenie.snp_stats.clumped_region    --pheno-path /home/dmc2245/UKBiobank/phenotype_files/hearing_impairment/090321_UKBB_Hearing_aid_f3393_expandedwhite_6436cases_96601ctrl_PC1_2.tsv    --geno-path /home/dmc2245/UKBiobank/data/exome_files/project_VCF/072721_run/plink/092321_UKBB_qc_exome_geno_path.txt   --sumstats-path /home/dmc2245/UKBiobank/results/REGENIE_results/results_exome_data/090921_f3393_hearing_aid_200K/*.snp_stats.gz     --unrelated-samples /home/dmc2245/UKBiobank/results/083021_PCA_results/090221_ldprun_unrelated/cache/UKB_genotypedatadownloaded083019.090221_sample_variant_qc_final_callrate90.filtered.extracted.europeans.filtered.090221_ldprun_unrelated.filtered.prune.txt  --    --job-size 1