In [6]:
import pandas as pd
from sklearn.metrics import (
    average_precision_score,roc_auc_score
)
import numpy as np
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import f1_score


In [7]:
cdr1 = list(range(27, 38 + 1))
cdr2=list(range(56, 65 + 1))
cdr3=list(range(105, 117 + 1))
cdrs = cdr1 + cdr2 + cdr3
all=list(range(129))
framework = [each for each in all if each not in cdrs]

def get_f1_scores(predictions, column="prediction", threshold=0.5):
    cdr_ranges = {
        "CDR1": cdr1,
        "CDR2": cdr2,
        "CDR3": cdr3,
    }
    f1_scores = {}

    for cdr_name, cdr_range in cdr_ranges.items():
        preds_cdr = (predictions.query("IMGT_bis in @cdr_range and chain_type=='light'")[column] >= threshold).astype(int).tolist()
        labs_cdr = predictions.query("IMGT_bis in @cdr_range and chain_type=='light'")["labels"].tolist()
        f1_scores[f"{cdr_name} light"] = f1_score(labs_cdr, preds_cdr)

        preds_cdr = (predictions.query("IMGT_bis in @cdr_range and chain_type=='heavy'")[column] >= threshold).astype(int).tolist()
        labs_cdr = predictions.query("IMGT_bis in @cdr_range and chain_type=='heavy'")["labels"].tolist()
        f1_scores[f"{cdr_name} heavy"] = f1_score(labs_cdr, preds_cdr)

    for name, range in zip(["CDRs", "Framework", "Whole sequence"], [cdrs, framework, all]):
        preds_cdr = (predictions.query("IMGT_bis in @range")[column] >= threshold).astype(int).tolist()
        labs_cdr = predictions.query("IMGT_bis in @range")["labels"].tolist()
        f1_scores[name] = f1_score(labs_cdr, preds_cdr)

    return f1_scores


def get_mcc_scores(predictions, column="prediction", threshold=0.5):
    cdr_ranges = {
        "CDR1": cdr1,
        "CDR2": cdr2,
        "CDR3": cdr3,
    }
    mcc_scores = {}

    for cdr_name, cdr_range in cdr_ranges.items():
        preds_cdr = (predictions.query("IMGT_bis in @cdr_range and chain_type=='light'")[column] >= threshold).astype(int).tolist()
        labs_cdr = predictions.query("IMGT_bis in @cdr_range and chain_type=='light'")["labels"].tolist()
        mcc_scores[f"{cdr_name} light"] = matthews_corrcoef(labs_cdr, preds_cdr)

        preds_cdr = (predictions.query("IMGT_bis in @cdr_range and chain_type=='heavy'")[column] >= threshold).astype(int).tolist()
        labs_cdr = predictions.query("IMGT_bis in @cdr_range and chain_type=='heavy'")["labels"].tolist()
        mcc_scores[f"{cdr_name} heavy"] = matthews_corrcoef(labs_cdr, preds_cdr)

    for name, range in zip(["CDRs", "Framework", "Whole sequence"], [cdrs, framework, all]):
        preds_cdr = (predictions.query("IMGT_bis in @range")[column] >= threshold).astype(int).tolist()
        labs_cdr = predictions.query("IMGT_bis in @range")["labels"].tolist()
        mcc_scores[name] = matthews_corrcoef(labs_cdr, preds_cdr)

    return mcc_scores


def get_ap_scores(predictions, column="prediction"):
    cdr_ranges = {
    "CDR1": cdr1,
    "CDR2": cdr2,
    "CDR3": cdr3,
    }
    ap_scores = {}

    for cdr_name, cdr_range in cdr_ranges.items():
        preds_cdr = predictions.query("IMGT_bis in @cdr_range and chain_type=='light'")[column].tolist()
        labs_cdr = predictions.query("IMGT_bis in @cdr_range and chain_type=='light'")["labels"].tolist()
        ap_scores[f"{cdr_name} light"] = average_precision_score(labs_cdr, preds_cdr)
        preds_cdr = predictions.query("IMGT_bis in @cdr_range and chain_type=='heavy'")[column].tolist()
        labs_cdr = predictions.query("IMGT_bis in @cdr_range and chain_type=='heavy'")["labels"].tolist()
        ap_scores[f"{cdr_name} heavy"] = average_precision_score(labs_cdr, preds_cdr)

    for name, range in zip(["CDRs","Framework","Whole sequence"],[cdrs,framework, all]):
        preds_cdr = predictions.query("IMGT_bis in @range")[column].tolist()
        labs_cdr = predictions.query("IMGT_bis in @range")["labels"].tolist()
        ap_scores[name] = average_precision_score(labs_cdr, preds_cdr)

    return ap_scores
def get_roc_scores(predictions, column="prediction"):
    cdr_ranges = {
    "CDR1": cdr1,
    "CDR2": cdr2,
    "CDR3": cdr3,
    }
    ap_scores = {}

    for cdr_name, cdr_range in cdr_ranges.items():
        preds_cdr = predictions.query("IMGT_bis in @cdr_range and chain_type=='light'")[column].tolist()
        labs_cdr = predictions.query("IMGT_bis in @cdr_range and chain_type=='light'")["labels"].tolist()
        ap_scores[f"{cdr_name} light"] = roc_auc_score(labs_cdr, preds_cdr)
        preds_cdr = predictions.query("IMGT_bis in @cdr_range and chain_type=='heavy'")[column].tolist()
        labs_cdr = predictions.query("IMGT_bis in @cdr_range and chain_type=='heavy'")["labels"].tolist()
        ap_scores[f"{cdr_name} heavy"] = roc_auc_score(labs_cdr, preds_cdr)

    for name, range in zip(["CDRs","Framework","Whole sequence"],[cdrs,framework, all]):
        preds_cdr = predictions.query("IMGT_bis in @range")[column].tolist()
        labs_cdr = predictions.query("IMGT_bis in @range")["labels"].tolist()
        ap_scores[name] = roc_auc_score(labs_cdr, preds_cdr)

    return ap_scores


# LLM / Paragraph / combined

In [9]:
def get_ap_roc_f1_mcc_df(llm_path, paragraph_path, threshold=0.5):
    predictions_llm = pd.read_csv(llm_path)
    predictions_llm["IMGT_bis"] = predictions_llm["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
    predictions_paragraph = pd.read_csv(paragraph_path)
    predictions_paragraph["IMGT_bis"] = predictions_paragraph["IMGT"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)

    combined = pd.merge(predictions_llm, predictions_paragraph[['pdb', 'IMGT', 'chain_type', 'prediction']],
                        on=["pdb", "IMGT", "chain_type"],
                        how="left",
                        suffixes=("_llm", "_paragraph"))

    combined["prediction"] = np.where(
        combined["prediction_paragraph"].notna(),
        combined["prediction_paragraph"],
        combined["prediction_llm"]
    )
    combined["prediction_paragraph"] = np.where(
        combined["prediction_paragraph"].notna(),
        combined["prediction_paragraph"],
        0
    )

    # Compute all metrics
    ap_dict = {
        "LLM": get_ap_scores(combined, column="prediction_llm"),
        "Paragraph": get_ap_scores(combined, column="prediction_paragraph"),
        "Combined": get_ap_scores(combined, column="prediction"),
    }

    roc_dict = {
        "LLM": get_roc_scores(combined, column="prediction_llm"),
        "Paragraph": get_roc_scores(combined, column="prediction_paragraph"),
        "Combined": get_roc_scores(combined, column="prediction"),
    }

    f1_dict = {
        "LLM": get_f1_scores(combined, column="prediction_llm", threshold=threshold),
        "Paragraph": get_f1_scores(combined, column="prediction_paragraph", threshold=threshold),
        "Combined": get_f1_scores(combined, column="prediction", threshold=threshold),
    }

    mcc_dict = {
        "LLM": get_mcc_scores(combined, column="prediction_llm", threshold=threshold),
        "Paragraph": get_mcc_scores(combined, column="prediction_paragraph", threshold=threshold),
        "Combined": get_mcc_scores(combined, column="prediction", threshold=threshold),
    }

    # Convert to DataFrames
    ap_df = pd.DataFrame(ap_dict)
    roc_df = pd.DataFrame(roc_dict)
    f1_df = pd.DataFrame(f1_dict)
    mcc_df = pd.DataFrame(mcc_dict)

    return ap_df, roc_df, f1_df, mcc_df


In [8]:
ap_df_list = []
roc_df_list = []
f1_df_list = []
mcc_df_list = []

for seed in range(1, 17):
    llm_path = f"/home/athenes/benchmark2/paragraph/250106/lr-0.00001_dr-0.4,0.4,0.4_mk-0.2_bs-16_dim1-2000,1000,500_alphas-4,5,6_pen-0.00001_weight_1_multi___emb_all_seed_{seed}/prediction_test_set.csv"
    paragraph_path = f"/home/athenes/benchmark2/3D_paragraph/one-hot/{seed}/prediction_abb3.csv"

    ap_df, roc_df, f1_df, mcc_df = get_ap_roc_f1_mcc_df(llm_path, paragraph_path)

    ap_df_list.append(ap_df)
    roc_df_list.append(roc_df)
    f1_df_list.append(f1_df)
    mcc_df_list.append(mcc_df)

final_abb3_ap = sum(ap_df_list) / len(ap_df_list)
final_abb3_roc = sum(roc_df_list) / len(roc_df_list)
final_abb3_f1 = sum(f1_df_list) / len(f1_df_list)
final_abb3_mcc = sum(mcc_df_list) / len(mcc_df_list)


In [9]:
ap_df_list = []
roc_df_list = []
f1_df_list = []
mcc_df_list = []

for seed in range(1, 17):
    llm_path = f"/home/athenes/benchmark2/paragraph/250106/lr-0.00001_dr-0.4,0.4,0.4_mk-0.2_bs-16_dim1-2000,1000,500_alphas-4,5,6_pen-0.00001_weight_1_multi___emb_all_seed_{seed}/prediction_test_set.csv"
    paragraph_path = f"/home/athenes/benchmark2/3D_paragraph/one-hot/{seed}/prediction_.csv"

    ap_df, roc_df, f1_df, mcc_df = get_ap_roc_f1_mcc_df(llm_path, paragraph_path)

    ap_df_list.append(ap_df)
    roc_df_list.append(roc_df)
    f1_df_list.append(f1_df)
    mcc_df_list.append(mcc_df)

final_crystal_ap = sum(ap_df_list) / len(ap_df_list)
final_crystal_roc = sum(roc_df_list) / len(roc_df_list)
final_crystal_f1 = sum(f1_df_list) / len(f1_df_list)
final_crystal_mcc = sum(mcc_df_list) / len(mcc_df_list)


In [10]:
print(final_abb3_ap[-3:])
print(final_abb3_roc[-3:])


                     LLM  Paragraph  Combined
CDRs            0.771617   0.779981  0.779843
Framework       0.555931   0.407509  0.541183
Whole sequence  0.742213   0.722582  0.746903
                     LLM  Paragraph  Combined
CDRs            0.871009   0.871299  0.871193
Framework       0.966327   0.766798  0.961907
Whole sequence  0.966458   0.933710  0.965233


In [12]:
print(final_crystal_ap[-3:])
print(final_crystal_roc[-3:])
print(final_crystal_mcc[-3:])
print(final_crystal_f1[-3:])


                     LLM  Paragraph  Combined
CDRs            0.771617   0.804365  0.804365
Framework       0.555931   0.437125  0.573009
Whole sequence  0.742213   0.747320  0.772229
                     LLM  Paragraph  Combined
CDRs            0.871009   0.884247  0.884247
Framework       0.966327   0.767355  0.962992
Whole sequence  0.966458   0.936074  0.967905
                     LLM  Paragraph  Combined
CDRs            0.567949   0.568494  0.568494
Framework       0.536639   0.501921  0.548297
Whole sequence  0.674982   0.677121  0.683605
                     LLM  Paragraph  Combined
CDRs            0.732942   0.735048  0.735048
Framework       0.535007   0.507706  0.556435
Whole sequence  0.705818   0.704540  0.709663


In [None]:
print(final_crystal_ap[-3:])
print(final_crystal_roc[-3:])