In [1]:
from __future__ import annotations
import pandas as pd
from pathlib import Path
import numpy as np
from dataclasses import dataclass
import plotly.express as px
import json
pd.set_option('display.max_rows', 500)

### Define quality criteria

In [2]:
@dataclass
class TestCriteria:
    max_entry_resolution: float = 3.5
    max_entry_r = 0.4
    max_entry_rfree = 0.45
    max_entry_r_minus_rfree = 0.05
    ligand_max_num_unresolved_heavy_atoms = 0
    ligand_max_alt_count = 1
    ligand_min_average_occupancy: float = 0.8
    ligand_min_average_rscc: float = 0.8
    ligand_max_average_rsr: float = 0.3
    ligand_max_percent_outliers_clashes = 0
    pocket_max_num_unresolved_heavy_atoms = 0
    pocket_max_alt_count = 1
    pocket_min_average_occupancy: float = 0.8
    pocket_min_average_rscc: float = 0.8
    pocket_max_average_rsr: float = 0.3
    pocket_max_percent_outliers_clashes = 100


def get_high_quality_systems(
    row: pd.Series,
    criteria: TestCriteria
) -> bool:
    if row.system_type != "holo":
        return False
    if row.entry_r is not None and row.system_ligand_average_rscc is not None:
        quality = [
            # ENTRY
            row.entry_resolution <= criteria.max_entry_resolution,
            row.entry_r <= criteria.max_entry_r,
            row.entry_rfree <= criteria.max_entry_rfree,
            row.entry_r_minus_rfree <= criteria.max_entry_r_minus_rfree,
            # LIGAND
            row.system_ligand_num_unresolved_heavy_atoms <= row.system_num_covalent_ligands + criteria.ligand_max_num_unresolved_heavy_atoms,
            row.system_ligand_max_alt_count <= criteria.ligand_max_alt_count, # NOTE: max_alt_count is misnomer - this counts number of total conformers!
            row.system_ligand_average_occupancy >= criteria.ligand_min_average_occupancy,
            row.system_ligand_average_rscc >= criteria.ligand_min_average_rscc,
            row.system_ligand_average_rsr <= criteria.ligand_max_average_rsr,
            row.system_ligand_percent_outliers_clashes <= criteria.ligand_max_percent_outliers_clashes,
            # POCKET
            row.system_pocket_num_unresolved_heavy_atoms <= criteria.pocket_max_num_unresolved_heavy_atoms,
            row.system_pocket_max_alt_count <= criteria.pocket_max_alt_count,
            row.system_pocket_average_occupancy >= criteria.pocket_min_average_occupancy,
            row.system_pocket_average_rscc >= criteria.pocket_min_average_rscc,
            row.system_pocket_average_rsr <= criteria.pocket_max_average_rsr,
            row.system_pocket_percent_outliers_clashes <= criteria.pocket_max_percent_outliers_clashes,
        ]
        if np.logical_and.reduce(quality):
            return True
    return False

quality_config = TestCriteria()

df = pd.read_parquet("gs://plinder-collab-bucket/2024-04/v1/index/annotation_table.parquet")
df["system_num_covalent_ligands"] = df.groupby("system_id")["ligand_is_covalent"].transform("sum")
df["passes_quality"] = df.apply(lambda row: get_high_quality_systems(row, criteria=quality_config), axis=1)

all_systems_passing_quality = set(df[df["passes_quality"]]["system_id"])



### Input literature reference to performance of different datasets

In [None]:
imputed_diffdock_performance_pdbbind = {
    # Reference: https://arxiv.org/abs/2402.18396
    "PDBBIND-TIME": {
        "Split-test": {"SR-1-mean": 35.0  , "SR-1-std": 0.00, "SR-10-mean": 48, "SR-10-std": np.nan},
        "PoseBusters": {"SR-1-mean": 38.0, "SR-1-std": 0.00, "SR-10-mean": np.nan, "SR-10-std":np.nan}},
    }

In [None]:
def evaluate_performance(
    prediction_csv: pd.DataFrame,
    target_sys_ids: list[str],
    test_set_tag="Split-test",
    topn: int = 1,
    in_target_list: bool =True,
    rmsd_threshold:int = 2,
    use_all: bool = False) -> float:
    """
    Compute performance metric from predictions csvs
    """
    # Load CSV
    prediction_df = pd.read_csv(prediction_csv)
    if test_set_tag == "PoseBusters":
        # Select top N
        prediction_df = prediction_df[prediction_df["rank"] <= topn]
        # Select all that passed based on quality
        success_df = prediction_df[prediction_df.rmsd <= rmsd_threshold]
        return (len(prediction_df.id.unique()),
                len(success_df.id.unique())/len(prediction_df.id.unique())*100)
    # Select systems by quality
    if in_target_list and not use_all:
        prediction_df_hq = prediction_df[prediction_df.id.isin(target_sys_ids)]
    elif (not in_target_list) and use_all:
        prediction_df_hq = prediction_df[~(prediction_df.id.isin(target_sys_ids))]
    elif use_all:
        prediction_df_hq = prediction_df.copy()
    # Select top N
    prediction_df_hq = prediction_df_hq[prediction_df_hq["rank"] <= topn]
    # Select all that passed based on quality
    success_df = prediction_df_hq[prediction_df_hq.rmsd <= rmsd_threshold]
    return (len(prediction_df_hq.id.unique()),
            len(success_df.id.unique())/len(prediction_df_hq.id.unique())*100)


def compare_performance(
    dict_of_prediction_csvs: dict[Path],
    target_sys_ids: list[str],
    test_set_tag="Split-test",
    topn = 1,
    in_target_list: bool =True,
    rmsd_threshold: int = 2,
    use_all: bool = False) -> pd.DataFrame:
    output_dict = []
    """
    Compares performance between splits defined by `dict_of_prediction_csvs`
    """
    for split, list_of_prediction_csvs in dict_of_prediction_csvs.items():
        results = []
        for prediction_csv in list_of_prediction_csvs:
            size, perf = evaluate_performance(
                prediction_csv,
                target_sys_ids,
                test_set_tag,
                topn, in_target_list, rmsd_threshold, use_all)
            results.append(perf)
        output_dict.append([split.upper(), size, test_set_tag, np.mean(results), np.std(results)])
    return pd.DataFrame(output_dict, columns=["Split", "Size", "Test_tag", f"SR-{topn}-mean",  f"SR-{topn}-std"])

def get_top1_and_top10_performance(
    dict_of_prediction_csvs: dict[Path],
    target_sys_ids: list[str],
    test_set_tag="Split-test",
    in_target_list: bool =True,
    rmsd_threshold: int = 2,
    use_all: bool = False) -> pd.DataFrame:
    """
    Compute top 1 and top 10 performance metric from predictions csvs
    """
    # Get top1 perfromance
    top1_df = compare_performance(
        dict_of_prediction_csvs,
        target_sys_ids,
        test_set_tag,
        topn = 1,
        in_target_list=in_target_list,
        rmsd_threshold=rmsd_threshold,
        use_all=use_all)
     # Get top10 perfromance
    top10_df = compare_performance(
        dict_of_prediction_csvs,
        target_sys_ids,
        test_set_tag,
        topn = 10,
        in_target_list=in_target_list,
        rmsd_threshold=rmsd_threshold,
        use_all=use_all)
    return pd.merge(top1_df, top10_df[["Split", "SR-10-mean", "SR-10-std"]], on="Split")


def get_top1_and_top10_perfromance_stratified(
        performance_df_dict: dict[pd.DataFrame],
        tag="",  exclude=None):
    df_list = []
    """
    Compute top 1 and top 10 performance metric stratified by quality
    """
    for idx, (label, df) in enumerate(performance_df_dict.items()):
        df[tag] = label
        df_list.append(df)
    combined_df = pd.concat(df_list)

    if exclude is not None:
        combined_df = combined_df[combined_df.Split != exclude]
    return combined_df

def extract_leakage_from_json(
        leagake_json,  split,
        test_set_tag="Split-test",
        is_high_quality=True,
        links_threshold=0):
    with open(leagake_json, "r") as f:
        data = json.load(f)
    if is_high_quality and test_set_tag == "Split-test":
        tag_name = f"fraction_leaked_train_{test_set_tag.lower().split('-')[-1]}_hq_{links_threshold}"
    else:
        tag_name = f"fraction_leaked_train_{test_set_tag.lower().split('-')[-1]}_{links_threshold}"
    qual = "High"
    return qual, data[split][tag_name]

def extract_data_size_from_json(
        leagake_json,  split):
    with open(leagake_json, "r") as f:
        data = json.load(f)
    return (data[split]["num_train"],
        data[split]["num_val"],
          data[split]["num_test"])

def make_leakage_dataframe(
        dict_of_all_leakage_metrics,
        list_of_splits, test_set_tag="Split-test",
        is_high_quality=True, links_threshold=0):
    results = []
    for metric, leagake_json in dict_of_all_leakage_metrics.items():
        for split in list_of_splits:
            leakage = extract_leakage_from_json(
                leagake_json, split,
                test_set_tag, is_high_quality,
                links_threshold)
            results.append([metric, split.upper(), test_set_tag, *leakage])
    return pd.DataFrame(results, columns=["Metric", "Split", "Test_tag", "Quality", "Leakage"])


def get_all_merged_pred_dataset(
    dict_of_prediction_csvs: dict[Path],
    leakage_json_dict:  dict[Path],
    splits: list[str],
    target_sys_ids: list[str],
    reference_dict_to_input,
    rmsd_threshold: int = 2, is_high_quality=True, links_threshold=0):
    final_preds = []
    for test_set_tag, prediction_csv_dict in dict_of_prediction_csvs.items():
        df_hq = get_top1_and_top10_performance(
            prediction_csv_dict ,
            target_sys_ids,
            test_set_tag=test_set_tag,
            in_target_list=True,
            rmsd_threshold= rmsd_threshold,
            use_all= False)
        if test_set_tag != "PoseBusters":
            df_all = get_top1_and_top10_performance(
                prediction_csv_dict,
                target_sys_ids,
                test_set_tag=test_set_tag,
                use_all=True)
        stratififed_df = get_top1_and_top10_perfromance_stratified(
            {"High+Low": df_all, "High": df_hq}, tag="Quality")
        final_preds.append(stratififed_df )
    final_preds_df = pd.concat(final_preds, ignore_index=True).drop_duplicates()

    # Compute leakage
    split_leakage = make_leakage_dataframe(
        leakage_json_dict,
        splits, test_set_tag="Split-test",
        is_high_quality=is_high_quality, links_threshold=links_threshold)

    posebusters_leakage = make_leakage_dataframe(
        leakage_json_dict,
        splits, test_set_tag="PoseBusters",
        is_high_quality=is_high_quality, links_threshold=links_threshold)
    leakage_df = pd.concat([split_leakage, posebusters_leakage])
    merged_df = final_preds_df.merge(leakage_df, how="outer", on=['Split','Test_tag', 'Quality'])
    # Imput literature reference
    for split, v in reference_dict_to_input.items():
        for test_tag, sr in v.items():
            for sr_tag, val in sr.items():
                merged_df.loc[
                    (merged_df['Split'] == split) &\
                          (merged_df['Test_tag'] == test_tag), sr_tag] = val
    return final_preds_df, leakage_df, merged_df

def plot_diffdock_performance_vs_leakage(
        performance_vs_leakage_df,
        show_only = None, exclude_tanimoto=True, plot_dimension=(400, 650)):
    symbols = ['circle', 'square', 'star', 'diamond', 'hourglass' , 'pentagon']
    colors = ["lightblue", "pink", "lightgreen", "purple", "goldenrod" ]
    if exclude_tanimoto:
         performance_vs_leakage_df =  performance_vs_leakage_df[~ performance_vs_leakage_df.Metric.apply(lambda x : "LIGAND" in x)]

    if show_only is not None:
        performance_vs_leakage_df = performance_vs_leakage_df[performance_vs_leakage_df[show_only[0]].isin(show_only[1])]
    no_of_colors =  len(performance_vs_leakage_df.Metric.unique())
    fig = px.scatter(
         performance_vs_leakage_df,
        x="Leakage", y="SR-10-mean", color="Metric", trendline="lowess",
        color_discrete_sequence=colors[:no_of_colors],
        symbol="Metric",
        symbol_sequence=symbols[:no_of_colors],
                    height=plot_dimension[0],
                    width=plot_dimension[1])
    fig.update_layout(
        {"plot_bgcolor": "rgba(0, 0, 0, 0)"},
        legend=dict(
            x=0,
            y=1,
            traceorder="normal",
            font=dict(
                size=18,
                color="black"
            ),
        ),
        font=dict(
            size=18,  # Set the font size here
        ),
        yaxis_title="Success Rate",
        xaxis_title="Fraction of leaked systems",
        yaxis = dict(
        tickfont = dict(size=18)),
        xaxis = dict(
        tickfont = dict(size=18)),
        legend_title=None
    )
    fig.update_xaxes(showline=True, linewidth=2, linecolor='black', color='black')
    fig.update_yaxes(showline=True, linewidth=2, linecolor='black', color='black')
    fig.update_traces(marker=dict(size=12,
                                line=dict(width=1,
                                            color='DarkSlateGrey')),
                    selector=dict(mode='markers'))
    config = {
    'toImageButtonOptions': {
        'format': 'png', # one of png, svg, jpeg, webp
        'filename': 'top10_diffdock_vs_leakage',
        'scale':6 # Multiply title/legend/axis/canvas sizes by this factor
    }
    }
    fig.show(config=config)


def plot_diffdock_performance_bar(performance_dataset, pdbbind_diffdock_line, posebusters_line=None, topn=1):
    color_dict={"High": "#89CFF0", "High+Low": "#c9a0dc"}
    fig = px.bar(
        performance_dataset[
    (performance_dataset.Split.isin(["PLINDER-V0", "PLINDER-ECOD",  "PLINDER-TIME"])) & \
                ~((performance_dataset.Test_tag == "PoseBusters") \
                  & (performance_dataset.Quality == "High+Low"))][
                      ["Split", f"SR-{topn}-mean", f"SR-{topn}-std", "Quality", "Test_tag"]].drop_duplicates(),
        x="Split", y=f"SR-{topn}-mean",
                    color="Quality",
                    barmode='group',
                    color_discrete_map=color_dict,
                    color_continuous_scale=color_dict,
                    facet_col="Test_tag",
                    error_y=f"SR-{topn}-std",
                    error_y_minus=f"SR-{topn}-std",
                    height=400,
                    width=1200)
    fig.update_layout(
        {"plot_bgcolor": "rgba(0, 0, 0, 0)",
         "paper_bgcolor": "rgba(0,0,0,0)",
        "title": f"Top-{topn}"},

        font=dict(
            size=18,  # Set the font size here
        ),
        title=dict(font=dict(color="black")),
        legend=dict(font=dict(color="black", size=15)),
        yaxis_title="Success Rate",
        bargap=0.05,
        yaxis = dict(
        tickfont = dict(size=18)))

    fig.add_hline(y=pdbbind_diffdock_line, line_dash="dash",
            annotation=dict( x=0.35, y=pdbbind_diffdock_line, font_size=18,  align='left'),
            annotation_text="PDBBind-DiffDock", row=1, col=1, line_color="black"
            )
    if posebusters_line is not None:
        fig.add_hline(y=posebusters_line, line_dash="dash",
                annotation=dict( x=0.35, y=posebusters_line, font_size=18, align='left'),
                annotation_text="PDBBind-DiffDock", row=1, col=2,
                line_color="black"
                )
    fig.for_each_annotation(lambda a: a.update(text=a.text.replace("Test_tag=", "")))
    fig.update_annotations(font=dict( size=18, color="black"))
    fig.update_xaxes(showline=True, linewidth=2, linecolor='black', color='black')
    fig.update_yaxes(showline=True, linewidth=2, linecolor='black', color='black')

    config = {
    'toImageButtonOptions': {
        'format': 'png', # one of png, svg, jpeg, webp
        'filename': 'top10_diffdock_vs_leakage',
        'scale':6 # Multiply title/legend/axis/canvas sizes by this factor
    }
    }
    fig.show(config=config)


In [None]:
ROOT = Path("/Users/yusuf/plinder_local_data") # Change to your root path

### Set path to leakage jsons
> Computed by @Jay

In [None]:
leakage_json_dict = {
    'POCKET SHARED ≥ 30': Path(ROOT/'fractions_new/pocket_qcov__30.json'),
    'POCKET SHARED ≥ 50': Path(ROOT/'fractions_new/pocket_qcov__50.json'),
    'POCKET SHARED ≥ 70': Path(ROOT/'fractions_new/pocket_qcov__70.json'),
    'POCKET LDDT ≥ 30': Path(ROOT/'fractions_new/pocket_lddt__30.json'),
    'POCKET LDDT ≥ 50': Path(ROOT/'fractions_new/pocket_lddt__50.json'),
    'POCKET LDDT ≥ 70': Path(ROOT/'fractions_new/pocket_lddt__70.json'),
    'POCKET LDDT ≥ 90': Path(ROOT/'fractions_new/pocket_lddt__90.json'),
    'PROTEIN SEQSIM ≥ 30': Path(ROOT/'fractions_new/protein_seqsim_weighted_sum__30.json'),
    'PROTEIN SEQSIM ≥ 50': Path(ROOT/'fractions_new/protein_seqsim_weighted_sum__50.json'),
    'PLI SHARED ≥ 20': Path(ROOT/'fractions_new/pli_qcov__20.json'),
    'PLI SHARED ≥ 30': Path(ROOT/'fractions_new/pli_qcov__30.json'),
    'PLI SHARED ≥ 50': Path(ROOT/'fractions_new/pli_qcov__50.json'),
    'PLI SHARED ≥ 70': Path(ROOT/'fractions_new/pli_qcov__70.json')}

In [None]:
input_dict = {
    "plinder-v0": list(ROOT.glob("testset/*hard*.csv")),
    "plinder-time": list(ROOT.glob("testset/*time*.csv")),
    "plinder-ecod": list(ROOT.glob("testset/*ecod*.csv")),
    }

In [None]:
perf, leakage_df, perf_leakage = get_all_merged_pred_dataset(
    {
        "Split-test": {
    "plinder-v0": list(ROOT.glob("testset/*hard*.csv")),
    "plinder-time": list(ROOT.glob("testset/*time*.csv")),
    "plinder-ecod": list(ROOT.glob("testset/*ecod*.csv"))},
    "PoseBusters" : {
        "plinder-v0": list(ROOT.glob("pb_paper/*hard*.csv")),
        "plinder-time": list(ROOT.glob("pb_paper/*time*.csv")),
        "plinder-ecod": list(ROOT.glob("pb_paper/*ecod*.csv")),
    }
    },
    leakage_json_dict,
    ["plinder-ECOD", "plinder-v0",
     "plinder-Time", "PDBBind-Original",
    "PDBBind-LP", "PDBBind-Time" ],
    all_systems_passing_quality,
    imputed_diffdock_performance_pdbbind,
    rmsd_threshold= 2, is_high_quality=True, links_threshold=0)

#### Plot Diffdock performance vs leakage

In [None]:
perf.merge(leakage_df, on=['Split','Test_tag', 'Quality']).head()

Unnamed: 0,Split,Size,Test_tag,SR-1-mean,SR-1-std,SR-10-mean,SR-10-std,Quality,Metric,Leakage
0,PLINDER-V0,3518,Split-test,18.186341,0.256169,25.669798,0.527778,High,POCKET SHARED ≥ 30,0.160901
1,PLINDER-V0,3518,Split-test,18.186341,0.256169,25.669798,0.527778,High,POCKET SHARED ≥ 50,0.089836
2,PLINDER-V0,3518,Split-test,18.186341,0.256169,25.669798,0.527778,High,POCKET SHARED ≥ 70,0.047466
3,PLINDER-V0,3518,Split-test,18.186341,0.256169,25.669798,0.527778,High,POCKET LDDT ≥ 30,0.031108
4,PLINDER-V0,3518,Split-test,18.186341,0.256169,25.669798,0.527778,High,POCKET LDDT ≥ 50,0.0


#### Top 1 Diffdock performance

In [None]:
plot_diffdock_performance_bar(
    perf_leakage,
    pdbbind_diffdock_line=35, # PDBbind-DiffDock baseline
    posebusters_line=38, topn=1)

#### Figure 1A and 1B: Top 10 Diffdock performance

In [None]:
plot_diffdock_performance_bar(
    perf_leakage,
    pdbbind_diffdock_line=48,  # PDBbind-DiffDock baseline
    posebusters_line=None,
    topn=10)

#### Figure 2C

In [None]:
performance_vs_leakage_df = perf.merge(leakage_df, on=['Split','Test_tag', 'Quality'])
plot_diffdock_performance_vs_leakage(performance_vs_leakage_df[
    performance_vs_leakage_df.Metric.isin(["PLI SHARED ≥ 50", "PROTEIN SEQSIM ≥ 30","POCKET LDDT ≥ 50", "POCKET SHARED ≥ 50" ])])

#### All metric leakage plot

In [None]:
performance_vs_leakage_df = perf.merge(leakage_df, on=['Split','Test_tag', 'Quality'])
plot_diffdock_performance_vs_leakage(performance_vs_leakage_df, exclude_tanimoto=True,  plot_dimension=(400, 1200))