In [1]:
from pathlib import Path
import json
import pandas as pd
import numpy as np
from dataclasses import dataclass

In [2]:
@dataclass
class TestCriteria:
    max_entry_resolution: float = 3.5
    max_entry_r = 0.4
    max_entry_rfree = 0.45
    max_entry_r_minus_rfree = 0.05
    ligand_max_num_unresolved_heavy_atoms = 0
    ligand_max_alt_count = 1
    ligand_min_average_occupancy: float = 0.8
    ligand_min_average_rscc: float = 0.8
    ligand_max_average_rsr: float = 0.3
    ligand_max_percent_outliers_clashes = 0
    pocket_max_num_unresolved_heavy_atoms = 0
    pocket_max_alt_count = 1
    pocket_min_average_occupancy: float = 0.8
    pocket_min_average_rscc: float = 0.8
    pocket_max_average_rsr: float = 0.3
    pocket_max_percent_outliers_clashes = 100


def get_high_quality_systems(
    row: pd.Series,
    criteria: TestCriteria
) -> bool:
    if row.system_type != "holo":
        return False
    if row.entry_r is not None and row.system_ligand_average_rscc is not None:
        quality = [
            # ENTRY
            row.entry_resolution <= criteria.max_entry_resolution,
            row.entry_r <= criteria.max_entry_r,
            row.entry_rfree <= criteria.max_entry_rfree,
            row.entry_r_minus_rfree <= criteria.max_entry_r_minus_rfree,
            # LIGAND
            row.system_ligand_num_unresolved_heavy_atoms <= row.system_num_covalent_ligands + criteria.ligand_max_num_unresolved_heavy_atoms,
            row.system_ligand_max_alt_count <= criteria.ligand_max_alt_count, # NOTE: max_alt_count is misnomer - this counts number of total conformers!
            row.system_ligand_average_occupancy >= criteria.ligand_min_average_occupancy,
            row.system_ligand_average_rscc >= criteria.ligand_min_average_rscc,
            row.system_ligand_average_rsr <= criteria.ligand_max_average_rsr,
            row.system_ligand_percent_outliers_clashes <= criteria.ligand_max_percent_outliers_clashes,
            # POCKET
            row.system_pocket_num_unresolved_heavy_atoms <= criteria.pocket_max_num_unresolved_heavy_atoms,
            row.system_pocket_max_alt_count <= criteria.pocket_max_alt_count,
            row.system_pocket_average_occupancy >= criteria.pocket_min_average_occupancy,
            row.system_pocket_average_rscc >= criteria.pocket_min_average_rscc,
            row.system_pocket_average_rsr <= criteria.pocket_max_average_rsr,
            row.system_pocket_percent_outliers_clashes <= criteria.pocket_max_percent_outliers_clashes,
        ]
        if np.logical_and.reduce(quality):
            return True
    return False

quality_config = TestCriteria()

In [4]:
df = pd.read_parquet("v1/index/annotation_table.parquet")
df["system_num_covalent_ligands"] = df.groupby("system_id")["ligand_is_covalent"].transform("sum")
df["passes_quality"] = df.apply(lambda row: get_high_quality_systems(row, criteria=quality_config), axis=1)

In [5]:
passes_quality = set(df[df["passes_quality"]]["system_id"])

In [6]:
split_tables = {
    x.stem: pd.read_csv(x) for x in Path("v1/workshop_split/").iterdir()
}
replace_keys = {'ecod_split': 'plinder-ECOD', 
                'plinder_v0_no_posbuster': 'plinder-v0', 
                # 'diffdock_pdbind_nolig_or_rec': 'PDBBind-Time', 
                'time_split': 'plinder-Time'}
split_tables = {replace_keys[k]: v for k, v in split_tables.items() if k in replace_keys}
# for x in Path("v1/splits/batch_5/").iterdir():
#     split_tables[x.stem] = pd.read_parquet(x)
for x in Path("v1/splits/batch_6/").iterdir():
    split_tables[x.stem] = pd.read_parquet(x)
split_tables["pdbbind"] = pd.read_csv("v1/other_splits/pdbbind/pdbbind.csv")
split_tables["pdbbind_lp"] = pd.read_csv("v1/other_splits/pdbbind_lp/pdbbind_lp.csv")
split_tables["equibind"] = pd.read_csv("v1/other_splits/equibind/equibind.csv")
split_tables["dockgen"] = pd.read_csv("v1/other_splits/dockgen/dockgen.csv")

In [7]:
missing = set(pd.read_csv("v1/missing_pdb_ids/missing_pdb_ids.csv")["system_id"])

In [8]:
with open("v1/other_splits/posebusters/posebuster_system_ids.csv") as f:
    posebusters = set(line.strip() for line in f).difference(missing)

In [9]:
fractions_dir = Path("v1/fractions")
fractions_dir.mkdir(exist_ok=True)

In [10]:
metric_thresholds = [
     ("pli_qcov", 50),
     ("pocket_qcov", 50),
     ("pocket_lddt", 50),
     ("protein_lddt_qcov_weighted_sum", 50),
     ("protein_seqsim_weighted_sum", 30),
     ("tanimoto_similarity_max", 30)
]
fractions_all = {}
for metric, threshold in metric_thresholds:
    n = "similarity"
    fractions = {}
    print(metric, threshold)
    for name in split_tables:
        split_table = split_tables[name]
        train = set(split_table[split_table["split"] == "train"]["system_id"]).difference(missing)
        val = set(split_table[split_table["split"] == "val"]["system_id"]).difference(missing)
        test = set(split_table[split_table["split"] == "test"]["system_id"]).difference(missing)
        fractions[name] = {"num_test": len(test),
                           "num_train": len(train),
                           "num_val": len(val),
                           "fraction_hq": len(set(test).intersection(passes_quality)) / len(test),
                          }
        compare_pairs = [("train_test", (train, test)), 
                         ("train_val", (train, val)), 
                         ("val_test", (val, test)),
                         ("train_posebusters", (train, posebusters))]
        for pair_name, pair in compare_pairs:
            pair_table = pd.read_parquet(f"v1/leakage/{name}__{metric}__{pair_name}.parquet")
            num_leaked = len(set(pair_table[(pair_table[n] >= threshold) & (pair_table["query_system"].isin(pair[1]))]["query_system"]))
            fractions[name][f"fraction_leaked_{pair_name}_0"] = num_leaked / len(pair[1])
            if pair_name.endswith("_test"):
                num_leaked_hq = len(set(pair_table[(pair_table[n] >= threshold) & (pair_table["query_system"].isin(pair[1]))]["query_system"]).intersection(passes_quality))
                fractions[name][f"fraction_leaked_{pair_name}_hq_0"] = num_leaked_hq / len(pair[1].intersection(passes_quality))
    fractions_all[(metric, threshold)] = fractions

pli_qcov 50
pocket_qcov 50
pocket_lddt 50
protein_lddt_qcov_weighted_sum 50
protein_seqsim_weighted_sum 30
tanimoto_similarity_max 30


In [11]:
map_keys = {
    "pdbbind": "PDBBind-Original",
    "pdbbind_lp": "PDBBind-LP",
    "equibind": "PDBBind-DiffDock",
    "dockgen": "DockGen"
}
for metric, threshold in fractions_all:
    with open(f"v1/fractions/{metric}__{threshold}.json", "w") as f:
        json.dump({map_keys.get(k, k): v for k, v in fractions_all[(metric, threshold)].items()}, f)

In [14]:
map_keys = {
    "pdbbind": "pdbbind-original",
    "pdbbind_lp": "pdbbind-lp",
    "equibind": "pdbbind-diffdock",
}
rows = []
rows2 = []
for name in split_tables:
    if "splits" in name:
        continue
    print(name)
    row = {"split": name.lower() if name not in map_keys else map_keys[name]}
    row2 = {"split": name.lower() if name not in map_keys else map_keys[name]}
    for i, metric_threshold in enumerate(metric_thresholds):
        row[f"{metric_threshold[0]} > {metric_threshold[1]}"] = f'{fractions_all[metric_threshold][name]["fraction_leaked_train_test_hq_0"]:.2f}'
        row2[f"{metric_threshold[0]} > {metric_threshold[1]}"] = f'{fractions_all[metric_threshold][name]["fraction_leaked_train_posebusters_0"]:.2f}'
    rows.append(row)
    rows2.append(row2)

plinder-ECOD
plinder-v0
plinder-Time
pdbbind
pdbbind_lp
equibind
dockgen


In [17]:
metric_order = ["pli_qcov > 50", "pocket_lddt > 50", "pocket_qcov > 50", "protein_lddt_qcov_weighted_sum > 50", "protein_seqsim_weighted_sum > 30", "tanimoto_similarity_max > 30"]
order = ["pdbbind", "pdbbind_lp", "equibind", "dockgen", "plinder-ECOD", "plinder-Time", "plinder-v0"]
# order = ["pdbbind-original", "pdbbind-diffdock", "dockgen", "pdbbind-lp", "plinder-time", "plinder-ecod", "plinder-v0"]
for i, metric in enumerate(metric_order):
    metric, threshold = metric.split()[0], int(metric.split("> ")[-1])
    print(metric, threshold)
    for pair in ["train_posebusters", "train_test", "train_val", "val_test"]:
        l, r = pair.split("_")
        vals = []
        for name in order:
            if name.startswith("plinder") and r == "test":
                vals.append(f'{fractions_all[(metric, threshold)][name][f"fraction_leaked_{pair}_hq_0"]:.2f}')
            else:
                vals.append(f'{fractions_all[(metric, threshold)][name][f"fraction_leaked_{pair}_0"]:.2f}')
        print(f"{l} vs. {r} & " + " & ".join(vals) + " \\\\")
        

pli_qcov 50
train vs. posebusters & 0.51 & 0.48 & 0.52 & 0.60 & 0.64 & 0.72 & 0.40 \\
train vs. test & 0.88 & 0.71 & 0.27 & 0.05 & 0.30 & 0.80 & 0.04 \\
train vs. val & 0.73 & 0.64 & 0.78 & 0.13 & 0.05 & 0.78 & 0.10 \\
val vs. test & 0.89 & 0.69 & 0.08 & 0.01 & 0.05 & 0.59 & 0.00 \\
pocket_lddt 50
train vs. posebusters & 0.66 & 0.67 & 0.69 & 0.74 & 0.74 & 0.88 & 0.47 \\
train vs. test & 0.97 & 0.84 & 0.53 & 0.09 & 0.49 & 0.96 & 0.00 \\
train vs. val & 0.84 & 0.84 & 0.89 & 0.20 & 0.30 & 0.95 & 0.77 \\
val vs. test & 0.97 & 0.83 & 0.23 & 0.03 & 0.18 & 0.79 & 0.00 \\
pocket_qcov 50
train vs. posebusters & 0.63 & 0.64 & 0.65 & 0.70 & 0.70 & 0.83 & 0.47 \\
train vs. test & 0.96 & 0.81 & 0.47 & 0.06 & 0.35 & 0.88 & 0.09 \\
train vs. val & 0.82 & 0.80 & 0.87 & 0.14 & 0.07 & 0.87 & 0.29 \\
val vs. test & 0.94 & 0.80 & 0.16 & 0.01 & 0.08 & 0.70 & 0.00 \\
protein_lddt_qcov_weighted_sum 50
train vs. posebusters & 0.68 & 0.68 & 0.70 & 0.72 & 0.75 & 0.88 & 0.48 \\
train vs. test & 0.97 & 0.85 & 0.5

## Test split leakage fractions

In [19]:
order = ["pdbbind-original", "pdbbind-diffdock", "dockgen", "pdbbind-lp", "plinder-time", "plinder-ecod", "plinder-v0"]
print(pd.DataFrame(rows).set_index("split")[metric_order].T[order].T.to_latex())

\begin{tabular}{lllllll}
\toprule
 & pli_qcov > 50 & pocket_lddt > 50 & pocket_qcov > 50 & protein_lddt_qcov_weighted_sum > 50 & protein_seqsim_weighted_sum > 30 & tanimoto_similarity_max > 30 \\
split &  &  &  &  &  &  \\
\midrule
pdbbind-original & 0.91 & 1.00 & 1.00 & 1.00 & 1.00 & 0.62 \\
pdbbind-diffdock & 0.43 & 0.76 & 0.73 & 0.76 & 0.80 & 0.43 \\
dockgen & 0.04 & 0.08 & 0.05 & 0.08 & 0.18 & 0.64 \\
pdbbind-lp & 0.77 & 0.87 & 0.86 & 0.89 & 0.94 & 0.40 \\
plinder-time & 0.80 & 0.96 & 0.88 & 0.95 & 0.98 & 0.54 \\
plinder-ecod & 0.30 & 0.49 & 0.35 & 0.49 & 0.60 & 0.52 \\
plinder-v0 & 0.04 & 0.00 & 0.09 & 0.01 & 0.37 & 0.58 \\
\bottomrule
\end{tabular}



## Posebusters leakage fractions

In [20]:
order = ["pdbbind-original", "pdbbind-diffdock", "dockgen", "pdbbind-lp", "plinder-time", "plinder-ecod", "plinder-v0"]
print(pd.DataFrame(rows2).set_index("split")[metric_order].T[order].T.to_latex())

\begin{tabular}{lllllll}
\toprule
 & pli_qcov > 50 & pocket_lddt > 50 & pocket_qcov > 50 & protein_lddt_qcov_weighted_sum > 50 & protein_seqsim_weighted_sum > 30 & tanimoto_similarity_max > 30 \\
split &  &  &  &  &  &  \\
\midrule
pdbbind-original & 0.51 & 0.66 & 0.63 & 0.68 & 0.77 & 0.57 \\
pdbbind-diffdock & 0.52 & 0.69 & 0.65 & 0.70 & 0.78 & 0.59 \\
dockgen & 0.60 & 0.74 & 0.70 & 0.72 & 0.81 & 0.61 \\
pdbbind-lp & 0.48 & 0.67 & 0.64 & 0.68 & 0.79 & 0.57 \\
plinder-time & 0.72 & 0.88 & 0.83 & 0.88 & 0.93 & 0.66 \\
plinder-ecod & 0.64 & 0.74 & 0.70 & 0.75 & 0.81 & 0.65 \\
plinder-v0 & 0.40 & 0.47 & 0.47 & 0.48 & 0.64 & 0.64 \\
\bottomrule
\end{tabular}



## plinder full splits

In [22]:
rows = []
rows2 = []
for name in split_tables:
    if "splits" not in name:
        continue
    print(name)
    row = {"split": name.lower() if name not in map_keys else map_keys[name]}
    row2 = {"split": name.lower() if name not in map_keys else map_keys[name]}
    for i, metric_threshold in enumerate(metric_thresholds):
        row[f"{metric_threshold[0]} > {metric_threshold[1]}"] = f'{fractions_all[metric_threshold][name]["fraction_leaked_val_test_hq_0"]:.2f}'
        row2[f"{metric_threshold[0]} > {metric_threshold[1]}"] = f'{fractions_all[metric_threshold][name]["fraction_leaked_train_posebusters_0"]:.2f}'
    rows.append(row)
    rows2.append(row2)
# rows = rows + rows2

splits_metaflow_config_split_batch_6_9_yaml_e9ca06e682c3cb2f9340542d1ec1f6dc
splits_metaflow_config_split_batch_6_2_yaml_ff5ef415b01555aeb2bdef1f1d52a44e
splits_metaflow_config_split_batch_6_10_yaml_53e4b810607fa6672394e41abcbf73ed
splits_metaflow_config_split_batch_6_7_yaml_f2f13aafadd53ca34a8f399b5e88e7da


In [23]:
rows

[{'split': 'splits_metaflow_config_split_batch_6_9_yaml_e9ca06e682c3cb2f9340542d1ec1f6dc',
  'pli_qcov > 50': '0.00',
  'pocket_qcov > 50': '0.00',
  'pocket_lddt > 50': '0.00',
  'protein_lddt_qcov_weighted_sum > 50': '0.00',
  'protein_seqsim_weighted_sum > 30': '0.15',
  'tanimoto_similarity_max > 30': '0.45'},
 {'split': 'splits_metaflow_config_split_batch_6_2_yaml_ff5ef415b01555aeb2bdef1f1d52a44e',
  'pli_qcov > 50': '0.00',
  'pocket_qcov > 50': '0.00',
  'pocket_lddt > 50': '0.15',
  'protein_lddt_qcov_weighted_sum > 50': '0.17',
  'protein_seqsim_weighted_sum > 30': '0.29',
  'tanimoto_similarity_max > 30': '0.57'},
 {'split': 'splits_metaflow_config_split_batch_6_10_yaml_53e4b810607fa6672394e41abcbf73ed',
  'pli_qcov > 50': '0.00',
  'pocket_qcov > 50': '0.00',
  'pocket_lddt > 50': '0.03',
  'protein_lddt_qcov_weighted_sum > 50': '0.04',
  'protein_seqsim_weighted_sum > 30': '0.18',
  'tanimoto_similarity_max > 30': '0.39'},
 {'split': 'splits_metaflow_config_split_batch_6_7_