___
# Polygenic Risk Score Accuracy is Dependent on Local Ancestry
___

## Import necessary packages and code

In [None]:
# packages
import seaborn as sns
import matplotlib.pyplot as plt
import sys
from scipy import stats

In [None]:
# code
sys.path.insert(0,"/Users/taylorcavazos/repos/Local_Ancestry_PRS/code/")
from plot_correlations import *
from plot_af_ld import *

## Contents

In [None]:
# To do

## Background

When PRS are built in Europeans they perform poorly in non-Europeans; however, our most powered GWAS are of European populations. Thus moving forward we not only need larger diverse population datsets, but also statistical approaches for generalizing these scores.

![](https://media.springernature.com/full/springer-static/image/art%3A10.1038%2Fs41588-019-0379-x/MediaObjects/41588_2019_379_Fig3_HTML.png)

Performance seems to decrease with decreasing proportion of European ancestry, which led me to the question of whether this accuracy may be depednent on local ancestry proportions within ancestry groups.

## Simulation setup

![](images/methods.jpeg)

____
# Simulation PRS Accuracy

## Pearson's correlation of european derived PRS in African individuals by percent european ancestry at PRS variants

Here we are holding p-value and r2 constant at 0.01 and 0.2

In [None]:
plot_correlation_all_params_eur_weights()
plt.show()

## We can improve PRS accuracy by using local ancestry specific weights 

In [None]:
plot_correlation_all_params_all_weights()
plt.show()

#### To move forward with analysis I will limit to parameters $m=1000$ and $h2=0.5$ because although all parameter combinations reflect the same trend, this one is likely to be a closer proximity to true disease biology... hundreds of SNPs have been identified as causal for many common diseases and this isn't including the rare variants that are likely to be discovered (& are causal in my simulation)

In [None]:
data = load_all_weight_summary_data()

In [None]:
fig,ax = plt.subplots(1,2,figsize=(30,15))
plot_correlation_single_eur_weights(data,ax[0],m=1000,h2=0.5)
plot_correlation_single_all_weights(data,ax[1],m=1000,h2=0.5)
plt.tight_layout()
plt.show()

In [None]:
eur_weight_ceu_only = data.loc[(data["m"]==1000)&(data["h2"]==0.5)&(data["weight"]=="European"),"test_EUR_corr"]

afr_weight_ceu_low = data.loc[(data["m"]==1000)&(data["h2"]==0.5)&(data["weight"]=="African"),"ADMIX_low_eur_corr"]
eur_weight_ceu_low = data.loc[(data["m"]==1000)&(data["h2"]==0.5)&(data["weight"]=="European"),"ADMIX_low_eur_corr"]
LA_weight_ceu_low = data.loc[(data["m"]==1000)&(data["h2"]==0.5)&(data["weight"]=="Local ancestry \nspecific"),"ADMIX_low_eur_corr"]

In [None]:
stats.ttest_ind(LA_weight_ceu_low,afr_weight_ceu_low)

___

# Impact of decreasing African sample size

In [None]:
plot_correlation_decreasing_yri_allPops()

# Main Takeaways
#### (1) Local ancestry matters

#### (2) Using ancestry specific weights isn't enough to achieve similar accuracy as PRS in Europeans

___
# Exploration of Simulation PRS -  Why doesn't the European PRS generalize across populations?

## Allele frequency

In [None]:
plot_maf_bins()

## Linkage disequilibrium

In [7]:
import msprime
import numpy as np
import pandas as pd
import tqdm
import threading
import math
import seaborn as sns
import matplotlib.pyplot as plt
import operator

from multiprocessing import Pool
import time
import pickle

import random
from collections import defaultdict
from itertools import chain
from operator import methodcaller

In [2]:
tree_yri = msprime.load("../data/sim2/trees/tree_YRI_GWAS_nofilt.hdf")
tree_ceu = msprime.load("../data/sim2/trees/tree_CEU_GWAS_nofilt.hdf")

In [3]:
def return_LD_dict(tree_LD):
    var2mut, mut2var, positions = {}, {}, {}
    for mut in tree_LD.mutations():
        mut2var[mut.id]=mut.site
        var2mut[mut.site]=mut.id
        positions[mut.site]=mut.position

    tree_LD_filt = tree_LD.simplify(filter_sites=True)
    return tree_LD_filt, var2mut, mut2var, positions

In [4]:
tree_ceu_LD, var2mut_ceu, mut2var_ceu, pos_ceu = return_LD_dict(tree_ceu)

In [5]:
tree_yri_LD, var2mut_yri, mut2var_yri, pos_yri = return_LD_dict(tree_yri)

In [6]:
vars_to_check = list(set(pos_ceu.keys()).intersection(pos_yri.keys()))

In [8]:
list_keys = sorted(vars_to_check)

def get_dist_bin(dist):
    if dist < 5: dist_bin = 1
    elif dist >= 5 and dist < 10: dist_bin = 2
    elif dist >= 10 and dist < 15: dist_bin = 3
    elif dist >= 15 and dist < 20: dist_bin = 4
    elif dist >= 20 and dist < 25: dist_bin = 5
    elif dist >= 25 and dist < 30: dist_bin = 6
    elif dist >= 30 and dist < 35: dist_bin = 7
    elif dist >= 35 and dist < 40: dist_bin = 8
    elif dist >= 40 and dist < 45: dist_bin = 9
    elif dist >= 45 and dist < 50: dist_bin = 10
    else: dist_bin = -1
    return dist_bin

def get_dist(i):
#     df = pd.DataFrame(columns=["var1","var2","dist"])
    new_dict = dict.fromkeys(range(1,11), [])
    j=i+1
    while(j < len(list_keys) and np.absolute(pos_ceu.get(list_keys[i])-pos_ceu.get(list_keys[j])) <= 50e3):
        dist = np.absolute(pos_ceu.get(list_keys[j])-pos_ceu.get(list_keys[i]))/1000
        dist_bin = get_dist_bin(dist)
        new_dict.get(dist_bin).append((list_keys[i],list_keys[j]))
#         df = df.append({"var1":list_keys[i],"var2":list_keys[j],"dist":dist},ignore_index=True)
        j+=1
    return new_dict
                       
pool = Pool(processes=8)
pairwise_dists = pool.map(get_dist,range(len(list_keys)))
pool.close()

In [9]:
pairwise_dists_comb = defaultdict(list)
dict_items = map(methodcaller("items"),pairwise_dists)
for k,v in chain.from_iterable(dict_items):
    pairwise_dists_comb[k].extend(v)

In [10]:
dist_inds_df = pd.DataFrame(columns=[0,1,2])
for k in pairwise_dists_comb.keys():
    v = pairwise_dists_comb.get(k)
    rand_inds = random.sample(v,100)
    sub_df = pd.DataFrame(np.array(rand_inds))
    sub_df[2] = k
    dist_inds_df = dist_inds_df.append(sub_df,ignore_index=True)
dist_inds_df.columns=["var1","var2","dist_bin"]

In [11]:
ld_ceu = msprime.LdCalculator(tree_ceu)
ld_yri = msprime.LdCalculator(tree_yri)

def find_ld(ind):
    results = {}
    var1 = int(dist_inds_df.loc[ind,"var1"])
    var2 = int(dist_inds_df.loc[ind,"var2"]) 
    results[ind] = (ld_ceu.get_r2(var1,var2), ld_yri.get_r2(var1,var2))
    return results

In [None]:
overall_result = []
for key,val in dist_inds_df.groupby("dist_bin").groups.items():
    pool = Pool(processes=8)
    ld_result = pool.map(find_ld,val)
    overall_result.append(ld_result)
    pool.close()

In [None]:
overall_result = np.array(overall_result).flatten()
overall_dicts = {}
for d in overall_result:
    overall_dicts = {**overall_dicts,**d}

for key, val in overall_dicts.items():
    dist_inds_df.loc[key,"CEU_r2"] = val[0]
    dist_inds_df.loc[key,"YRI_r2"] = val[1]

In [None]:
mean_mat = dist_inds_df.groupby("dist_bin").mean().reset_index()

In [None]:
mean_mat["dist_bin"] = np.arange(5,55,5)

In [None]:
plt.figure(figsize=(10,5))
sns.lineplot(x="dist_bin",y="CEU_r2",data=mean_mat,color="blue")
sns.lineplot(x="dist_bin",y="YRI_r2",data=mean_mat,color="red")
plt.xticks(np.arange(5,55,5))
plt.show()

___
# Deep dive into causal variants

(1) If they are present in the summary statistics do they have the effect size? What about the p-values?
* Make plot of one simulation (heatmap of some sort)
* Come up with plot for summarizing all simulations

In [None]:
tree_yri = msprime.load("../data/sim2/trees/tree_YRI_GWAS_nofilt.hdf")
tree_ceu = msprime.load("../data/sim2/trees/tree_CEU_GWAS_nofilt.hdf")

causal_vars = np.linspace(0, tree_ceu.num_sites, m, dtype=int,endpoint=False)

yri_sumstats = pd.read_csv("../data/sim2/emp_prs/yri_comm_maf_0.01_sum_stats_m_1000_h2_0.5.txt",index_col=0,sep="\t")
ceu_sumstats = pd.read_csv("../data/sim2/emp_prs/comm_maf_0.01_sum_stats_m_1000_h2_0.5.txt",index_col=0,sep="\t")

In [None]:
yri_sumstats.loc[yri_sumstats.OR==0,"OR"] = 1
ceu_sumstats.loc[ceu_sumstats.OR==0,"OR"] = 1

In [None]:
causal_pres_yri = yri_sumstats.reindex(causal_vars).dropna()
causal_pres_yri.columns = ["OR_YRI","PVAL_YRI"]

In [None]:
causal_pres_ceu =ceu_sumstats.reindex(causal_vars).dropna()
causal_pres_ceu.columns = ["OR_CEU","PVAL_CEU"]

In [None]:
data = pd.concat([causal_pres_ceu,causal_pres_yri],sort=False,axis=1).dropna()
data["CEU"] = -1*np.log10(data["PVAL_CEU"])
data["YRI"] = -1*np.log10(data["PVAL_YRI"])

In [None]:
data_long = data.reset_index().melt(id_vars="var_id",value_vars=["CEU","YRI"],)

In [None]:
def heatmap(x, y, size, color):
    fig, ax = plt.subplots(figsize=(36,5))
    
    # Mapping from column names to integer coordinates
    x_labels = [v for v in sorted(x.unique())]
    y_labels = [v for v in sorted(y.unique())]
    x_to_num = {p[1]:p[0] for p in enumerate(x_labels)} 
    y_to_num = {p[1]:p[0] for p in enumerate(y_labels)} 
    
    size_scale = 5
    sns.scatterplot(
        x=x.map(x_to_num), # Use mapping for x
        y=y.map(y_to_num), # Use mapping for y
        s=size * size_scale, # Vector of square sizes, proportional to size parameter
        marker='o', # Use square as scatterplot marker
        hue = color,
        palette = "bwr",
        ax=ax,
        legend=False
    )
    
    # Show column labels on the axes
    ax.set_xticks([x_to_num[v] for v in x_labels])
    ax.set_xticklabels(x_labels, rotation=45, horizontalalignment='right',fontsize=18)
    ax.set_yticks([y_to_num[v] for v in y_labels])
    ax.set_yticklabels(y_labels)
    
    
    ax.grid(False, 'major')
    ax.grid(True, 'minor')
    ax.set_xticks([t + 0.5 for t in ax.get_xticks()], minor=True)
    ax.set_yticks([t + 0.5 for t in ax.get_yticks()], minor=True)
    
    
    ax.set_xlim([-0.5, max([v for v in x_to_num.values()]) + 0.5]) 
    ax.set_ylim([-0.5, max([v for v in y_to_num.values()]) + 0.5])

In [None]:
for ind in data_long.index:
    var = data_long.loc[ind,"var_id"]
    if data_long.loc[ind,"variable"]=="CEU":
        data_long.loc[ind,"color"] = data.loc[var,"OR_CEU"]
    else: data_long.loc[ind,"color"] = data.loc[var,"OR_YRI"]

In [None]:
sns.set_style("darkgrid")

In [None]:
heatmap(data_long["var_id"],data_long["variable"],data_long["value"],np.log(data_long["color"]))
plt.show()

___
# Next Steps

### Test different SNP selection approaches
* African, European (done), or Meta selected SNPs

___