In [1]:
from math import sqrt
from ipywidgets import IntProgress
from IPython.display import display
from ipywidgets import FileUpload
from IPython.display import display, clear_output
from ipywidgets import BoundedFloatText, widgets, VBox, HBox
import plotly.express as px
import numpy as np
import pandas as pd
from pycaret.regression import *
import os
from sklearn.metrics import r2_score, mean_squared_error, root_mean_squared_error

## Proffinity ML Prediction Module

In [2]:
content={"input":"","model":"","outname":""}

In [3]:
out = widgets.Output(layout={'border': '1px solid black','width':'1000px'})
#out

In [4]:
def load_ipywiget_model(widget):
    content['model'] = list(widget['new'])[0].split('.')[0]
    with out:
        display("load "+content['model'])

In [5]:
def on_button_clicked(b):
    path = './'
    dropdown = widgets.SelectMultiple(
                        options=[f for f in os.listdir(path) if f.endswith('.pkl')],
                        description='models',
                        disabled=False,
                        layout={'height':'100px', 'width':'40%'})
    dropdown.observe(load_ipywiget_model, names='value')
    

    with out:
        clear_output()
        display(dropdown)

In [6]:
button = widgets.Button(description="Load Model")
button.on_click(on_button_clicked)
#button

In [7]:
out2 = widgets.Output(layout={'border':'1px solid black','width':'1000px'})
#out2

In [8]:
def load_input(widget):
    content['input'] = list(widget['new'])[0]
    content['outname'] = list(widget['new'])[0].split('.')[0]
    df = pd.read_csv(content["input"], header=0, index_col=0)
    
    with out2:
        display("load "+content['input'])
        display(df)

In [9]:
def on_button2_clicked(b):
    path = './'
    dropdown = widgets.SelectMultiple(
                        options=[f for f in os.listdir(path) if f.endswith('.csv') and f.startswith('ppi_index_')],
                        description='inputs',
                        disabled=False,
                        layout={'height':'100px', 'width':'40%'})
    dropdown.observe(load_input, names='value')

    with out2:
        clear_output()
        display(dropdown)

In [10]:
button2 = widgets.Button(description="Load Inputs")
button2.on_click(on_button2_clicked)
#button2

In [11]:
out3 = widgets.Output(layout={'border':'1px solid black','width':'1000px'})
#out3

In [12]:
def on_button3_clicked(b):
    
    saved_ml = load_model(content['model'])
    df_test=pd.read_csv(content["input"], header=0, index_col=0)
    pred_test=predict_model(saved_ml,data=df_test)
    pred_test_output=pred_test[['kd','prediction_label']]
    pred_test_output.to_csv(content["outname"]+"_pred.csv")
    pred_test_output['data']=content['outname']

    pred_ground=pd.read_csv('./background/ppi_index_extract_skempiv2_pred.csv', header=0, index_col=0)
    pred_ground_output=pred_ground[['kd','prediction_label']]
    pred_ground_output['data']='skempiv2-ref'

    pred_output=pd.concat([pred_ground_output,pred_test_output])

    
    fig = px.scatter(pred_output, x="kd", y="prediction_label", color="data", opacity=0.6)
    r2=r2_score(pred_test_output['kd'],pred_test_output['prediction_label'])
    mse=mean_squared_error(pred_test_output['kd'],pred_test_output['prediction_label'])
    
    with out3:
        clear_output()
        display(pred_output)
        display(fig)
        display("r2_score:"+str(r2))
        display("mse:"+str(mse))
        #display(HBox([pred_output,fig]))
        display("save to "+content["outname"]+"_pred.csv")

In [13]:
def on_button3_clickedv2(b):
    
    saved_ml = load_model(content['model'])
    df_test=pd.read_csv(content["input"], header=0, index_col=0)

    #divide into test and validation set
    df_tset=df_test[df_test['kd'].isna()==1]
    df_tset=df_tset.drop(columns=['kd'])
    df_vset=df_test[df_test['kd'].isna() != 1]

    if df_tset.empty == False:
        df_tset_pred=predict_model(saved_ml,data=df_tset)
        df_tset_output=df_tset_pred[['prediction_label']]
        df_tset_output['data']=content['outname']
        df_rset_pred=pd.read_csv('./background/'+content['model'].split('_')[2]+'_pred.csv', header=0, index_col=0)
        df_rset_output=df_rset_pred[['prediction_label']]
        df_rset_output['data']='skempiv2-ref'
        df_tset_final_output=pd.concat([df_rset_output,df_tset_output])
        tset_mean=df_tset_output['prediction_label'].mean()
        tset_max=df_tset_output['prediction_label'].max()
        tset_min=df_tset_output['prediction_label'].min()
        rset_mean=df_rset_output['prediction_label'].mean()
        rset_max=df_rset_output['prediction_label'].max()
        rset_min=df_rset_output['prediction_label'].min()
        
        fig_tset = px.histogram(df_tset_final_output,x="prediction_label",color="data",barmode="overlay",histnorm="percent",width=600, height=400 )

        #plot histrogram

    if df_vset.empty == False:
        df_vset_pred=predict_model(saved_ml,data=df_vset)
        df_vset_output=df_vset_pred[['kd','prediction_label']]
        df_vset_bar=df_vset_output
        df_vset_bar['id']=df_vset_bar.index
        df_vset_bar2=pd.melt(df_vset_bar, id_vars=["id"], value_vars=["kd", "prediction_label"])
        fig_vset_bar = px.bar(df_vset_bar2, x='id', y='value', color='variable',barmode='group',width=600, height=400)
        
        df_vset_output['data']=content['outname']
        df_rset_pred=pd.read_csv('./background/'+content['model'].split('_')[2]+'_pred.csv', header=0, index_col=0)
        df_rset_output=df_rset_pred[['kd','prediction_label']]
        df_rset_output['data']=content['model'].split('_')[2]
        df_vset_final_output=pd.concat([df_rset_output,df_vset_output])
        fig_vset = px.scatter(df_vset_final_output, x="kd", y="prediction_label", color="data", opacity=0.6,width=600, height=400)
        
        vset_r2=round(r2_score(df_vset_output['kd'],df_vset_output['prediction_label']),3)
        rset_r2=round(r2_score(df_rset_output['kd'],df_rset_output['prediction_label']),3)
        vset_mse=round(mean_squared_error(df_vset_output['kd'],df_vset_output['prediction_label']),3)
        rset_mse=round(mean_squared_error(df_rset_output['kd'],df_rset_output['prediction_label']),3)
    
    with out3:
        clear_output()
        if df_vset.empty == False:
            display(fig_vset)
            display("validation set r2_score:"+str(vset_r2)+" / reference set r2_score:"+str(rset_r2))
            display(fig_vset_bar)
            display("validation set mse:"+str(vset_mse)+" / reference set mse:"+str(rset_mse))
            df_vset_pred.to_csv("pred_"+content["outname"]+"_by_"+content['model'].split('_')[2]+"_vset.csv")
            display("save to pred_"+content["outname"]+"_by_"+content['model'].split('_')[2]+"_vset.csv")
        if df_tset.empty == False:
            display(fig_tset)
            display("test set mean:"+str(tset_mean)+" max:"+str(tset_max)+" min:"+str(tset_min))
            display("reference set mean:"+str(rset_mean)+" max:"+str(rset_max)+" min:"+str(rset_min))
            #df_tset_pred.to_csv(content['outname']+"_pred.csv")
            df_tset_pred.to_csv("pred_"+content["outname"]+"_by_"+content['model'].split('_')[2]+"_tset.csv")
            display("save to pred_"+content["outname"]+"_by_"+content['model'].split('_')[2]+"_tset.csv")

In [14]:
button3 = widgets.Button(description="Prediction")
button3.on_click(on_button3_clickedv2)
#button3

In [15]:
scene = HBox([VBox([out,button,out2,button2]),VBox([out3,button3]) ])

In [16]:
display(scene)

HBox(children=(VBox(children=(Output(layout=Layout(border_bottom='1px solid black', border_left='1px solid bla…

### Instruction
step1. Load model: load an extra tree regressor ML model for binding affinity prediction. Possible options:
<br>
- saved_model_skempiv2-ubv-denovo-pdbb.pkl: trained on SKEMPIv2+DeNovo+PDBBind dataset 
- saved_model_skempiv2-ubv-denovo.pkl: trained on SKEMPIv2+DeNovo dataset 
- saved_model_skempiv2.pkl: trained on SKEMPIv2 dataset
- customized model: please see ml_regressor_train.ipynb notebook for how to train customized model
<br>

step2. Load Inputs: load a extracted feature matrix from module 1. The input file should started with "ppi_index_extract_..."
<br>
step3. Prediction: show prediction output.
<br>
##### Note: if the Prediction does not show, check input "ppi_index_extract...". Make sure no duplicate tables/headers within a single csv file.