## Instructions for Using the Models and Data
Adapted from Chunyu Ma's xDTD training pipline

__Data Location__: /scratch/shared_data_new/xDTD_data_for_stephanie

__Description__: 
There are three folders corresponding to different versions of models: v2.8.0.1, v2.8.3, and v2.8.6. Each folder has a similar structure with a `data` subfolder containing the nodes and edges information of the knowledge graph used in that version, and a `models` subfolder containing the trained models and necessary files. 


Please refer to the code below for an example of how to use data and models:

# Installation

In [None]:
# To run this script first install conda and then configure the conda environment
conda env create -f envs/graphsage_p2.7env.yml
conda env create -f envs/xDTD_training_pipeline_env.yml

**Troubleshoot Tip**: Try installing Mamba to activate the enviornment if conda command does not work. Replace conda with `mamba` to create/configure environment.

In [None]:
# Activiate the 'xDTD_training_pipeline' conda environment
conda activate xDTD_training_pipeline

In [None]:
from tqdm.notebook import tqdm
from time import sleep

### Load Packages

In [None]:
import os, sys
import numpy as np
import pickle
import joblib
import pandas as pd
import csv
csv.field_size_limit(100000000)

### Input Functions

In [None]:
import csv
from tqdm import tqdm

def read_tsv_file(file_path, encoding='utf-8'):
    with open(file_path, newline='', encoding=encoding) as f:
        reader = csv.reader(f, delimiter='\t')
        # Wrap reader with tqdm for a progress bar
        data = [row for row in tqdm(reader, desc=f"Reading {file_path}")]
    return data

def load_index(input_path):
    name_to_id, id_to_name = {}, {}
    with open(input_path) as f:
        lines = f.readlines()
        # Wrap lines with tqdm for a progress bar
        for index, line in tqdm(enumerate(lines), total=len(lines), desc=f"Loading index from {input_path}"):
            name, _ = line.strip().split('\t')
            name_to_id[name] = index
            id_to_name[index] = name
    return name_to_id, id_to_name

### Set Up Parameters

In [None]:
version = 'v2.8.0.1' # to analyze different model change the version e.g. v2.8.3, v2.8.6
dir_path = os.path.join(os.getcwd(), version)

### Load Node and Edge Info 

In [None]:
## read node info
nodes_c_header = read_tsv_file(os.path.join(dir_path, 'data', 'nodes_c_header.tsv'))[0]
nodes_c = read_tsv_file(os.path.join(dir_path, 'data', 'nodes_c.tsv'))
node_to_info = {x[0]:x[1:] for x in nodes_c}

## read edge info
edges_c_header = read_tsv_file(os.path.join(dir_path, 'data', 'edges_c_header.tsv'))[0]
edges_c = read_tsv_file(os.path.join(dir_path, 'data', 'edges_c.tsv'))

### Read Embeddings

In [None]:
node_embeddings = np.load(os.path.join(dir_path, 'models', 'RF_model_3class', 'entity_embeddings.npy'))
entity2id, id2entity = load_index(os.path.join(dir_path, 'models', 'RF_model_3class', 'entity2freq.txt'))
entity_embeddings_dict = {key:node_embeddings[value] for key, value in entity2id.items() if value != 0}

### Load Random Forest model

In [None]:
## Note: Make sure sk-learn is installed 
# to check if sk-learn is installed use this command: python -c "import sklearn; print(sklearn. __version__)"

RF_model = joblib.load(os.path.join(dir_path, 'models', 'RF_model_3class', 'RF_model.pt'))

### Query interested drug-disease pairs

In [None]:
# Load Drug disease pairs of interest
pairs = [('PUBCHEM.COMPOUND:456255', 'MONDO:0021680'),('PUBCHEM.COMPOUND:4583', 'MONDO:0005247'),('PUBCHEM.COMPOUND:2082', 'MONDO:0019444'),
         ('PUBCHEM.COMPOUND:4583', 'MONDO:0024355'),('PUBCHEM.COMPOUND:4583', 'MONDO:0005972'),('PUBCHEM.COMPOUND:9782', 'MONDO:0011786'),
         ('PUBCHEM.COMPOUND:456255', 'MONDO:0020920'),('PUBCHEM.COMPOUND:9782', 'MONDO:0006608'),('PUBCHEM.COMPOUND:38103', 'MONDO:0002258'),
         ('PUBCHEM.COMPOUND:4485', 'MONDO:0001134'),('PUBCHEM.COMPOUND:51039', 'MONDO:0005230'),('PUBCHEM.COMPOUND:9782', 'MONDO:0005185'),
         ('PUBCHEM.COMPOUND:5865', 'MONDO:0000870'),('PUBCHEM.COMPOUND:2726', 'MONDO:0011295'),('CHEMBL.COMPOUND:CHEMBL1481', 'MONDO:0012818'),
         ('PUBCHEM.COMPOUND:6741', 'MONDO:0006554'),('PUBCHEM.COMPOUND:2662', 'MONDO:0011923'),('PUBCHEM.COMPOUND:2710', 'MONDO:0005324'),
         ('PUBCHEM.COMPOUND:5743', 'MONDO:0004967'),('PUBCHEM.COMPOUND:456255', 'MONDO:0005545'),('PUBCHEM.COMPOUND:5591', 'MONDO:0010894'),
         ('PUBCHEM.COMPOUND:38103', 'MONDO:0009813'),('PUBCHEM.COMPOUND:456255', 'MONDO:0005970'),('PUBCHEM.COMPOUND:5743', 'MONDO:0006670'),
         ('PUBCHEM.COMPOUND:126941', 'MONDO:0018874'),('PUBCHEM.COMPOUND:441199', 'MONDO:0006816'),('PUBCHEM.COMPOUND:60149', 'MONDO:0011307'),
         ('PUBCHEM.COMPOUND:5865', 'MONDO:0005306'),('PUBCHEM.COMPOUND:5073', 'MONDO:0012054'),('PUBCHEM.COMPOUND:2082', 'MONDO:0005996'),
         ('PUBCHEM.COMPOUND:456255', 'MONDO:0005229'),('PUBCHEM.COMPOUND:4583', 'MONDO:0004264'),('PUBCHEM.COMPOUND:9782', 'MONDO:0008728'),
         ('PUBCHEM.COMPOUND:6741', 'MONDO:0007243'),('PUBCHEM.COMPOUND:3878', 'MONDO:0005384'),('PUBCHEM.COMPOUND:5503', 'MONDO:0011667'),
         ('PUBCHEM.COMPOUND:3394', 'MONDO:0005178'),('PUBCHEM.COMPOUND:5591', 'MONDO:0007453'),('PUBCHEM.COMPOUND:2726', 'MONDO:0013498'),
         ('PUBCHEM.COMPOUND:4912', 'MONDO:0012893'),('PUBCHEM.COMPOUND:4921', 'MONDO:0008143'),('PUBCHEM.COMPOUND:5754', 'MONDO:0005984'),
         ('PUBCHEM.COMPOUND:31307', 'MONDO:0019558'),('PUBCHEM.COMPOUND:3478', 'MONDO:0011668'),('PUBCHEM.COMPOUND:5591', 'MONDO:0013242'),
         ('PUBCHEM.COMPOUND:2726', 'MONDO:0011280'),('PUBCHEM.COMPOUND:5503', 'MONDO:0014674'),('PUBCHEM.COMPOUND:9782', 'MONDO:0013108'),
         ('PUBCHEM.COMPOUND:5865', 'MONDO:0007817'),('PUBCHEM.COMPOUND:5743', 'MONDO:0005615'),('PUBCHEM.COMPOUND:5865', 'MONDO:0012105'),
         ('PUBCHEM.COMPOUND:38103', 'MONDO:0001039'),('PUBCHEM.COMPOUND:126941', 'MONDO:0008903'),('PUBCHEM.COMPOUND:9782', 'MONDO:0010273'),
         ('PUBCHEM.COMPOUND:5282411', 'MONDO:0005709'),('PUBCHEM.COMPOUND:3690', 'MONDO:0004948'),('PUBCHEM.COMPOUND:5865', 'MONDO:0000607'),
         ('PUBCHEM.COMPOUND:3446', 'MONDO:0000675'),('PUBCHEM.COMPOUND:4583', 'MONDO:0005297'),('PUBCHEM.COMPOUND:2082', 'MONDO:0005654'),
         ('PUBCHEM.COMPOUND:4583', 'MONDO:0005619'),('PUBCHEM.COMPOUND:444008', 'MONDO:0008159'),('PUBCHEM.COMPOUND:5865', 'MONDO:0011849'),
         ('PUBCHEM.COMPOUND:5743', 'MONDO:0018479'),('PUBCHEM.COMPOUND:4583', 'MONDO:0005246'),('PUBCHEM.COMPOUND:126941', 'MONDO:0004638'),
         ('PUBCHEM.COMPOUND:9782', 'MONDO:0015898'),('PUBCHEM.COMPOUND:5865', 'MONDO:0004126'),('PUBCHEM.COMPOUND:5388906', 'MONDO:0005861'),
         ('PUBCHEM.COMPOUND:2726', 'MONDO:0005618'),('PUBCHEM.COMPOUND:5743', 'MONDO:0021166'),('PUBCHEM.COMPOUND:6741', 'MONDO:0006545'),
         ('PUBCHEM.COMPOUND:51039', 'MONDO:0004652'),('PUBCHEM.COMPOUND:5743', 'MONDO:0008558'),('PUBCHEM.COMPOUND:9782', 'MONDO:0005083'),
         ('PUBCHEM.COMPOUND:5282493', 'MONDO:0005480'),('PUBCHEM.COMPOUND:5865', 'MONDO:0001713'),('PUBCHEM.COMPOUND:9782', 'MONDO:0001405'),
         ('CHEMBL.COMPOUND:CHEMBL1481', 'MONDO:0011955'),('PUBCHEM.COMPOUND:5503', 'MONDO:0005148'),('PUBCHEM.COMPOUND:3823', 'MONDO:0001461'),
         ('PUBCHEM.COMPOUND:5865', 'MONDO:0005554'),('PUBCHEM.COMPOUND:5865', 'MONDO:0004956'),('PUBCHEM.COMPOUND:9782', 'MONDO:0006547'),
         ('PUBCHEM.COMPOUND:3878', 'MONDO:0010826'),('PUBCHEM.COMPOUND:4583', 'MONDO:0005945'),('PUBCHEM.COMPOUND:9782', 'MONDO:0008729'),
         ('PUBCHEM.COMPOUND:3823', 'MONDO:0005915'),('PUBCHEM.COMPOUND:5770', 'MONDO:0011294'),('PUBCHEM.COMPOUND:5578', 'MONDO:0024330'),
         ('PUBCHEM.COMPOUND:6216', 'MONDO:0004980'),('PUBCHEM.COMPOUND:3823', 'MONDO:0005672'),('PUBCHEM.COMPOUND:2726', 'MONDO:0011498'),
         ('PUBCHEM.COMPOUND:5073', 'MONDO:0011552'),('PUBCHEM.COMPOUND:3878', 'MONDO:0016532'),('PUBCHEM.COMPOUND:135398735', 'MONDO:0012268'),
         ('PUBCHEM.COMPOUND:5865', 'MONDO:0015614'),('PUBCHEM.COMPOUND:5754', 'MONDO:0014241'),('PUBCHEM.COMPOUND:4583', 'MONDO:0018076'),
         ('PUBCHEM.COMPOUND:5743', 'MONDO:0006042'),('PUBCHEM.COMPOUND:2082', 'MONDO:0001103'),('PUBCHEM.COMPOUND:441401', 'MONDO:0024313'),
         ('PUBCHEM.COMPOUND:60750', 'MONDO:0024879'),('PUBCHEM.COMPOUND:4583', 'MONDO:0005124'),('PUBCHEM.COMPOUND:5865', 'MONDO:0009539'),
         ('CHEMBL.COMPOUND:CHEMBL1481', 'MONDO:0011072'),('PUBCHEM.COMPOUND:5865', 'MONDO:0005556'),('PUBCHEM.COMPOUND:451668', 'MONDO:0005374'),
         ('PUBCHEM.COMPOUND:2554', 'MONDO:0008414'),('PUBCHEM.COMPOUND:5865', 'MONDO:0001198'),('PUBCHEM.COMPOUND:5865', 'MONDO:0019203'),
         ('PUBCHEM.COMPOUND:5743', 'MONDO:0009971'),('PUBCHEM.COMPOUND:5865', 'MONDO:0001509'),('PUBCHEM.COMPOUND:9782', 'MONDO:0013294'),
         ('PUBCHEM.COMPOUND:5324346', 'MONDO:0001247'),('PUBCHEM.COMPOUND:5754', 'MONDO:0004773'),('PUBCHEM.COMPOUND:135398745', 'MONDO:0002009'),
         ('PUBCHEM.COMPOUND:135398735', 'MONDO:0000369'),('PUBCHEM.COMPOUND:5865', 'MONDO:0012935'),('PUBCHEM.COMPOUND:5073', 'MONDO:0013089'),
         ('PUBCHEM.COMPOUND:6741', 'MONDO:0008725'),('PUBCHEM.COMPOUND:441314', 'MONDO:0013240'),('PUBCHEM.COMPOUND:60750', 'MONDO:0021040'),
         ('PUBCHEM.COMPOUND:6741', 'MONDO:0010762'),('PUBCHEM.COMPOUND:6741', 'MONDO:0043455'),('PUBCHEM.COMPOUND:9782', 'MONDO:0006548'),
         ('PUBCHEM.COMPOUND:3151', 'MONDO:0002268'),('PUBCHEM.COMPOUND:441401', 'MONDO:0017776'),('PUBCHEM.COMPOUND:2576', 'MONDO:0001830'),
         ('PUBCHEM.COMPOUND:6741', 'MONDO:0005377'),('PUBCHEM.COMPOUND:6509979', 'MONDO:0011598'),('CHEMBL.COMPOUND:CHEMBL1481', 'MONDO:0011363'),
         ('PUBCHEM.COMPOUND:9782', 'MONDO:0013893'),('PUBCHEM.COMPOUND:3478', 'MONDO:0012348'),('PUBCHEM.COMPOUND:6741', 'MONDO:0013730'),
         ('PUBCHEM.COMPOUND:51039', 'MONDO:0005892'),('PUBCHEM.COMPOUND:5865', 'MONDO:0019127'),('PUBCHEM.COMPOUND:5754', 'MONDO:0004857'),
         ('PUBCHEM.COMPOUND:126941', 'MONDO:0005340'),('PUBCHEM.COMPOUND:5865', 'MONDO:0005093'),('PUBCHEM.COMPOUND:2726', 'MONDO:0013506'),
         ('PUBCHEM.COMPOUND:4173', 'MONDO:0002154'),('PUBCHEM.COMPOUND:4173', 'MONDO:0006669'),('PUBCHEM.COMPOUND:2165', 'MONDO:0005920'),
         ('PUBCHEM.COMPOUND:4659569', 'MONDO:0014604'),('PUBCHEM.COMPOUND:3394', 'MONDO:0008383'),('PUBCHEM.COMPOUND:5311066', 'MONDO:0006569'),
         ('PUBCHEM.COMPOUND:92727', 'MONDO:0005109'),('PUBCHEM.COMPOUND:3823', 'MONDO:0001648'),('PUBCHEM.COMPOUND:5865', 'MONDO:0002081'),
         ('PUBCHEM.COMPOUND:6741', 'MONDO:0043789'),('PUBCHEM.COMPOUND:5865', 'MONDO:0012318'),('PUBCHEM.COMPOUND:5503', 'MONDO:0011027'),
         ('PUBCHEM.COMPOUND:5865', 'MONDO:0008219')]

### Predict Probabilities: Drug-Diease pairs

In [None]:
X = np.vstack([np.hstack([entity_embeddings_dict[drug_curie_id],entity_embeddings_dict[disease_curie]]
                        ) for drug_curie_id, disease_curie in pairs])
res_temp = RF_model.predict_proba(X)
res = pd.concat([pd.DataFrame(pairs),pd.DataFrame(res_temp)], axis=1)
res.columns = ['drug_id','disease_id','tn_score','tp_score','unknown_score']

In [None]:
res_temp # displays the prediction probability in arrays

In [None]:
res

In [None]:
# Export the result as csv file
f_path = '/scratch/shared_data_new/xDTD_data_for_stephanie/RF_prediction/v2.8.0.1_RF_pairs.csv'
res.to_csv(f_path, index=False)