In [None]:
import hail as hl
import pyspark
from hail.plot import show
import bokeh
from bokeh.plotting import output_file, save
import umap
import pandas
import seaborn
from matplotlib import pyplot as plt
import os
from gnomad.sample_qc.ancestry import pc_project
from gnomad.sample_qc.ancestry import assign_population_pcs
import hdbscan

sc = pyspark.SparkContext()

############################
# specify temporary folder #
############################
tmp_dir = "file:///lustre/scratch126/teams/hgi/users/vo3/tmp/"

#initiate hail
hl.stop()
hl.init(sc=sc, tmp_dir=tmp_dir, default_reference="GRCh38")
hl.plot.output_notebook()
tmp_file = "/lustre/scratch123/projects/gnh_industry/Genes_and_Health_2024_03_54k/qc/plots/tmp.html"#temp file for plots

In [None]:
def run_pca(mt: hl.MatrixTable, n: int):
    #runs PCA and adds some columns to one of the tables so PC projection can be made
    pca_evals, pca_scores, pca_loadings = hl.hwe_normalized_pca(mt.GT, k=n, compute_loadings=True)
    pca_af_ht = mt.annotate_rows(pca_af=hl.agg.mean(mt.GT.n_alt_alleles()) / 2).rows()
    pca_loadings = pca_loadings.annotate(pca_af=pca_af_ht[pca_loadings.key].pca_af)
    return pca_evals, pca_scores, pca_loadings

def pc_projection(mt: hl.MatrixTable, pca_scores: hl.Table, pca_loadings: hl.Table):
    #projects samples on PCs and then combines PCA_scores tables
    projection_PCA_scores=pc_project(mt, pca_loadings, loading_location='loadings', af_location='pca_af')
    union_PCA_scores=pca_scores.union(projection_PCA_scores)
    return union_PCA_scores

def UMAP_run (pca: hl.Table, df: pandas.DataFrame, n: int):
    #creates a data frame used for plotting to separate populations
    reducer = umap.UMAP()
    scores = pandas.DataFrame.from_records(pca.to_pandas().scores)
    embedding = reducer.fit_transform(scores[range(n)])
    embedding.shape
    sns_data = pandas.DataFrame(embedding)
    sns_data['s'] = pca.s.collect()
    sns_data=pandas.merge(sns_data, df, left_on='s', right_on='Sample_ID', how='left')
    return sns_data

def filter_mt (pcamt: hl.MatrixTable):
    #1 Remove palindromes, no other filters
    mt_non_pal = pcamt.filter_rows((pcamt.alleles[0] == "G") & (pcamt.alleles[1] == "C"), keep=False)
    mt_non_pal = mt_non_pal.filter_rows((mt_non_pal.alleles[0] == "C") & (mt_non_pal.alleles[1] == "G"), keep=False)
    mt_non_pal = mt_non_pal.filter_rows((mt_non_pal.alleles[0] == "A") & (mt_non_pal.alleles[1] == "T"), keep=False)
    pcamt_filtered1 = mt_non_pal.filter_rows((mt_non_pal.alleles[0] == "T") & (mt_non_pal.alleles[1] == "A"), keep=False)

    #2 Remove palindromes and filter without MAF
    mt_vqc = hl.variant_qc(pcamt_filtered1, name='variant_QC_Hail')
    pcamt_filtered2 = mt_vqc.filter_rows(
        (mt_vqc.variant_QC_Hail.call_rate >= 0.99) &
        (mt_vqc.variant_QC_Hail.p_value_hwe >= 1e-5)
    )
    #3 Remove palindromes and filter with MAF
    pcamt_filtered3 = pcamt_filtered2.filter_rows(pcamt_filtered2.variant_QC_Hail.AF[1] >= 0.05)

    return pcamt_filtered1, pcamt_filtered2, pcamt_filtered3


In [None]:
#folders with hail matrices and tables
mtdir="file:///lustre/scratch123/projects/gnh_industry/Genes_and_Health_2024_03_54k/qc/matrixtables/"
mtdir2="/lustre/scratch123/projects/gnh_industry/Genes_and_Health_2024_03_54k/qc/matrixtables/"

#files
pruned_mt_file=mtdir+"mt_ldpruned.mt"#result of 2-sample_qc/2-prune_related_samples.py
samples_to_remove_file=mtdir+"mt_related_samples_to_remove.ht"#result of 2-sample_qc/2-prune_related_samples.py

#files to save information about PCA of unrelated samples
pca_scores_file_UR = mtdir+"pca_scores.unrelated.ht"
pca_loadings_file_UR = mtdir+"pca_loadings.unrelated.ht"
pca_evals_file_UR = mtdir2+"pca_evals.unrelated.txt"

#file to save information about union of PCA of unrelated samples and prokjected related samples
union_pca_scores_file= mtdir+"pca_scores.union.ht"

#matrix with 1000G samples
kg_mt_file = mtdir + "/kg_wes_regions.mt"

#file with information about samples from 1000 Genomes
pops_file = "file:///lustre/scratch123/projects/gnh_industry/Genes_and_Health_2024_03_54k/qc/igsr_samples.tsv"

In [None]:
#clean pruned matrix from unneeded rows and get a list(table) of related samples
pruned_mt = hl.read_matrix_table(pruned_mt_file)
pruned_mt = pruned_mt.drop('callrate', 'f_stat', 'is_female')
related_samples_to_remove = hl.read_table(samples_to_remove_file)

# WARNING!

In [None]:
'''
###########################################
#               Be careful!               #
#   Don't run it like this if you have    #
#      pop_assignments.superpop.ht        #
#   you don't wan't to change this file   #
###########################################
'''

pop_file=mtdir+"pop_assignments.ht"#result of 2-sample_qc/3-population_pca_prediction.py
super_pop_file=mtdir+"pop_assignments.superpop.ht"#table with assigned superpopulation
'''
###########################################
#               WARNING!                  #
#   if pop_assignments.superpop.ht exist  #
#       comment the next two lines        #
#       and uncoment the last line        #
###########################################
'''
pop_ht=hl.read_table(pop_file)
pop_ht.write(super_pop_file, overwrite=True)

#pop_ht=hl.read_table(pop_assisuper_pop_filegnments.superpop.ht)

In [None]:
#Rmove samples not assigned to SAS superpopulation
non_sas_to_remove=pop_ht.filter((pop_ht.pop!='SAS') & (hl.is_missing(pop_ht.known_pop)))
pca_mt = pruned_mt.filter_cols(hl.is_defined(non_sas_to_remove[pruned_mt.col_key]), keep=False)

In [None]:
###############################################
# optional step to add SAS samples from 1000G #
###############################################

kg_mt = hl.read_matrix_table(kg_mt_file)

kg_mt = kg_mt.select_entries(kg_mt.GT)

cohorts_pop = hl.import_table(pops_file, delimiter="\t").key_by('Sample name')
kg_mt = kg_mt.annotate_cols(known_pop=cohorts_pop[kg_mt.s]['Superpopulation code'])
kg_sas_mt=kg_mt.filter_cols(kg_mt.known_pop=='SAS')
kg_sas_mt = kg_sas_mt.drop('known_pop')

mt_joined = pca_mt.union_cols(kg_sas_mt)
mt_joined = mt_joined.annotate_cols(known_pop=cohorts_pop[mt_joined.s]['Superpopulation code'])


In [None]:
#apply 3 filters to the matrix
pca_mt_filtered1, pca_mt_filtered2, pca_mt_filtered3 = filter_mt(pca_mt)

In [None]:
###################################################
# optional apply 3 filters to the matrix with SAS #
###################################################
pca_mt_kg_filtered1, pca_mt_kg_filtered2, pca_mt_kg_filtered3 = filter_mt(mt_joined)

In [None]:
#separate related and unrelated samples
pca_mt1_UR = pca_mt_filtered1.filter_cols(hl.is_defined(related_samples_to_remove[pca_mt_filtered1.col_key]), keep=False)
pca_mt1_R = pca_mt_filtered1.filter_cols(hl.is_defined(related_samples_to_remove[pca_mt_filtered1.col_key]))

pca_mt2_UR = pca_mt_filtered2.filter_cols(hl.is_defined(related_samples_to_remove[pca_mt_filtered2.col_key]), keep=False)
pca_mt2_R = pca_mt_filtered2.filter_cols(hl.is_defined(related_samples_to_remove[pca_mt_filtered2.col_key]))

pca_mt3_UR = pca_mt_filtered3.filter_cols(hl.is_defined(related_samples_to_remove[pca_mt_filtered3.col_key]), keep=False)
pca_mt3_R = pca_mt_filtered3.filter_cols(hl.is_defined(related_samples_to_remove[pca_mt_filtered3.col_key]))


In [None]:
############################################################
# optional separation of ELGH sampls and 1000G SAS samples #
############################################################
pca_mt1_kg_UR = pca_mt_kg_filtered1.filter_cols(hl.is_defined(related_samples_to_remove[pca_mt_kg_filtered1.col_key]), keep=False)
pca_mt1_kg_R = pca_mt_kg_filtered1.filter_cols(hl.is_defined(related_samples_to_remove[pca_mt_kg_filtered1.col_key]))

pca_mt2_kg_UR = pca_mt_kg_filtered2.filter_cols(hl.is_defined(related_samples_to_remove[pca_mt_kg_filtered2.col_key]), keep=False)
pca_mt2_kg_R = pca_mt_kg_filtered2.filter_cols(hl.is_defined(related_samples_to_remove[pca_mt_kg_filtered2.col_key]))

pca_mt3_kg_UR = pca_mt_kg_filtered3.filter_cols(hl.is_defined(related_samples_to_remove[pca_mt_kg_filtered3.col_key]), keep=False)
pca_mt3_kg_R = pca_mt_kg_filtered3.filter_cols(hl.is_defined(related_samples_to_remove[pca_mt_kg_filtered3.col_key]))

pca_mt1_sas = pca_mt1_kg_UR.filter_cols(hl.is_defined(pca_mt1_kg_UR.known_pop))
pca_mt1_kg_UR = pca_mt1_kg_UR.filter_cols(hl.is_missing(pca_mt1_kg_UR.known_pop))

pca_mt2_sas = pca_mt2_kg_UR.filter_cols(hl.is_defined(pca_mt2_kg_UR.known_pop))
pca_mt2_kg_UR = pca_mt2_kg_UR.filter_cols(hl.is_missing(pca_mt2_kg_UR.known_pop))

pca_mt3_sas = pca_mt3_kg_UR.filter_cols(hl.is_defined(pca_mt3_kg_UR.known_pop))
pca_mt3_kg_UR = pca_mt3_kg_UR.filter_cols(hl.is_missing(pca_mt3_kg_UR.known_pop))

In [None]:
#Run PCA for unrelated samples
print("PCA1")
pca_evals1, pca_scores1, pca_loadings1 = run_pca(pca_mt1_UR, 10)
print("PCA2")
pca_evals2, pca_scores2, pca_loadings2 = run_pca(pca_mt2_UR, 10)
print("PCA3")
pca_evals3, pca_scores3, pca_loadings3 = run_pca(pca_mt3_UR, 10)

In [None]:
# save tmp PCA files
tmp_dir2=tmp_dir.replace('file://', '')

pca_scores_fileX = tmp_dir + "pca_scores1.ht"
pca_loadings_fileX = tmp_dir + "pca_loadings1.ht"
pca_evals_fileX = tmp_dir2 + "pca_evals1.txt"

pca_scores1.write(pca_scores_fileX, overwrite=True)
pca_loadings1.write(pca_loadings_fileX, overwrite=True)
with open(pca_evals_fileX, 'w') as f:
    for val in pca_evals1:
        f.write(str(val) + "\n")

pca_scores_fileX = tmp_dir + "pca_scores2.ht"
pca_loadings_fileX = tmp_dir + "pca_loadings2.ht"
pca_evals_fileX = tmp_dir2 + "pca_evals2.txt"

pca_scores2.write(pca_scores_fileX, overwrite=True)
pca_loadings2.write(pca_loadings_fileX, overwrite=True)
with open(pca_evals_fileX, 'w') as f:
    for val in pca_evals2:
        f.write(str(val) + "\n")

pca_scores_fileX = tmp_dir + "pca_scores3.ht"
pca_loadings_fileX = tmp_dir + "pca_loadings3.ht"
pca_evals_fileX = tmp_dir2 + "pca_evals3.txt"

pca_scores3.write(pca_scores_fileX, overwrite=True)
pca_loadings3.write(pca_loadings_fileX, overwrite=True)
with open(pca_evals_fileX, 'w') as f:
    for val in pca_evals3:
        f.write(str(val) + "\n")


In [None]:
############################################################################
# optional PCA for ELGH samples which were combined with 1000G SAS samples #
############################################################################
print("PCA1")
pca_kg_evals1, pca_kg_scores1, pca_kg_loadings1 = run_pca(pca_mt1_kg_UR, 10)
print("PCA2")
pca_kg_evals2, pca_kg_scores2, pca_kg_loadings2 = run_pca(pca_mt2_kg_UR, 10)
print("PCA3")
pca_kg_evals3, pca_kg_scores3, pca_kg_loadings3 = run_pca(pca_mt3_kg_UR, 10)

In [None]:
###########################################################################################
# optional save tmp PCA files for ELGH samples which were combined with 1000G SAS samples #
###########################################################################################
pca_scores_fileX = tmp_dir + "pca_scores1_1kg.ht"
pca_loadings_fileX = tmp_dir + "pca_loadings1_1kg.ht"
pca_evals_fileX = tmp_dir2 + "pca_evals1_1kg.txt"

pca_kg_scores1.write(pca_scores_fileX, overwrite=True)
pca_kg_loadings1.write(pca_loadings_fileX, overwrite=True)
with open(pca_evals_fileX, 'w') as f:
    for val in pca_kg_evals1:
        f.write(str(val) + "\n")

pca_scores_fileX = tmp_dir + "pca_scores2_1kg.ht"
pca_loadings_fileX = tmp_dir + "pca_loadings2_1kg.ht"
pca_evals_fileX = tmp_dir2 + "pca_evals2_1kg.txt"

pca_kg_scores2.write(pca_scores_fileX, overwrite=True)
pca_kg_loadings2.write(pca_loadings_fileX, overwrite=True)
with open(pca_evals_fileX, 'w') as f:
    for val in pca_kg_evals2:
        f.write(str(val) + "\n")

pca_scores_fileX = tmp_dir + "pca_scores3_1kg.ht"
pca_loadings_fileX = tmp_dir + "pca_loadings3_1kg.ht"
pca_evals_fileX = tmp_dir2 + "pca_evals3_1kg.txt"

pca_kg_scores3.write(pca_scores_fileX, overwrite=True)
pca_kg_loadings3.write(pca_loadings_fileX, overwrite=True)
with open(pca_evals_fileX, 'w') as f:
    for val in pca_kg_evals3:
        f.write(str(val) + "\n")

In [None]:
#make data frames for plots
sns_data1 = pandas.DataFrame.from_records(pca_scores1.to_pandas().scores)[range(5)]
sns_data1['s'] = pca_scores1.s.collect()

sns_data2 = pandas.DataFrame.from_records(pca_scores2.to_pandas().scores)[range(5)]
sns_data2['s'] = pca_scores2.s.collect()

sns_data3 = pandas.DataFrame.from_records(pca_scores3.to_pandas().scores)[range(5)]
sns_data3['s'] = pca_scores3.s.collect()

In [None]:
############################################################
# optional for ELGH sampls and 1000G SAS samples       #
############################################################
sns_data1_kg = pandas.DataFrame.from_records(pca_kg_scores1.to_pandas().scores)[range(5)]
sns_data1_kg['s'] = pca_kg_scores1.s.collect()

sns_data2_kg = pandas.DataFrame.from_records(pca_kg_scores2.to_pandas().scores)[range(5)]
sns_data2_kg['s'] = pca_kg_scores2.s.collect()

sns_data3_kg = pandas.DataFrame.from_records(pca_kg_scores3.to_pandas().scores)[range(5)]
sns_data3_kg['s'] = pca_kg_scores3.s.collect()

In [None]:
#get a pandas dataframe with self-reported ethnicity for each sample
ethnic_file='/path/to/file/with_SR_wthnicity.tsv'
ethnic_df=pandas.read_csv(ethnic_file, sep='\t')

In [None]:
#annotates dataframes
sns_data1=pandas.merge(sns_data1, ethnic_df, left_on='s', right_on='Sample_ID', how='left')
sns_data2=pandas.merge(sns_data2, ethnic_df, left_on='s', right_on='Sample_ID', how='left')
sns_data3=pandas.merge(sns_data3, ethnic_df, left_on='s', right_on='Sample_ID', how='left')

In [None]:
##################################################
# optional for ELGH sampls and 1000G SAS samples #
##################################################
sns_data1_kg=pandas.merge(sns_data1_kg, ethnic_df, left_on='s', right_on='Sample_ID', how='left')
sns_data2_kg=pandas.merge(sns_data2_kg, ethnic_df, left_on='s', right_on='Sample_ID', how='left')
sns_data3_kg=pandas.merge(sns_data3_kg, ethnic_df, left_on='s', right_on='Sample_ID', how='left')

In [None]:
#create plots

In [None]:
p = seaborn.pairplot(sns_data1, vars=sns_data1.columns[range(5)], hue='SR_Ethnicity',  plot_kws={'s':2})
plt.show()

In [None]:
p = seaborn.pairplot(sns_data2, vars=sns_data2.columns[range(5)], hue='SR_Ethnicity',  plot_kws={'s':2})
plt.show()

In [None]:
p = seaborn.pairplot(sns_data3, vars=sns_data3.columns[range(5)], hue='SR_Ethnicity',  plot_kws={'s':2})
plt.show()

In [None]:
#optional
p = seaborn.pairplot(sns_data1_kg, vars=sns_data1.columns[range(5)], hue='SR_Ethnicity',  plot_kws={'s':2})
plt.show()

In [None]:
#optional
p = seaborn.pairplot(sns_data2_kg, vars=sns_data1.columns[range(5)], hue='SR_Ethnicity',  plot_kws={'s':2})
plt.show()

In [None]:
#optional
p = seaborn.pairplot(sns_data3_kg, vars=sns_data1.columns[range(5)], hue='SR_Ethnicity',  plot_kws={'s':2})
plt.show()

In [None]:
#choose prefered filtration
pca_evals_UR=pca_evals1
pca_scores_UR=pca_scores1
pca_loadings_UR=pca_loadings1
pca_mt_R=pca_mt1_R

pca_scores_file_UR = mtdir + "pca_scores.unrelated_test2.ht"
pca_loadings_file_UR = mtdir + "pca_loadings.unrelated_test2.ht"
pca_evals_file_UR = mtdir2 + "pca_evals.unrelated_test2.txt"
union_pca_scores_file= mtdir+"pca_scores.union_test2.ht"

pca_scores_UR.write(pca_scores_file_UR, overwrite=True)
pca_loadings_UR.write(pca_loadings_file_UR, overwrite=True)
with open(pca_evals_file_UR, 'w') as f:
    for val in pca_evals_UR:
        f.write(str(val) + "\n")


In [None]:
##################################################################
# optional selection of SAS matrix according to filter selection #
##################################################################
SAS_mt=pca_mt1_sas

In [None]:
#project related samples on PC of unrelated samples and unite them
union_PCA_scores=pc_projection(pca_mt_R, pca_scores_UR, pca_loadings_UR)

In [None]:
##########################################
# optional projection and union with SAS #
##########################################
union_PCA_scores=pc_projection(SAS_mt, union_PCA_scores, pca_loadings_UR)

In [None]:
#save PCA scores
union_PCA_scores.write(union_pca_scores_file, overwrite=True)

In [None]:
#create a dataframe for a plot
sns_data = pandas.DataFrame.from_records(union_PCA_scores.to_pandas().scores)[range(5)]
sns_data['s'] = union_PCA_scores.s.collect()
sns_data=pandas.merge(sns_data, ethnic_df, left_on='s', right_on='Sample_ID', how='left')

In [None]:
######################################
# optional annotation if added 1000G #
######################################
sns_data['SR_Ethnicity'] = sns_data['SR_Ethnicity'].fillna('SAS')

In [None]:
#make a plot
p = seaborn.pairplot(sns_data, vars=sns_data.columns[range(5)], hue='SR_Ethnicity',  plot_kws={'s':2})
#plt.savefig("PCA_scores_pairplot.Self_reported_ethnicity.png")
plt.show()

In [None]:
###################################################
# optional get population names from 1000 genomes #
###################################################
pops_file2 = "/lustre/scratch123/projects/gnh_industry/Genes_and_Health_2024_03_54k/qc/igsr_samples.tsv"
kgpop=pandas.read_csv(pops_file2, sep='\t')
kgpop=kgpop[["Sample name", "Population name"]]
kgpop = kgpop.rename(columns={'Sample name': 's',
                        'Population name': 'Known_pop'})

In [None]:
################################################################
# optional annotate df with population names from 1000 genomes #
################################################################
sns_data=pandas.merge(sns_data, kgpop, on='s', how='left')
sns_data['Known_pop'] = sns_data['Known_pop'].fillna('ELGH')

In [None]:
###################################################
# optional plot with names from 1000 genomes #
###################################################
p = seaborn.pairplot(sns_data, vars=sns_data.columns[range(5)], hue='Known_pop',  plot_kws={'s':2})
#plt.savefig("PCA_scores_pairplot.Self_reported_ethnicity.png")
plt.show()

In [None]:
'''
################################################
# The next steps can be repeated several times #
#         with different number of PCs         #
################################################
'''

In [None]:
#make datframe to divide samples into populations
#specify the number of PCs
sns_data_UM=UMAP_run(union_PCA_scores, ethnic_df, 3)

In [None]:
#######################################
# optional annotation of the datframe #
#######################################
sns_data_UM=pandas.merge(sns_data_UM, kgpop, on='s', how='left')
sns_data_UM['Known_pop'] = sns_data_UM['Known_pop'].fillna('ELGH')

In [None]:
#umap plot with self-reported ethnicity
p = seaborn.scatterplot(data=sns_data_UM, 
                       x=0, y=1, hue='SR_Ethnicity')
#plt.savefig("PCA_UMAP.Self_reported_ethnicity.png")
plt.show()

In [None]:
##################################################
# optional plot with ethnicity from 1000 genomes #
##################################################
p = seaborn.scatterplot(data=sns_data_UM, 
                       x=0, y=1, hue='Known_pop')
#plt.savefig("PCA_UMAP.ethnicity_from_1K_genomes.png")
plt.show()

In [None]:
#identify clusters
clusterer = hdbscan.HDBSCAN(min_cluster_size=1000, min_samples=1000)
sns_data_cluster=sns_data_UM[[0,1]]
clusterer.fit(sns_data_cluster)
sns_data_UM['label'] = clusterer.labels_

In [None]:
#umap plot with clusters
p = seaborn.scatterplot(data=sns_data_UM, 
                       x=0, y=1, hue='label')
#plt.savefig("PCA_UMAP.nPCs.OG_clusters.png")
plt.show()

In [None]:
#Modify clusters
sns_data_UM2=sns_data_UM.copy()
sns_data_UM2.loc[sns_data_UM2[0] < 4, 'label'] = 0
sns_data_UM2.loc[sns_data_UM2[0] >= 4, 'label'] = 1
sns_data_UM2.loc[((sns_data_UM2[0] >2.5) & (sns_data_UM2[1] <-4.3)), 'label'] = 1
sns_data_UM2.loc[((sns_data_UM2[0] >2.3) & (sns_data_UM2[1] >15)), 'label'] = 1
sns_data_UM2.loc[((sns_data_UM2[0] >2.3) & (sns_data_UM2[1] >15)), 'label'] = 1
sns_data_UM2.loc[((sns_data_UM2[0] <0) & (sns_data_UM2[1] <-5)), 'label'] = -1
sns_data_UM2.loc[((sns_data_UM2[0] <2.5) & (sns_data_UM2[1] <17)& (sns_data_UM2[1] >13)), 'label'] = -1
sns_data_UM2.loc[((sns_data_UM2[0] <1) & (sns_data_UM2[1] <14) & (sns_data_UM2[1] >10)), 'label'] = 0

p = seaborn.scatterplot(data=sns_data_UM2,
                       x=0, y=1, hue='label')
#plt.savefig("UMAP.no_palindrome.3PCs.clusters.png")
plt.show()

In [None]:
#Rename clusters
sns_data_UM2['label']=sns_data_UM2['label'].astype(str).replace('1', 'bangladeshi').replace('0', 'pakistani').replace('-1', 'other-sas')

In [None]:
#umap plot with renamed clusters
p = seaborn.scatterplot(data=sns_data_UM2, 
                       x=0, y=1, hue='label')
#plt.savefig("UMAP.no_palindrome.3PCs.clusters2.png")
plt.show()

In [None]:
#combine samples with non-sas and combine annotation
sns_data_UM2=sns_data_UM2.drop(columns=[0, 1, "Sample_ID", "SR_Ethnicity"])
non_sas_df=non_sas_to_remove.to_pandas()
non_sas=non_sas_df.drop(columns=["known_pop", "pca_scores", "prob_AFR", "prob_AMR", "prob_EAS", "prob_EUR", "prob_SAS", "evaluation_sample", "training_sample"])
non_sas['pop']='non-sas'
pop_df=pandas.concat([sns_data_UM2, non_sas])
pop_df['label'] = pop_df['label'].fillna(pop_df['pop'])
pop_df=pop_df.reset_index().drop(columns=['index', 'pop'])

In [None]:
#create a dataframe with information from pop_ht and new poppulation annotation
popht_df=pop_ht.to_pandas()
popht_df=pandas.merge(popht_df, pop_df, on='s', how='left')
popht_df['label'] = popht_df['label'].fillna(popht_df['pop'])
popht_df['pop']=popht_df['label']
popht_df=popht_df.drop(columns=['label'])

#WARNING!

In [None]:
'''
###########################################
#                Be careful!              #
#            You rewrite a table          #
#   that was created by previous script.  #
#         you should have superpop ht     #
#              before running this        #
###########################################
'''
#save new poppulation information in hail table
ht_pop = hl.Table.from_pandas(popht_df)
pop_file=mtdir+"pop_assignments.ht"
ht_pop.write(pop_file, overwrite=True)