### Investigating the optimal inference threshold for binary RE

In [None]:
import json
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_recall_curve
import numpy as np
import pandas as pd

with open('/kaggle/input/predictions/Predictions_mention_based_PubMedBERT_29.json') as f: # predictions
    preds = json.load(f)
with open('/kaggle/input/predictions/dev.json') as f: # ground truth
    gts = json.load(f)


true_pairs = set() # get true relations from the ground truth
for article_id, data in gts.items():
    for rel in data.get('relations', []):
        key = (rel['subject_start_idx'], rel['object_start_idx'],rel['subject_label'],rel['object_label']) # just take start idx  because we do not consider span based errors in binary RE (heuristic)
        true_pairs.add(key)

scores = []
labels = []  # correct (TP), incorrect (FP)

for article_id, doc in preds.items():
    for rel in doc.get('binary_mention_based_relations', []):
        key = (rel['subject_start_index'], rel['object_start_index'],rel['subject_label'],rel['object_label'])
        scores.append(rel['score'])
        labels.append(1 if key in true_pairs else 0)

In [None]:
df = pd.DataFrame({
    'Confidence score': scores,
    'Correctness': ['TP' if l else 'FP' for l in labels]
})

pastel_green = '#a8e6a8'
pastel_red = '#e6b8b8' 

plt.figure(figsize=(6, 5))
sns.boxplot(x='Correctness', y='Confidence score', data=df, palette={'TP': pastel_green, 'FP': pastel_red})

plt.tight_layout()
plt.savefig("binary_RE_box_Plot")
plt.show()
df.groupby('Correctness')['Confidence score'].describe()

In [None]:
tp_scores = [s for s, l in zip(scores, labels) if l == 1]
fp_scores = [s for s, l in zip(scores, labels) if l == 0]
plt.figure(figsize=(8, 5))

# make a histrogram of FP and TP distributions
plt.hist(fp_scores, bins=20, alpha=0.6, label='False Positives', color=pastel_red, edgecolor='grey')
plt.hist(tp_scores, bins=20, alpha=0.7, label='True Positives', color=pastel_green, edgecolor='grey')
plt.xlabel('Confidence Score')
plt.ylabel('Frequency')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# zoomed in histogram for high confidence predictions
bins = np.arange(0.9, 1.005, 0.005)

plt.figure(figsize=(8, 5))
plt.hist(fp_scores, bins=bins, alpha=0.6, label='FPs', color=pastel_red, edgecolor='grey')
plt.hist(tp_scores, bins=bins, alpha=0.7, label='TPs', color=pastel_green, edgecolor='grey')
xticks = np.arange(0.90, 1.01, 0.01)
plt.xticks(xticks, [f"{x:.2f}" for x in xticks])

plt.xlabel('Confidence Score')
plt.ylabel('Frequency')
plt.legend()
plt.tight_layout()
plt.savefig("histogram_tp_fp.png")
plt.show()

In [None]:
# "Compute precision-recall pairs for different probability thresholds", cf. https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html, assessed 7th May 2025
# This is for mention based binary RE, it is just an approximation for tag based RE... , cf. https://www.blog.trainindata.com/precision-recall-curves/, assessed 7th May 2025

scores = np.array(scores) 
labels = np.array(labels) # true labels

precision, recall, thresholds = precision_recall_curve(labels, scores) # get precision and recall 
f1_scores = 2 * (precision * recall) / (precision + recall + 1e-12) # compute f1s for every threshold, add small value to avoid zero division error

best_idx       = np.argmax(f1_scores)
best_thresh    = thresholds[best_idx]
best_precision = precision[best_idx]
best_recall    = recall[best_idx]
best_f1        = f1_scores[best_idx]

print(f"Best threshold = {best_thresh}")
print(f"Precision = {best_precision}")
print(f"Recall = {best_recall}")
print(f"F1 = {best_f1}")

plt.figure(figsize=(8, 6))
plt.plot(recall, precision, label='Precision-Recall curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.legend(loc='best')
# Plot the threshold points
plt.scatter(recall[best_idx], precision[best_idx], color='red', label=f'Best Threshold = {best_thresh} (F1 = {best_f1})')
plt.legend(loc='best')
plt.savefig("precision-recall_curve.png")
plt.show()

In [None]:
from sklearn.metrics import precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt
with open('/kaggle/input/predictions/Predictions_mention_based_PubMedBERT_29.json') as f: # predictions
    preds = json.load(f)
with open('/kaggle/input/predictions/dev.json') as f: # ground truth
    gts = json.load(f)


true_pairs = set() # get true relations from the ground truth
for article_id, data in gts.items():
    for rel in data.get('relations', []):
        key = (rel['subject_start_idx'], rel['object_start_idx'],rel['subject_label'],rel['object_label']) # just take start idx  because we do not consider span based errors in binary RE 
        true_pairs.add(key)

scores = []
labels = []  # correct (TP), incorrect (FP)

for article_id, doc in preds.items():
    for rel in doc.get('binary_mention_based_relations', []):
        key = (rel['subject_start_index'], rel['object_start_index'],rel['subject_label'],rel['object_label'])
        scores.append(rel['score'])
        labels.append(1 if key in true_pairs else 0)

pred_pairs = {
    (rel['subject_start_index'], rel['object_start_index'],rel['subject_label'],rel['object_label'])
    for doc in preds.values()
    for rel in doc.get('binary_mention_based_relations', [])
}

# include FNs, i.e. missing ground truth positives
for true_pair in true_pairs:
    if true_pair not in pred_pairs:
        scores.append(0.0) # assign 0 score... This is only a heuristic since the predictions do only include positive relations.
        labels.append(1)

precision, recall, thresholds = precision_recall_curve(labels, scores)
f1_scores = 2 * (precision * recall) / (precision + recall + 1e-12) # compute f1s for every threshold, add small value to avoid zero division error

best_idx       = np.argmax(f1_scores)
best_thresh    = thresholds[best_idx]
best_precision = precision[best_idx]
best_recall    = recall[best_idx]
best_f1        = f1_scores[best_idx]

print(f"Best threshold = {best_thresh}")
print(f"Precision = {best_precision}")
print(f"Recall = {best_recall}")
print(f"F1 = {best_f1}")

plt.figure(figsize=(8, 6))
plt.plot(recall, precision, label='Precision-Recall curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.legend(loc='best')
# plot threshold point (max F1)
plt.scatter(recall[best_idx], precision[best_idx], color='red', label=f'Best Threshold = {best_thresh} (F1 = {best_f1})')
plt.legend(loc='best')
plt.savefig("precision-recall_curve.png")
plt.show()