## Instructions for Using the Script and Data
__Data Location__: /scratch/shared_data_new/xDTD_data_for_stephanie/RF_prediction

__Description__: 
This script is to analyze the predicton performance of Random Forest Classification algorithm component in KGML-xDTD model.
There are three csv files corresponding to different versions of models: v2.8.0.1, v2.8.3, and v2.8.6.
Random Forest prediction scores were generated and extracted via `xDTD_RandomForest_Instructions` (Jupyter Notebook script).

Please refer to the code below for an example of how to extract data and create visualization:

## Load Packages

In [None]:
import os
os.chdir('/scratch/shared_data_new/xDTD_data_for_stephanie/RF_prediction')

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import scipy.stats as stats

# Read RF Prediction Score CSV files
df1 = pd.read_csv('v2.8.0.1_RF_pairs.csv')
df2 = pd.read_csv('v2.8.3_RF_pairs.csv')
df3 = pd.read_csv('v2.8.6_RF_pairs.csv')

## Start of Prediction Performance Analysis:

In [None]:
# Load Drug-Disease Pairs
pairs = [('PUBCHEM.COMPOUND:456255', 'MONDO:0021680'),('PUBCHEM.COMPOUND:4583', 'MONDO:0005247'),('PUBCHEM.COMPOUND:2082', 'MONDO:0019444'),
         ('PUBCHEM.COMPOUND:4583', 'MONDO:0024355'),('PUBCHEM.COMPOUND:4583', 'MONDO:0005972'),('PUBCHEM.COMPOUND:9782', 'MONDO:0011786'),
         ('PUBCHEM.COMPOUND:456255', 'MONDO:0020920'),('PUBCHEM.COMPOUND:9782', 'MONDO:0006608'),('PUBCHEM.COMPOUND:38103', 'MONDO:0002258'),
         ('PUBCHEM.COMPOUND:4485', 'MONDO:0001134'),('PUBCHEM.COMPOUND:51039', 'MONDO:0005230'),('PUBCHEM.COMPOUND:9782', 'MONDO:0005185'),
         ('PUBCHEM.COMPOUND:5865', 'MONDO:0000870'),('PUBCHEM.COMPOUND:2726', 'MONDO:0011295'),('CHEMBL.COMPOUND:CHEMBL1481', 'MONDO:0012818'),
         ('PUBCHEM.COMPOUND:6741', 'MONDO:0006554'),('PUBCHEM.COMPOUND:2662', 'MONDO:0011923'),('PUBCHEM.COMPOUND:2710', 'MONDO:0005324'),
         ('PUBCHEM.COMPOUND:5743', 'MONDO:0004967'),('PUBCHEM.COMPOUND:456255', 'MONDO:0005545'),('PUBCHEM.COMPOUND:5591', 'MONDO:0010894'),
         ('PUBCHEM.COMPOUND:38103', 'MONDO:0009813'),('PUBCHEM.COMPOUND:456255', 'MONDO:0005970'),('PUBCHEM.COMPOUND:5743', 'MONDO:0006670'),
         ('PUBCHEM.COMPOUND:126941', 'MONDO:0018874'),('PUBCHEM.COMPOUND:441199', 'MONDO:0006816'),('PUBCHEM.COMPOUND:60149', 'MONDO:0011307'),
         ('PUBCHEM.COMPOUND:5865', 'MONDO:0005306'),('PUBCHEM.COMPOUND:5073', 'MONDO:0012054'),('PUBCHEM.COMPOUND:2082', 'MONDO:0005996'),
         ('PUBCHEM.COMPOUND:456255', 'MONDO:0005229'),('PUBCHEM.COMPOUND:4583', 'MONDO:0004264'),('PUBCHEM.COMPOUND:9782', 'MONDO:0008728'),
         ('PUBCHEM.COMPOUND:6741', 'MONDO:0007243'),('PUBCHEM.COMPOUND:3878', 'MONDO:0005384'),('PUBCHEM.COMPOUND:5503', 'MONDO:0011667'),
         ('PUBCHEM.COMPOUND:3394', 'MONDO:0005178'),('PUBCHEM.COMPOUND:5591', 'MONDO:0007453'),('PUBCHEM.COMPOUND:2726', 'MONDO:0013498'),
         ('PUBCHEM.COMPOUND:4912', 'MONDO:0012893'),('PUBCHEM.COMPOUND:4921', 'MONDO:0008143'),('PUBCHEM.COMPOUND:5754', 'MONDO:0005984'),
         ('PUBCHEM.COMPOUND:31307', 'MONDO:0019558'),('PUBCHEM.COMPOUND:3478', 'MONDO:0011668'),('PUBCHEM.COMPOUND:5591', 'MONDO:0013242'),
         ('PUBCHEM.COMPOUND:2726', 'MONDO:0011280'),('PUBCHEM.COMPOUND:5503', 'MONDO:0014674'),('PUBCHEM.COMPOUND:9782', 'MONDO:0013108'),
         ('PUBCHEM.COMPOUND:5865', 'MONDO:0007817'),('PUBCHEM.COMPOUND:5743', 'MONDO:0005615'),('PUBCHEM.COMPOUND:5865', 'MONDO:0012105'),
         ('PUBCHEM.COMPOUND:38103', 'MONDO:0001039'),('PUBCHEM.COMPOUND:126941', 'MONDO:0008903'),('PUBCHEM.COMPOUND:9782', 'MONDO:0010273'),
         ('PUBCHEM.COMPOUND:5282411', 'MONDO:0005709'),('PUBCHEM.COMPOUND:3690', 'MONDO:0004948'),('PUBCHEM.COMPOUND:5865', 'MONDO:0000607'),
         ('PUBCHEM.COMPOUND:3446', 'MONDO:0000675'),('PUBCHEM.COMPOUND:4583', 'MONDO:0005297'),('PUBCHEM.COMPOUND:2082', 'MONDO:0005654'),
         ('PUBCHEM.COMPOUND:4583', 'MONDO:0005619'),('PUBCHEM.COMPOUND:444008', 'MONDO:0008159'),('PUBCHEM.COMPOUND:5865', 'MONDO:0011849'),
         ('PUBCHEM.COMPOUND:5743', 'MONDO:0018479'),('PUBCHEM.COMPOUND:4583', 'MONDO:0005246'),('PUBCHEM.COMPOUND:126941', 'MONDO:0004638'),
         ('PUBCHEM.COMPOUND:9782', 'MONDO:0015898'),('PUBCHEM.COMPOUND:5865', 'MONDO:0004126'),('PUBCHEM.COMPOUND:5388906', 'MONDO:0005861'),
         ('PUBCHEM.COMPOUND:2726', 'MONDO:0005618'),('PUBCHEM.COMPOUND:5743', 'MONDO:0021166'),('PUBCHEM.COMPOUND:6741', 'MONDO:0006545'),
         ('PUBCHEM.COMPOUND:51039', 'MONDO:0004652'),('PUBCHEM.COMPOUND:5743', 'MONDO:0008558'),('PUBCHEM.COMPOUND:9782', 'MONDO:0005083'),
         ('PUBCHEM.COMPOUND:5282493', 'MONDO:0005480'),('PUBCHEM.COMPOUND:5865', 'MONDO:0001713'),('PUBCHEM.COMPOUND:9782', 'MONDO:0001405'),
         ('CHEMBL.COMPOUND:CHEMBL1481', 'MONDO:0011955'),('PUBCHEM.COMPOUND:5503', 'MONDO:0005148'),('PUBCHEM.COMPOUND:3823', 'MONDO:0001461'),
         ('PUBCHEM.COMPOUND:5865', 'MONDO:0005554'),('PUBCHEM.COMPOUND:5865', 'MONDO:0004956'),('PUBCHEM.COMPOUND:9782', 'MONDO:0006547'),
         ('PUBCHEM.COMPOUND:3878', 'MONDO:0010826'),('PUBCHEM.COMPOUND:4583', 'MONDO:0005945'),('PUBCHEM.COMPOUND:9782', 'MONDO:0008729'),
         ('PUBCHEM.COMPOUND:3823', 'MONDO:0005915'),('PUBCHEM.COMPOUND:5770', 'MONDO:0011294'),('PUBCHEM.COMPOUND:5578', 'MONDO:0024330'),
         ('PUBCHEM.COMPOUND:6216', 'MONDO:0004980'),('PUBCHEM.COMPOUND:3823', 'MONDO:0005672'),('PUBCHEM.COMPOUND:2726', 'MONDO:0011498'),
         ('PUBCHEM.COMPOUND:5073', 'MONDO:0011552'),('PUBCHEM.COMPOUND:3878', 'MONDO:0016532'),('PUBCHEM.COMPOUND:135398735', 'MONDO:0012268'),
         ('PUBCHEM.COMPOUND:5865', 'MONDO:0015614'),('PUBCHEM.COMPOUND:5754', 'MONDO:0014241'),('PUBCHEM.COMPOUND:4583', 'MONDO:0018076'),
         ('PUBCHEM.COMPOUND:5743', 'MONDO:0006042'),('PUBCHEM.COMPOUND:2082', 'MONDO:0001103'),('PUBCHEM.COMPOUND:441401', 'MONDO:0024313'),
         ('PUBCHEM.COMPOUND:60750', 'MONDO:0024879'),('PUBCHEM.COMPOUND:4583', 'MONDO:0005124'),('PUBCHEM.COMPOUND:5865', 'MONDO:0009539'),
         ('CHEMBL.COMPOUND:CHEMBL1481', 'MONDO:0011072'),('PUBCHEM.COMPOUND:5865', 'MONDO:0005556'),('PUBCHEM.COMPOUND:451668', 'MONDO:0005374'),
         ('PUBCHEM.COMPOUND:2554', 'MONDO:0008414'),('PUBCHEM.COMPOUND:5865', 'MONDO:0001198'),('PUBCHEM.COMPOUND:5865', 'MONDO:0019203'),
         ('PUBCHEM.COMPOUND:5743', 'MONDO:0009971'),('PUBCHEM.COMPOUND:5865', 'MONDO:0001509'),('PUBCHEM.COMPOUND:9782', 'MONDO:0013294'),
         ('PUBCHEM.COMPOUND:5324346', 'MONDO:0001247'),('PUBCHEM.COMPOUND:5754', 'MONDO:0004773'),('PUBCHEM.COMPOUND:135398745', 'MONDO:0002009'),
         ('PUBCHEM.COMPOUND:135398735', 'MONDO:0000369'),('PUBCHEM.COMPOUND:5865', 'MONDO:0012935'),('PUBCHEM.COMPOUND:5073', 'MONDO:0013089'),
         ('PUBCHEM.COMPOUND:6741', 'MONDO:0008725'),('PUBCHEM.COMPOUND:441314', 'MONDO:0013240'),('PUBCHEM.COMPOUND:60750', 'MONDO:0021040'),
         ('PUBCHEM.COMPOUND:6741', 'MONDO:0010762'),('PUBCHEM.COMPOUND:6741', 'MONDO:0043455'),('PUBCHEM.COMPOUND:9782', 'MONDO:0006548'),
         ('PUBCHEM.COMPOUND:3151', 'MONDO:0002268'),('PUBCHEM.COMPOUND:441401', 'MONDO:0017776'),('PUBCHEM.COMPOUND:2576', 'MONDO:0001830'),
         ('PUBCHEM.COMPOUND:6741', 'MONDO:0005377'),('PUBCHEM.COMPOUND:6509979', 'MONDO:0011598'),('CHEMBL.COMPOUND:CHEMBL1481', 'MONDO:0011363'),
         ('PUBCHEM.COMPOUND:9782', 'MONDO:0013893'),('PUBCHEM.COMPOUND:3478', 'MONDO:0012348'),('PUBCHEM.COMPOUND:6741', 'MONDO:0013730'),
         ('PUBCHEM.COMPOUND:51039', 'MONDO:0005892'),('PUBCHEM.COMPOUND:5865', 'MONDO:0019127'),('PUBCHEM.COMPOUND:5754', 'MONDO:0004857'),
         ('PUBCHEM.COMPOUND:126941', 'MONDO:0005340'),('PUBCHEM.COMPOUND:5865', 'MONDO:0005093'),('PUBCHEM.COMPOUND:2726', 'MONDO:0013506'),
         ('PUBCHEM.COMPOUND:4173', 'MONDO:0002154'),('PUBCHEM.COMPOUND:4173', 'MONDO:0006669'),('PUBCHEM.COMPOUND:2165', 'MONDO:0005920'),
         ('PUBCHEM.COMPOUND:4659569', 'MONDO:0014604'),('PUBCHEM.COMPOUND:3394', 'MONDO:0008383'),('PUBCHEM.COMPOUND:5311066', 'MONDO:0006569'),
         ('PUBCHEM.COMPOUND:92727', 'MONDO:0005109'),('PUBCHEM.COMPOUND:3823', 'MONDO:0001648'),('PUBCHEM.COMPOUND:5865', 'MONDO:0002081'),
         ('PUBCHEM.COMPOUND:6741', 'MONDO:0043789'),('PUBCHEM.COMPOUND:5865', 'MONDO:0012318'),('PUBCHEM.COMPOUND:5503', 'MONDO:0011027'),
         ('PUBCHEM.COMPOUND:5865', 'MONDO:0008219')]

### Extract Prediction Probability Scores of Drug-Disease Pairs 

In [None]:
from tqdm.notebook import tqdm
from time import sleep

def filter_dataframe(df, pairs):
    extracted_elements = []
    for drug_id, disease_id in tqdm(pairs, desc='Filtering rows'):
        # Filter the DataFrame for rows where 'drug_id' and 'disease_id' match the current pair
        matched_rows = df[(df['drug_id'] == drug_id) & (df['disease_id'] == disease_id)]
        
        # Append the matched rows to your list (if any)
        if not matched_rows.empty:
            extracted_elements.append(matched_rows)
    
    return extracted_elements

# extracted_elements contains the DataFrames from each dataframes that match the pairs
df1_filt = filter_dataframe(df1, pairs)
df2_filt = filter_dataframe(df2, pairs)
df3_filt = filter_dataframe(df3, pairs)

In [None]:
df1_pairs = pd.concat(df1_filt, ignore_index=True).drop(columns=['tn_score', 'unknown_score']) # v2.8.0.1
df2_pairs = pd.concat(df2_filt, ignore_index=True).drop(columns=['tn_score', 'unknown_score']) # v2.8.3
df3_pairs = pd.concat(df3_filt, ignore_index=True).drop(columns=['tn_score', 'unknown_score']) # v.2.8.6

In [None]:
# Display the prediction score in dataframe format
df1_pairs # v2.8.0.1

In [None]:
df2_pairs # v2.8.3

In [None]:
df3_pairs # v2.8.6

#### Compile the data model versions 2.8.0.1, 2.8.3 & 2.8.6

In [None]:
merged_df = pd.merge(pd.merge(df1_pairs, df2_pairs, on=['disease_id', 'drug_id'], 
                              how='inner', suffixes=('_df1', '_df2')), 
                     df3_pairs, on=['disease_id', 'drug_id'], how='inner')
merged_df

### Scatter Plot of Drug-Disease Pairs Prediction

In [None]:
# Create the scatter plot for each model by each disease id

## 100 diseases ##
df_100 = merged_df.head(100)

# Create scatter plot
plt.figure(figsize=(10, 6))

# Plotting each tp_score
plt.scatter(df_100.index, df_100['tp_score_df1'], color= '#DC267F', marker = 'o', label='KG2.8.0.1')
plt.scatter(df_100.index, df_100['tp_score_df2'], color= '#648FFF', marker = 's', label='KG2.8.3')
plt.scatter(df_100.index, df_100['tp_score'], color= '#FFB000', marker = 'D', label='v2.8.6')

# Adding title and labels
#plt.title('RF Prediction Probability Performance')
plt.xlabel('Disease ID')
plt.ylabel('Probability (TP Score)')
plt.xticks([0,10,20,30,40,50,60,70,80,90,100])
plt.xticks(rotation=45)
plt.legend(loc='lower right')
plt.tight_layout()

# Save the figure
plt.savefig('Fig3.png',dpi=1000)

# Display plot
plt.show()

In [None]:
## 154 Drug-Disease Pairs ##
plt.figure(figsize=(10, 6))

# Plotting each tp_score, row idex will be used as proxy to disease_id
plt.scatter(merged_df.index, merged_df['tp_score_df1'], color= '#DC267F', marker = 'o', label='v2.8.0.1')
plt.scatter(merged_df.index, merged_df['tp_score_df2'], color= '#648FFF', marker = 's', label='v2.8.3')
plt.scatter(merged_df.index, merged_df['tp_score'], color= '#FFB000', marker = 'D', label='v2.8.6')

# Adding title and labels
plt.title('Scatter plot of tp_scores vs disease_id')
plt.xlabel('Disease ID')
plt.ylabel('Tp Scores')
plt.xticks([0,10,20,30,40,50,60,70,80,90,100,110,120,130,140,150])
plt.xticks(rotation=45)
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
## 10 diseases ##
df_10 = merged_df.head(10)

# Create scatter plot
plt.figure(figsize=(10, 6))

# Plotting each tp_score
plt.scatter(df_10.index, df_10['tp_score_df1'], color= '#DC267F', marker = 'o', label='v2.8.0.1')
plt.scatter(df_10.index, df_10['tp_score_df2'], color= '#648FFF', marker = 's', label='v2.8.3')
plt.scatter(df_10.index, df_10['tp_score'], color= '#FFB000', marker = 'D', label='v2.8.6')

# Adding title and labels
plt.title('Scatter plot of tp_scores vs disease_id')
plt.xlabel('Disease ID')
plt.ylabel('Tp Scores')
plt.xticks(rotation=45)
plt.xticks([0,1,2,3,4,5,6,7,8,9])
#plt.yticks([0.80,0.85,0.90,0.95,1])
plt.legend(loc='lower left')
plt.tight_layout()
plt.show()

In [None]:
# Creating box plot to show the distribution of prediction probabilities by model

plt.figure(figsize=(8, 6), dpi=300)  # Adjust the DPI for high resolution

# Creating the box plot
boxplot = plt.boxplot([merged_df['tp_score_df1'], merged_df['tp_score_df2'], merged_df['tp_score']], 
                      labels=['v2.8.0.1', 'v2.8.3', 'v2.8.6'],
                      patch_artist=True,  # Fill with color
                      medianprops=dict(color='black'))  # Median line color

# Customizing colors for boxes, whiskers, caps, and outliers
colors = ['#DC267F', '#648FFF', '#FFB000']
for patch, color in zip(boxplot['boxes'], colors):
    patch.set_facecolor(color)  # Box color
    patch.set_edgecolor('black')  # Box edge color

for whisker, cap in zip(boxplot['whiskers'], boxplot['caps']):
    whisker.set_color('black')  # Whisker color
    whisker.set_linewidth(1.5)  # Whisker line width
    cap.set_color('black')  # Cap color

for flier in boxplot['fliers']:
    flier.set(marker='o', color='red', alpha=0.5)  # Outlier color

# Adding title and labels
plt.title('RF Prediction Probability Performance', fontsize=16)  # Increase font size for title
plt.ylabel('Probability (TP Score)', fontsize=14)  # Increase font size for ylabel
plt.xticks(fontsize=12)  # Increase font size for x-axis labels
plt.yticks(fontsize=12)  # Increase font size for y-axis labels

# Add gridlines
plt.grid(True)

# Adjust the layout for better spacing
plt.tight_layout()

# Save the plot as a high-quality image
#plt.savefig('RF_Prediction_Probability_Performance_Boxplot.png', dpi=300)

plt.show()
