# Generate PRS with PRSedm
## Updated 11/8/24
Please either install the package via pip, anaconda or locally from the bucket as below.

### Install PRSedm from the bucket (optional)
*Must* restart kernel after install

In [None]:
!gsutil cp $WORKSPACE_BUCKET/scoreprs/<prsedm>.whl
!pip install <prsedm>.whl --force-reinstall

### Setup the workspace and load the Hail VDS
Must be run with environment "Hail Genomics Analysis" (dataproc)

In [None]:
from datetime import datetime
import os
import pandas as pd
import math
import numpy as np
import pysam
import prsedm
import warnings

#Define local and workspace bucket locations
bucket = os.getenv('WORKSPACE_BUCKET')
workspace = "/home/jupyter/workspaces/<myworkspace>"

### Generate a list of required variants

In [None]:
#Retrieve variant list from package
df = prsedm.get_snp_db('t1dgrs2-sharp24')
df = df.drop_duplicates().reset_index(drop=True)

#Build a list of variant positions to extract
pos_col='position_hg38'#change to hg19 if required
variantIntervals = [
    f"chr{row['contig_id']}:{row[pos_col]}-{row[pos_col] + 1}"
    for _, row in df.iterrows()
]
print(f"Number of unique variants to extract: {len(variantIntervals)}")

### Use Hail to retrieve variants from the VDS into VCF

In [None]:
#Initiate Hail
import hail as hl

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    hl.init(idempotent=True)
    hl.default_reference('GRCh38')

    #Load the Hail comprehensive VDS for srWGS
    vds_path = os.getenv('WGS_VDS_PATH')
    vds = hl.vds.read_vds(vds_path)

#Must be chunked up as querying >1200 regions causes crash
print("Retrieve chunks from VDS and densify...")
chunk_size=1000
chunked_intervals = [variantIntervals[i:i + chunk_size] for i in range(0, len(variantIntervals), chunk_size)]
mt_subsets = []
for i,chunk in enumerate(chunked_intervals):
    print(f"Processing chunk: {i+1}")
    vds_subset = hl.vds.filter_intervals(vds,
                                         [hl.parse_locus_interval(x) for x in chunk])

    vds_subset=hl.vds.split_multi(vds_subset)
    mt=hl.vds.to_dense_mt(vds_subset)
    mt_subsets.append(mt)

#Combine all
print("Combining retrieved chunks...")
mt = mt_subsets[0]
for mt_n in mt_subsets[1:]:
    mt = mt.union_rows(mt_n)

#Process the merged data
print("Process merged MT...")
mt=hl.split_multi_hts(mt)
mt=mt.drop('FT')
mt=hl.variant_qc(mt)
mt=mt.annotate_rows(info=hl.struct(AF=mt.variant_qc.AF))

#Export VCF
print("Export to VCF and store in bucket...")
hl.export_vcf(mt, f'{bucket}/temp.vcf.bgz')

### Copy imputation reference data if required

In [None]:
# Copy TOPMED reference data
!gsutil cp -r $WORKSPACE_BUCKET/BravoFreeze8 .
# Index the data
!for f in BravoFreeze8/*.vcf.gz;do tabix -f $f;done

### Call PRSedm to generate PRS
Recommended to switch to 'General Analysis Environment'

In [None]:
# Generate GRS
ref_dir=workspace+'/BravoFreeze8/'
vcf=workspace+'/<mygenotypes>.vcf.bgz'

start_time = datetime.now()
output = prsedm.gen_dm(vcf=vcf, 
                            col="GT", 
                            build="hg38", 
                            prsflags="t1dgrs2-sharp24", 
                            impute=1, 
                            ref_dir=ref_dir,
                            norm=1,
                            parallel=1,
                            ntasks=16,
                            batch_size=1)
end_time = datetime.now()
print(f"Execution time: {(end_time - start_time).total_seconds():.2f} seconds")

#Store the result in the bucket
output.to_csv(f"{bucket}/prsedm_result.csv", index=False)