In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
# Change final metrics for the plot
data = [
    ("Softmax", 0.0909, 0.00020),
    ("Sampled softmax", 0.0958, 0.00019),
    ("BCE", 0.0911, 0.00022),
    ("gBCE", 0.0869, 0.00020),
    ("Uniform", 0.0958, 0.00019),
    ("InBatch", 0.0128, 0.00002),
    ("InBatchLogQ", 0.0326, 0.00020),
    ("Mixed-0.4", 0.0862, 0.00012),
    ("MixedLogQ-0.4", 0.0854, 0.00082),
    ("Shifted", 0.0958, 0.00019),
    ("MLM", 0.0717, 0.00020),
    ("AllAction", 0.1137, 0.00015),
    ("DenseAllAction", 0.0318, 0.00044),
    ("SASRec", 0.0958, 0.00019),
    ("BERT4Rec", 0.0717, 0.00020),
    ("Dot", 0.0958, 0.00019),
    ("Cosine", 0.0232, 0.00033),
    ("Mixed-0.2", 0.0861, 0.00014),
    ("Mixed-0.6", 0.0876, 0.00012),
    ("Mixed-0.8", 0.0701, 0.00010),
    ("MixedLogQ-0.2", 0.0870, 0.00082),
    ("MixedLogQ-0.6", 0.0943, 0.00071),
    ("MixedLogQ-0.8", 0.0859, 0.00068),
]


df = pd.DataFrame(data, columns=["Model", "Recall", "ARP"])
pareto_optimal = []
for i, row in df.iterrows():
    dominated = False
    for j, competitor in df.iterrows():
        if (competitor["Recall"] >= row["Recall"] and competitor["ARP"] <= row["ARP"] and
            (competitor["Recall"] > row["Recall"] or competitor["ARP"] < row["ARP"])):
            dominated = True
            break
    if not dominated:
        pareto_optimal.append(i)


df["Type"] = ["Pareto" if i in pareto_optimal else "Non-Pareto" for i in df.index]

sns.set(style="whitegrid", context="talk")
plt.figure(figsize=(14, 8))
sns.scatterplot(
    data=df,
    x="Recall",
    y="ARP",
    hue="Type",
    style="Type",
    palette={"Pareto": "#E74C3C", "Non-Pareto": "#3498DB"},
    s=120
)

for i in pareto_optimal:
    plt.annotate(
        df.loc[i, "Model"],
        (df.loc[i, "Recall"], df.loc[i, "ARP"]),
        textcoords="offset points",
        xytext=(0, 10),
        ha='center',
        fontsize=10,
        fontweight='bold'
    )

plt.xlabel("Recall@10", fontsize=14)
plt.ylabel("ARP@10", fontsize=14)
plt.legend(title="", fontsize=12)
plt.tight_layout()
plt.grid(True, linestyle='--', alpha=0.5)
plt.show()