In [None]:
import csv

import joblib
import numpy as np

eps = 1e-6


def score_func_precision_recall(eval_result_path, outputpath):
    data = joblib.load(eval_result_path)
    y_pred, y_true, unmon_label = data["y_pred"], data["y_true"], data["unmonitored_label"]
    y_pred = np.exp(y_pred) / (np.exp(y_pred).sum(axis=1)).reshape(-1, 1)
    file = open(outputpath, "w+", encoding="utf-8", newline="")
    csvwirter = csv.writer(file)
    upper_bound = 1.0
    thresholds = upper_bound - upper_bound / np.logspace(0.05, 2, num=15, endpoint=True)
    csvwirter.writerow(["TH", "TP", "TN", "FP", "FN", "Precision", "Recall"])
    fmt_str = "{:.2f}:\t{}\t{}\t{}\t{}\t{:.3f}\t{:.3f}"

    # evaluate list performance at different thresholds
    # high threshold will yield higher precision, but reduced recall
    for TH in thresholds:
        TP, FP, TN, FN = 0, 0, 0, 0

        # Test with Monitored testing instances
        for i in range(len(y_pred)):
            ground_truths = y_true.argmax(axis=1)[i]
            sm_vector = y_pred[i]
            predicted_class = np.argmax(sm_vector)
            max_prob = max(sm_vector)
            if ground_truths != unmon_label:
                if predicted_class == ground_truths:  # predicted as Monitored
                    if max_prob >= TH:  # predicted as Monitored and actual site is Monitored
                        TP = TP + 1
                    else:  # predicted as Unmonitored and actual site is Monitored
                        FN = FN + 1
                else:  # predicted as Unmonitored and actual site is Monitored
                    FN = FN + 1
            else:
                if predicted_class != unmon_label:  # predicted as Monitored
                    if max_prob >= TH:  # predicted as Monitored and actual site is Unmonitored
                        FP = FP + 1
                    else:  # predicted as Unmonitored and actual site is Unmonitored
                        TN = TN + 1
                else:  # predicted as Unmonitored and actual site is Unmonitored
                    TN = TN + 1
        res = [TH, TP, TN, FP, FN, float(TP) / (TP + FP + eps), float(TP) / (TP + FN + eps)]
        print(fmt_str.format(*res))
        csvwirter.writerow(res)

    file.close()


import matplotlib.pyplot as plt
import pandas as pd


def draw_pr_curve(csv_paths):
    fig, ax = plt.subplots(figsize=(5, 4))
    for csv_path in csv_paths:
        df = pd.read_csv(csv_path)
        ax.plot(
            df["Recall"],
            df["Precision"],
            label=csv_path.split("/")[-1].split(".")[0],
            lw=2,
            marker=".",
        )
    plt.legend()
    plt.xlim([-0.05, 1.05])
    plt.ylim([0.1, 1.05])
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("PR Curve")
    plt.show()
    return

In [None]:
score_func_precision_recall("run/test/20240322testow/ours/RF/front/front.pkl", "log/pr_front")
score_func_precision_recall("run/test/20240322testow/ours/RF/empty/empty.pkl", "log/pr_empty")

In [None]:
draw_pr_curve(["log/pr_front", "log/pr_empty"])