In [29]:
import pandas as pd
import numpy as np
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import f1_score
from sklearn.metrics import (
    average_precision_score,roc_auc_score
)

In [30]:
all=list(range(129))
cdr1 = list(range(25, 40 + 1))
cdr2=list(range(54, 67 + 1))
cdr3=list(range(103, 119 + 1))
cdrs = cdr1 + cdr2 + cdr3

def get_f1_scores(predictions, column="prediction", threshold1=0.5, threshold2=0.5):
    # Filter predictions where IMGT_bis is in cdrs
    cdr_predictions = predictions.query("IMGT_bis in @cdrs")
    non_cdr_predictions = predictions.query("IMGT_bis not in @cdrs")

    # Apply threshold1 to cdr_predictions
    cdr_preds = (cdr_predictions[column] >= threshold1).astype(int).tolist()
    cdr_labs = cdr_predictions["labels"].tolist()

    # Apply threshold2 to non_cdr_predictions
    non_cdr_preds = (non_cdr_predictions[column] >= threshold2).astype(int).tolist()
    non_cdr_labs = non_cdr_predictions["labels"].tolist()

    # Combine predictions and labels
    all_preds = cdr_preds + non_cdr_preds
    all_labs = cdr_labs + non_cdr_labs

    # Calculate F1 score
    f1 = f1_score(all_labs, all_preds)

    return f1


def get_mcc_scores(predictions, column="prediction", threshold1=0.5,threshold2=0.5):
    # Filter predictions where IMGT_bis is in cdrs
    cdr_predictions = predictions.query("IMGT_bis in @cdrs")
    non_cdr_predictions = predictions.query("IMGT_bis not in @cdrs")

    # Apply threshold1 to cdr_predictions
    cdr_preds = (cdr_predictions[column] >= threshold1).astype(int).tolist()
    cdr_labs = cdr_predictions["labels"].tolist()

    # Apply threshold2 to non_cdr_predictions
    non_cdr_preds = (non_cdr_predictions[column] >= threshold2).astype(int).tolist()
    non_cdr_labs = non_cdr_predictions["labels"].tolist()

    # Combine predictions and labels
    all_preds = cdr_preds + non_cdr_preds
    all_labs = cdr_labs + non_cdr_labs
    mcc = matthews_corrcoef(all_labs, all_preds)

    return mcc
def get_ap_scores(predictions, column="prediction"):

    preds = predictions[column].tolist()
    labs = predictions["labels"].tolist()
    ap = average_precision_score(labs, preds)

    return ap
def get_roc_scores(predictions, column="prediction"):
    preds = predictions[column].tolist()
    labs = predictions["labels"].tolist()
    roc = roc_auc_score(labs, preds)

    return roc
def get_ap_roc_f1_mcc_df(llm_path1,llm_path2, threshold=0.5, column="labels_llm2"):
    predictions_llm1 = pd.read_csv(llm_path1)
    predictions_llm2 = pd.read_csv(llm_path2)
    predictions_llm = pd.merge(predictions_llm1, predictions_llm2[['pdb', 'IMGT', 'chain_type', 'prediction', 'labels']],
                        on=["pdb", "IMGT", "chain_type"],
                        how="left",
                        suffixes=("_llm1", "_llm2"))
    predictions_llm=predictions_llm.dropna()
    predictions_llm["labels"]=predictions_llm["labels_llm1"]
    predictions_llm["prediction"]=predictions_llm[column]
    predictions_llm["IMGT_bis"] = predictions_llm["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)

    # Compute all metrics

    f1_dict = {
    }

    mcc_dict = {
    }
    f1_llm_list = []
    mcc_llm_list = []

    for _, df_pdb in predictions_llm.groupby("pdb"):

        f1_llm_list.append(get_f1_scores(df_pdb, column="prediction", threshold1=threshold, threshold2=threshold))
        mcc_llm_list.append(get_mcc_scores(df_pdb, column="prediction", threshold1=threshold, threshold2=threshold))

    # Average the results across pdb groups


    f1_dict["llm"] = np.mean(f1_llm_list)
    mcc_dict["llm"] = np.mean(mcc_llm_list)
    mcc_dict["metric"]="mcc"
    f1_dict["metric"]="f1"


    return f1_dict, mcc_dict

# paragraph

In [31]:
paired_chain=pd.read_csv("/home/athenes/Paraplume/benchmark/paragraph/250526/lr-0.00005_dr-0.4,0.4,0.4_mk-0.4_bs-16_dim1-2000,1000,500_alphas-4,5,6_pen-0.00001_weight_1_emb_all_seed_1/prediction_test_set.csv")
heavy=pd.read_csv("/home/athenes/Paraplume/benchmark/paragraph/250526/lr-0.00005_dr-0.4,0.4,0.4_mk-0.4_bs-16_dim1-2000,1000,500_alphas-4,5,6_pen-0.00001_weight_1_emb_all_seed_1/prediction_test_set_heavy.csv")
light=pd.read_csv("/home/athenes/Paraplume/benchmark/paragraph/250526/lr-0.00005_dr-0.4,0.4,0.4_mk-0.4_bs-16_dim1-2000,1000,500_alphas-4,5,6_pen-0.00001_weight_1_emb_all_seed_1/prediction_test_set_light.csv")

In [32]:
paired_chain["IMGT_bis"] = paired_chain["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
f1_list = []
mcc_list = []
ap_list = []
roc_list = []
for _, df_pdb in paired_chain.query("chain_type=='heavy'").groupby("pdb"):
    if len(set(df_pdb["labels"].unique()))==1:
        continue
    f1_list.append(get_f1_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    mcc_list.append(get_mcc_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    ap_list.append(get_ap_scores(df_pdb, column="prediction"))
    roc_list.append(get_roc_scores(df_pdb, column="prediction"))
print("paired heavy")
for name,each in zip(["f1","mcc","ap","roc"],[f1_list,mcc_list,ap_list, roc_list]):
    print(name,np.mean(each))
print("="*20)
heavy["IMGT_bis"] = heavy["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
f1_list = []
mcc_list = []
ap_list = []
roc_list = []
for _, df_pdb in heavy.query("chain_type=='heavy'").groupby("pdb"):
    if len(set(df_pdb["labels"].unique()))==1:
        continue
    f1_list.append(get_f1_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    mcc_list.append(get_mcc_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    ap_list.append(get_ap_scores(df_pdb, column="prediction"))
    roc_list.append(get_roc_scores(df_pdb, column="prediction"))
print("single heavy")
for name,each in zip(["f1","mcc","ap","roc"],[f1_list,mcc_list,ap_list, roc_list]):
    print(name,np.mean(each))
print("="*20)
paired_chain["IMGT_bis"] = paired_chain["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
f1_list = []
mcc_list = []
ap_list = []
roc_list = []
for _, df_pdb in paired_chain.query("chain_type=='light'").groupby("pdb"):
    if len(set(df_pdb["labels"].unique()))==1:
        continue
    f1_list.append(get_f1_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    mcc_list.append(get_mcc_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    ap_list.append(get_ap_scores(df_pdb, column="prediction"))
    roc_list.append(get_roc_scores(df_pdb, column="prediction"))
print("paired light")
for name,each in zip(["f1","mcc","ap","roc"],[f1_list,mcc_list,ap_list, roc_list]):
    print(name,np.mean(each))
print("="*20)

light["IMGT_bis"] = light["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
f1_list = []
mcc_list = []
ap_list = []
roc_list = []
for _, df_pdb in light.query("chain_type=='light'").groupby("pdb"):
    if len(set(df_pdb["labels"].unique()))==1:
        continue
    f1_list.append(get_f1_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    mcc_list.append(get_mcc_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    ap_list.append(get_ap_scores(df_pdb, column="prediction"))
    roc_list.append(get_roc_scores(df_pdb, column="prediction"))
print("single light")
for name,each in zip(["f1","mcc","ap","roc"],[f1_list,mcc_list,ap_list, roc_list]):
    print(name,np.mean(each))

paired heavy
f1 0.7236896925779821
mcc 0.7005545517457616
ap 0.7937420877877883
roc 0.9685792594943431
single heavy
f1 0.7089969350855303
mcc 0.6818109899278485
ap 0.7876350150463332
roc 0.9681252743675688
paired light
f1 0.6468729436018179
mcc 0.638525809029712
ap 0.7529076595919341
roc 0.9685747184477042
single light
f1 0.5913634580170368
mcc 0.5777079510400438
ap 0.6827285552637917
roc 0.9579505977343682


# pecan

In [33]:
paired_chain=pd.read_csv("/home/athenes/Paraplume/benchmark/pecan/250526/lr-0.00005_dr-0.4,0.4,0.4_mk-0.4_bs-16_dim1-2000,1000,500_alphas-4,5,6_pen-0.00001_weight_1_emb_all_seed_1/prediction_test_set.csv")
heavy=pd.read_csv("/home/athenes/Paraplume/benchmark/pecan/250526/lr-0.00005_dr-0.4,0.4,0.4_mk-0.4_bs-16_dim1-2000,1000,500_alphas-4,5,6_pen-0.00001_weight_1_emb_all_seed_1/prediction_test_set_heavy.csv")
light=pd.read_csv("/home/athenes/Paraplume/benchmark/pecan/250526/lr-0.00005_dr-0.4,0.4,0.4_mk-0.4_bs-16_dim1-2000,1000,500_alphas-4,5,6_pen-0.00001_weight_1_emb_all_seed_1/prediction_test_set_light.csv")

In [34]:
paired_chain["IMGT_bis"] = paired_chain["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
f1_list = []
mcc_list = []
ap_list = []
roc_list = []
for _, df_pdb in paired_chain.query("chain_type=='heavy'").groupby("pdb"):
    if len(set(df_pdb["labels"].unique()))==1:
        continue
    f1_list.append(get_f1_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    mcc_list.append(get_mcc_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    ap_list.append(get_ap_scores(df_pdb, column="prediction"))
    roc_list.append(get_roc_scores(df_pdb, column="prediction"))
print("paired heavy")
for name,each in zip(["f1","mcc","ap","roc"],[f1_list,mcc_list,ap_list, roc_list]):
    print(name,np.mean(each))
print("="*20)
heavy["IMGT_bis"] = heavy["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
f1_list = []
mcc_list = []
ap_list = []
roc_list = []
for _, df_pdb in heavy.query("chain_type=='heavy'").groupby("pdb"):
    if len(set(df_pdb["labels"].unique()))==1:
        continue
    f1_list.append(get_f1_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    mcc_list.append(get_mcc_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    ap_list.append(get_ap_scores(df_pdb, column="prediction"))
    roc_list.append(get_roc_scores(df_pdb, column="prediction"))
print("single heavy")
for name,each in zip(["f1","mcc","ap","roc"],[f1_list,mcc_list,ap_list, roc_list]):
    print(name,np.mean(each))
print("="*20)
paired_chain["IMGT_bis"] = paired_chain["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
f1_list = []
mcc_list = []
ap_list = []
roc_list = []
for _, df_pdb in paired_chain.query("chain_type=='light'").groupby("pdb"):
    if len(set(df_pdb["labels"].unique()))==1:
        continue
    f1_list.append(get_f1_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    mcc_list.append(get_mcc_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    ap_list.append(get_ap_scores(df_pdb, column="prediction"))
    roc_list.append(get_roc_scores(df_pdb, column="prediction"))
print("paired light")
for name,each in zip(["f1","mcc","ap","roc"],[f1_list,mcc_list,ap_list, roc_list]):
    print(name,np.mean(each))
print("="*20)

light["IMGT_bis"] = light["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
f1_list = []
mcc_list = []
ap_list = []
roc_list = []
for _, df_pdb in light.query("chain_type=='light'").groupby("pdb"):
    if len(set(df_pdb["labels"].unique()))==1:
        continue
    f1_list.append(get_f1_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    mcc_list.append(get_mcc_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    ap_list.append(get_ap_scores(df_pdb, column="prediction"))
    roc_list.append(get_roc_scores(df_pdb, column="prediction"))
print("single light")
for name,each in zip(["f1","mcc","ap","roc"],[f1_list,mcc_list,ap_list, roc_list]):
    print(name,np.mean(each))

paired heavy
f1 0.6898727252317213
mcc 0.6661756816975023
ap 0.7710970789241476
roc 0.966392968071185
single heavy
f1 0.6819616326609673
mcc 0.6582749359081305
ap 0.7655074436336462
roc 0.9654268308453322
paired light
f1 0.6172233862614466
mcc 0.6074586715640241
ap 0.7485541752744719
roc 0.9668687956145908
single light
f1 0.5335902854756335
mcc 0.5186410189523603
ap 0.6713787392307432
roc 0.9530561883726795


# mipe

In [35]:
paired_chain=pd.read_csv("/home/athenes/Paraplume/benchmark/mipe/250526/lr-0.00005_dr-0.4,0.4,0.4_mk-0.4_bs-16_dim1-2000,1000,500_alphas-4,5,6_pen-0.00001_weight_1_emb_all_seed_1/0/prediction_test_set.csv")
heavy=pd.read_csv("/home/athenes/Paraplume/benchmark/mipe/250526/lr-0.00005_dr-0.4,0.4,0.4_mk-0.4_bs-16_dim1-2000,1000,500_alphas-4,5,6_pen-0.00001_weight_1_emb_all_seed_1/0/prediction_test_set_heavy.csv")
light=pd.read_csv("/home/athenes/Paraplume/benchmark/mipe/250526/lr-0.00005_dr-0.4,0.4,0.4_mk-0.4_bs-16_dim1-2000,1000,500_alphas-4,5,6_pen-0.00001_weight_1_emb_all_seed_1/0/prediction_test_set_light.csv")

In [36]:
paired_chain["IMGT_bis"] = paired_chain["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
f1_list = []
mcc_list = []
ap_list = []
roc_list = []
for _, df_pdb in paired_chain.query("chain_type=='heavy'").groupby("pdb"):
    if len(set(df_pdb["labels"].unique()))==1:
        continue
    f1_list.append(get_f1_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    mcc_list.append(get_mcc_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    ap_list.append(get_ap_scores(df_pdb, column="prediction"))
    roc_list.append(get_roc_scores(df_pdb, column="prediction"))
print("paired heavy")
for name,each in zip(["f1","mcc","ap","roc"],[f1_list,mcc_list,ap_list, roc_list]):
    print(name,np.mean(each))
print("="*20)
heavy["IMGT_bis"] = heavy["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
f1_list = []
mcc_list = []
ap_list = []
roc_list = []
for _, df_pdb in heavy.query("chain_type=='heavy'").groupby("pdb"):
    if len(set(df_pdb["labels"].unique()))==1:
        continue
    f1_list.append(get_f1_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    mcc_list.append(get_mcc_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    ap_list.append(get_ap_scores(df_pdb, column="prediction"))
    roc_list.append(get_roc_scores(df_pdb, column="prediction"))
print("single heavy")
for name,each in zip(["f1","mcc","ap","roc"],[f1_list,mcc_list,ap_list, roc_list]):
    print(name,np.mean(each))
print("="*20)
paired_chain["IMGT_bis"] = paired_chain["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
f1_list = []
mcc_list = []
ap_list = []
roc_list = []
for _, df_pdb in paired_chain.query("chain_type=='light'").groupby("pdb"):
    if len(set(df_pdb["labels"].unique()))==1:
        continue
    f1_list.append(get_f1_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    mcc_list.append(get_mcc_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    ap_list.append(get_ap_scores(df_pdb, column="prediction"))
    roc_list.append(get_roc_scores(df_pdb, column="prediction"))
print("paired light")
for name,each in zip(["f1","mcc","ap","roc"],[f1_list,mcc_list,ap_list, roc_list]):
    print(name,np.mean(each))
print("="*20)

light["IMGT_bis"] = light["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
f1_list = []
mcc_list = []
ap_list = []
roc_list = []
for _, df_pdb in light.query("chain_type=='light'").groupby("pdb"):
    if len(set(df_pdb["labels"].unique()))==1:
        continue
    f1_list.append(get_f1_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    mcc_list.append(get_mcc_scores(df_pdb, column="prediction", threshold1=0.5, threshold2=0.5))
    ap_list.append(get_ap_scores(df_pdb, column="prediction"))
    roc_list.append(get_roc_scores(df_pdb, column="prediction"))
print("single light")
for name,each in zip(["f1","mcc","ap","roc"],[f1_list,mcc_list,ap_list, roc_list]):
    print(name,np.mean(each))

paired heavy
f1 0.6447108844108996
mcc 0.6234009189046661
ap 0.7260233033509083
roc 0.958768212298608
single heavy
f1 0.6075360275879612
mcc 0.5916669018547479
ap 0.7251036623963572
roc 0.9588217204222435
paired light
f1 0.644185399268788
mcc 0.638152382722701
ap 0.7580078334331026
roc 0.9718203419120744
single light
f1 0.5682453424487163
mcc 0.5670303844833805
ap 0.73599696588962
roc 0.9645158935201342
