In [4]:
import pandas as pd
from sklearn.metrics import (
    average_precision_score,roc_auc_score
)
import numpy as np
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import f1_score
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

In [1]:
cdr1 = list(range(27, 38 + 1))
cdr2=list(range(56, 65 + 1))
cdr3=list(range(105, 117 + 1))
cdrs = cdr1 + cdr2 + cdr3
cdr_ranges = {
    "CDR1": cdr1,
    "CDR2": cdr2,
    "CDR3": cdr3,
}
all=list(range(129))
framework = [each for each in all if each not in cdrs]


# LLM / Paragraph / combined

In [2]:
import pandas as pd
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score

def get_ap_roc_f1_mcc_df(llm_path):
    llm = pd.read_csv(llm_path)
    llm["IMGT_bis"] = llm["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
    ap_scores = {}
    roc_scores = {}

    for _, df_pdb in llm.groupby("pdb"):
        for cdr_name, cdr_range in cdr_ranges.items():
            for chain in ["light", "heavy"]:
                preds_cdr = df_pdb.query("IMGT_bis in @cdr_range and chain_type==@chain")["prediction"].tolist()
                labs_cdr = df_pdb.query("IMGT_bis in @cdr_range and chain_type==@chain")["labels"].tolist()

                if f"{cdr_name} {chain}" not in ap_scores:
                    ap_scores[f"{cdr_name} {chain}"] = []
                    roc_scores[f"{cdr_name} {chain}"] = []

                if len(set(labs_cdr)) == 1:  # Check if all labels are the same
                    roc_auc = 1.0
                else:
                    roc_auc = roc_auc_score(labs_cdr, preds_cdr)

                ap_scores[f"{cdr_name} {chain}"].append(average_precision_score(labs_cdr, preds_cdr))
                roc_scores[f"{cdr_name} {chain}"].append(roc_auc)

        for name, range in zip(["CDRs", "Framework", "Whole sequence"], [cdrs, framework, all]):
            preds_cdr = df_pdb.query("IMGT_bis in @range")["prediction"].tolist()
            labs_cdr = df_pdb.query("IMGT_bis in @range")["labels"].tolist()

            if name not in ap_scores:
                ap_scores[name] = []
            if name not in roc_scores:
                roc_scores[name] = []

            if len(set(labs_cdr)) == 1:  # Check if only one class exists
                roc_auc = 1.0
            else:
                roc_auc = roc_auc_score(labs_cdr, preds_cdr)

            ap_scores[name].append(average_precision_score(labs_cdr, preds_cdr))
            roc_scores[name].append(roc_auc)

    # Compute mean scores
    for key in ap_scores:
        ap_scores[key] = np.mean(ap_scores[key])
    for key in roc_scores:
        roc_scores[key] = np.mean(roc_scores[key])

    return ap_scores, roc_scores


In [3]:
def get_ap_roc_f1_mcc_df_one_hot(paragraph_path, llm_path="/home/athenes/Paraplume/benchmark/paragraph/250526/lr-0.00005_dr-0.4,0.4,0.4_mk-0.4_bs-16_dim1-2000,1000,500_alphas-4,5,6_pen-0.00001_weight_1_emb_all_seed_1/prediction_test_set.csv"):
    predictions_llm = pd.read_csv(llm_path)
    predictions_llm["IMGT_bis"] = predictions_llm["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
    predictions_paragraph = pd.read_csv(paragraph_path)
    predictions_paragraph["IMGT_bis"] = predictions_paragraph["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)

    combined = pd.merge(predictions_llm, predictions_paragraph[['pdb', 'IMGT', 'chain_type', 'prediction']],
                        on=["pdb", "IMGT", "chain_type"],
                        how="left",
                        suffixes=("_llm", "_paragraph"))


    combined["prediction"] = np.where(
        combined["prediction_paragraph"].notna(),
        combined["prediction_paragraph"],
        0
    )

    ap_scores = {}
    roc_scores = {}

    for _, df_pdb in combined.groupby("pdb"):
        for cdr_name, cdr_range in cdr_ranges.items():
            for chain in ["light", "heavy"]:
                preds_cdr = df_pdb.query("IMGT_bis in @cdr_range and chain_type==@chain")["prediction"].tolist()
                labs_cdr = df_pdb.query("IMGT_bis in @cdr_range and chain_type==@chain")["labels"].tolist()

                if f"{cdr_name} {chain}" not in ap_scores:
                    ap_scores[f"{cdr_name} {chain}"] = []
                    roc_scores[f"{cdr_name} {chain}"] = []

                if len(set(labs_cdr)) == 1:  # Check if all labels are the same
                    roc_auc = 1.0
                else:
                    roc_auc = roc_auc_score(labs_cdr, preds_cdr)

                ap_scores[f"{cdr_name} {chain}"].append(average_precision_score(labs_cdr, preds_cdr))
                roc_scores[f"{cdr_name} {chain}"].append(roc_auc)

        for name, range in zip(["CDRs", "Framework", "Whole sequence"], [cdrs, framework, all]):
            preds_cdr = df_pdb.query("IMGT_bis in @range")["prediction"].tolist()
            labs_cdr = df_pdb.query("IMGT_bis in @range")["labels"].tolist()

            if name not in ap_scores:
                ap_scores[name] = []
            if name not in roc_scores:
                roc_scores[name] = []

            if len(set(labs_cdr)) == 1:  # Check if only one class exists
                roc_auc = 1.0
            else:
                roc_auc = roc_auc_score(labs_cdr, preds_cdr)

            ap_scores[name].append(average_precision_score(labs_cdr, preds_cdr))
            roc_scores[name].append(roc_auc)

    # Compute mean scores
    for key in ap_scores:
        ap_scores[key] = np.mean(ap_scores[key])
    for key in roc_scores:
        roc_scores[key] = np.mean(roc_scores[key])

    return ap_scores, roc_scores


# table s3, Paraplume

In [5]:
records=[]
for seed in range(1, 17):
    llm_path = f"/home/athenes/Paraplume/benchmark/paragraph/250526/lr-0.00005_dr-0.4,0.4,0.4_mk-0.4_bs-16_dim1-2000,1000,500_alphas-4,5,6_pen-0.00001_weight_1_emb_all_seed_{seed}/prediction_test_set.csv"
    ap_scores, roc_scores = get_ap_roc_f1_mcc_df(llm_path)
    ap_scores["metric"]="ap"
    roc_scores["metric"]="roc"
    ap_scores["seed"]=seed
    roc_scores["seed"]=seed
    records.append(ap_scores)
    records.append(roc_scores)
df = pd.DataFrame.from_records(records)

In [6]:
print(df.groupby("metric").mean())

        CDR1 light  CDR1 heavy  CDR2 light  CDR2 heavy  CDR3 light  \
metric                                                               
ap        0.786264    0.789808    0.450736    0.804707    0.789705   
roc       0.910678    0.922713    0.989536    0.931404    0.940954   

        CDR3 heavy      CDRs  Framework  Whole sequence  seed  
metric                                                         
ap        0.838380  0.785496   0.667942        0.757530   8.5  
roc       0.892584  0.872421   0.977329        0.966174   8.5  


# table s3, paragraph crystal

In [7]:
records=[]
for seed in tqdm(range(1, 17)):
    one_hot_path = f"/home/athenes/Paraplume/benchmark/paragraph/3D/{seed}/prediction_.csv"
    ap_scores, roc_scores = get_ap_roc_f1_mcc_df_one_hot(one_hot_path)
    ap_scores["metric"]="ap"
    roc_scores["metric"]="roc"
    ap_scores["seed"]=seed
    roc_scores["seed"]=seed
    records.append(ap_scores)
    records.append(roc_scores)
df = pd.DataFrame.from_records(records)

100%|██████████| 16/16 [02:02<00:00,  7.64s/it]


In [8]:
print(df.groupby("metric").mean())

        CDR1 light  CDR1 heavy  CDR2 light  CDR2 heavy  CDR3 light  \
metric                                                               
ap        0.799641    0.802783    0.451739    0.804122    0.808919   
roc       0.928089    0.934435    0.990539    0.929676    0.953174   

        CDR3 heavy      CDRs  Framework  Whole sequence  seed  
metric                                                         
ap        0.892786  0.821705   0.565694        0.769627   8.5  
roc       0.921862  0.888321   0.831195        0.939005   8.5  
