# PRS for WGS data

In [None]:
#set up Hail requirements on instance
import os
os.system('wget https://github.com/adoptium/temurin21-binaries/releases/download/jdk-21.0.4%2B7/OpenJDK21U-jre_x64_linux_hotspot_21.0.4_7.tar.gz')
!tar -xvzf OpenJDK21U-jre_x64_linux_hotspot_21.0.4_7.tar.gz
!pip install pyspark
# Set the JAVA_HOME environment variable
os.environ['JAVA_HOME'] = '/opt/notebooks/jdk-21.0.4+7-jre' 
os.environ['PATH'] = f"{os.environ['JAVA_HOME']}/bin:" + os.environ['PATH']

In [None]:
#set up HAIL
!pip install hail --force-reinstall
import hail as hl

hl.init(
    master='local[*]',  # Use all available cores
    spark_conf={
        'spark.executor.memory': '4g',  # Adjust memory as needed
        'spark.driver.memory': '4g'      # Adjust memory as needed
    }
)
hl.default_reference('GRCh38')

In [None]:
#Download corresponding .vcf.gz block WGS file for each SNP into instance
import os
os.system('dx download <file_name.vcf.gz>')

In [None]:
#Download PRSedm into instance
import os
os.system('wget https://files.pythonhosted.org/packages/6e/d9/ebd00d933502674a1072f226bc429e5092ab365941262def01f4cffdbb44/prsedm-1.0.0-py3-none-any.whl')
!pip install prsedm
import prsedm

In [None]:
import cryptography
import OpenSSL
from datetime import datetime


for chromosome in range(1, 23):  #Loop through chromosomes with * representing block number
    vcf_file = f'<file_name_c{chromosome}_b*_v1.vcf.gz>'
    print(f"Importing {vcf_file} into Hail...")
    
    try:
        #Import into Hail
        mt = hl.import_vcf(vcf_file, reference_genome='GRCh38', force_bgz=True)
        print(f"{vcf_file} imported successfully.")

        #Get list of SNPs to extract from PRS model
        df = prsedm.get_snp_db('t1dgrs2-sharp24')
        df = df.drop_duplicates().reset_index(drop=True)
        
        #Build a list of variant positions to extract
        pos_col = 'position_hg38'  # change to hg19 if required
        variantIntervals = [
            f"chr{row['contig_id']}:{row[pos_col]}-{row[pos_col] + 1}"
            for _, row in df.iterrows()
        ]
        print(f"Number of unique variants to extract: {len(variantIntervals)}")

        #Chunk up the regions incase there are too many variants requested
        print(f"Retrieve chunks from chromosome {chromosome} VCF and densify...")
        chunk_size = 1000
        chunked_intervals = [variantIntervals[i:i + chunk_size] for i in range(0, len(variantIntervals), chunk_size)]
        mt_subsets = []  # Reset subsets list for each chromosome
        
        for i, chunk in enumerate(chunked_intervals):
            print(f"Processing chunk {i + 1} for chromosome {chromosome}...")
            vcf_filtered = hl.filter_intervals(mt, [hl.parse_locus_interval(x) for x in chunk])
            vcf_filtered = hl.split_multi_hts(vcf_filtered)
            mt_subsets.append(vcf_filtered)
        
        #Combine chunks
        print(f"Combining retrieved chunks for chromosome {chromosome}...")
        if mt_subsets:
            combined_mt = mt_subsets[0]
            for mt_n in mt_subsets[1:]:
                combined_mt = combined_mt.union_rows(mt_n)

            #Process the merged data
            print(f"Processing merged MT for chromosome {chromosome}...")
            combined_mt = hl.variant_qc(combined_mt)
            combined_mt = combined_mt.annotate_rows(info=hl.struct(AF=combined_mt.variant_qc.AF))

            # Export to VCF
            print(f"Exporting chromosome {chromosome} to VCF...")
            start = datetime.now()
            output_vcf = f'./chr{chromosome}_data_temp.vcf.bgz'
            hl.export_vcf(combined_mt, output_vcf)
            print(f"Export for chromosome {chromosome} took {(datetime.now() - start).total_seconds():.2f} seconds")
        else:
            print(f"No MatrixTables were imported for chromosome {chromosome}; nothing to combine.")

    except Exception as e:
        print(f"Error importing or processing chromosome {chromosome} VCF: {e}")

In [None]:
#Use bcftools to merge VCF files
!bcftools concat -n -Oz \
chr1_data_temp.vcf.bgz \
chr2_data_temp.vcf.bgz \
chr4_data_temp.vcf.bgz \
chr6_data_temp.vcf.bgz \
chr7_data_temp.vcf.bgz \
chr9_data_temp.vcf.bgz \
chr10_data_temp.vcf.bgz \
chr11_data_temp.vcf.bgz \
chr12_data_temp_1.vcf.bgz \
chr12_data_temp_2.vcf.bgz \
chr12_data_temp_3.vcf.bgz \
chr12_data_temp_4.vcf.bgz \
chr13_data_temp.vcf.bgz \
chr14_data_temp.vcf.bgz \
chr15_data_temp.vcf.bgz \
chr16_data_temp.vcf.bgz \
chr18_data_temp.vcf.bgz \
chr19_data_temp.vcf.bgz \
chr20_data_temp.vcf.bgz \
chr21_data_temp.vcf.bgz \
chr22_data_temp.vcf.bgz \
 -o merged_wgs_data.vcf.gz

In [None]:
#Index file
!tabix -fp vcf merged_wgs_data.vcf.gz

In [None]:
#Download reference TOPMED data using command for individual account and then 
#Index files
!for f in reference.vcf.gz;do tabix -f $f;done

In [None]:
#Generate GRS
output = prsedm.gen_dm(vcf=vcf, 
                            col="GT", 
                            build="hg38", 
                            prsflags="t1dgrs2-luckett25", 
                            impute=1, 
                            refvcf=<path to TOPMED reference files>,
                            norm=1,
                            ntasks=16,
                            parallel=1,
                            batch_size=1)
#Save results
output.to_csv(f"prsedm_result.csv", index=False)