## Imports

In [2]:
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score, average_precision_score
import xgboost as xgb

ot = ['clin_ot','hgmd','gwas_credible_sets','expression_atlas','impc','europepmc']
mantis = ['mantis']
cc = ['cc_common_max_p','cc_rare_max_p','cc_rare_burden_max_p','cc_ultrarare_max_p']

## Generate predictions using trained models

Here we use trained models to generate predictions for 120 ultrarare phecodes not included in our training set. 

Because genetic associations are not available for these phecodes, we use the version of RareGPS without genetic associations ("ot_mantis").

In [9]:
ur = pd.read_pickle('./Final/drugs_ot_ur.pkl')[['id','gene','phecode','indication','phase']+ot+mantis]
print('G-P pairs',ur['id'].nunique())
print('Genes',ur['gene'].nunique())
print('Phecodes',ur['phecode'].nunique())
print('Indications',ur['indication'].sum())
ur.head()

G-P pairs 169904
Genes 17998
Phecodes 121
Indications 958


Unnamed: 0,id,gene,phecode,indication,phase,clin_ot,hgmd,gwas_credible_sets,expression_atlas,impc,europepmc,mantis
4228,GI_552.11:LCN2,LCN2,GI_552.11,0,0.0,0.0,0.0,0.0,0.090278,0.0,0.0,
46586,GE_962.5:SRI,SRI,GE_962.5,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1253
87131,GE_972.7:USP24,USP24,GE_972.7,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1618
53026,GE_972.7:CLOCK,CLOCK,GE_972.7,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1065
40650,GE_981.8:SPATA6L,SPATA6L,GE_981.8,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0102


Feature columns should be named as above. 

The columns 'indication' and 'phase' are not necessary to generate predictions but are helpful to validate the accuracy of predictions. Please refer to the manuscript for details about how each feature was scored.

RareGPS models were trained using 5-fold cross validation, so 25 predictions will be generated for each G-P pair and averaged to output a single prediction.

In [10]:
model_names = ['ot_mantis']
feature_names = [ot+mantis]

for i,j in zip(model_names, feature_names):
    print('Model',i)

    X_ur = ur[j]
    y_ur = ur['indication']
    ids_ur = ur['id']
    
    ur_predictions = []
        
    for fold in range(1,6):
        for inner_fold in range(1,6):

            model = xgb.Booster()
            model.load_model(f"./GPS/Main/Models/{i}_{fold}_{inner_fold}.json")

            y_pred_ur = model.predict(xgb.DMatrix(X_ur, enable_categorical=True))
            auroc_ur = roc_auc_score(y_ur, y_pred_ur)
            auprc_ur = average_precision_score(y_ur, y_pred_ur)
            print('Fold',fold,inner_fold,'AUROC',round(auroc_ur,2),'AUPRC',round(auprc_ur,2))

            fold_predictions = pd.DataFrame({
                'id': ids_ur,
                'prediction': y_pred_ur
            })
            ur_predictions.append(fold_predictions)

    ur_predictions = pd.concat(ur_predictions, ignore_index=True)
    ur_predictions = ur_predictions.groupby('id')['prediction'].mean().reset_index()
    ur_predictions = ur_predictions.merge(ur[['id','gene','phecode','indication','phase']])
    ur_predictions.to_pickle(f'./Ultrarare/Predictions_trained/ur_predictions_{i}.pkl')


Model ot_mantis
Fold 1 1 AUROC 0.83 AUPRC 0.05
Fold 1 2 AUROC 0.83 AUPRC 0.05
Fold 1 3 AUROC 0.83 AUPRC 0.06
Fold 1 4 AUROC 0.83 AUPRC 0.06
Fold 1 5 AUROC 0.83 AUPRC 0.06
Fold 2 1 AUROC 0.83 AUPRC 0.06
Fold 2 2 AUROC 0.83 AUPRC 0.06
Fold 2 3 AUROC 0.82 AUPRC 0.06
Fold 2 4 AUROC 0.82 AUPRC 0.05
Fold 2 5 AUROC 0.82 AUPRC 0.05
Fold 3 1 AUROC 0.83 AUPRC 0.06
Fold 3 2 AUROC 0.83 AUPRC 0.05
Fold 3 3 AUROC 0.83 AUPRC 0.05
Fold 3 4 AUROC 0.83 AUPRC 0.05
Fold 3 5 AUROC 0.82 AUPRC 0.05
Fold 4 1 AUROC 0.82 AUPRC 0.05
Fold 4 2 AUROC 0.83 AUPRC 0.06
Fold 4 3 AUROC 0.83 AUPRC 0.06
Fold 4 4 AUROC 0.83 AUPRC 0.06
Fold 4 5 AUROC 0.83 AUPRC 0.06
Fold 5 1 AUROC 0.82 AUPRC 0.05
Fold 5 2 AUROC 0.82 AUPRC 0.06
Fold 5 3 AUROC 0.83 AUPRC 0.06
Fold 5 4 AUROC 0.82 AUPRC 0.05
Fold 5 5 AUROC 0.83 AUPRC 0.06


Here we see that the predictions achieve high AUROCs and AUPRCs.

In [12]:
ur_predictions = pd.read_pickle('./Ultrarare/Predictions/ur_predictions_ot_mantis.pkl')
print('AUROC',round(roc_auc_score(ur_predictions['indication'], ur_predictions['prediction']),4))
print('AUPRC',round(average_precision_score(ur_predictions['indication'], ur_predictions['prediction']),4))
print('Proportion of cases',round(ur_predictions['indication'].mean(),4))

AUROC 0.8337
AUPRC 0.0612
Proportion of cases 0.0056


Below are the G-P pairs with highest scores.

In [14]:
ur_predictions = pd.read_pickle('./Ultrarare/Predictions/ur_predictions_ot_mantis.pkl')
ur_annotations = pd.read_excel('./Resources/ultrarare_phecode_list.xlsx')
ur_predictions.merge(ur_annotations[['phecode','phecode_string']]).sort_values('prediction',ascending=False).head(10)

Unnamed: 0,id,prediction,gene,phecode,indication,phase,phecode_string
79822,GE_965.4:LDLR,0.550051,LDLR,GE_965.4,0,0.0,Familial hypercholesterolemia*
51058,GE_962.4:SLC3A1,0.416906,SLC3A1,GE_962.4,0,0.0,Disturbances of sulphur-bearing amino-acid met...
135404,GE_981.3:MECP2,0.413124,MECP2,GE_981.3,1,0.5,Rett syndrome*
110664,GE_971.12:F9,0.409965,F9,GE_971.12,1,4.0,Hereditary factor IX disorder [Hemophilia B]
109561,GE_970.8:ALAS2,0.404861,ALAS2,GE_970.8,0,0.0,Hereditary sideroblastic anemia*
82702,GE_965.4:PCSK9,0.387524,PCSK9,GE_965.4,1,4.0,Familial hypercholesterolemia*
70186,GE_962.9:SLC3A1,0.382243,SLC3A1,GE_962.9,0,0.0,Disturbances of amino-acid transport
23268,BI_179.8:CD40LG,0.363928,CD40LG,BI_179.8,0,0.0,Immunodeficiency with increased IgM
71647,GE_964.6:SLC2A1,0.35919,SLC2A1,GE_964.6,0,0.0,Glucose transporter protein type 1 deficiency ...
22512,BI_172.3:PRF1,0.356295,PRF1,BI_172.3,0,0.0,Hemophagocytic syndromes
