This file calculates the AUC and RMSE for the predictions of the models.

In [None]:
from math import exp, sqrt
from tqdm import tqdm
from sklearn.metrics import mean_squared_error, roc_auc_score


class Logprobs:
    text_offset = []
    token_logprobs = []
    tokens = []
    top_logprobs = []

    def __init__(self, text_offset, token_logprobs, tokens, top_logprobs):
        self.text_offset = text_offset
        self.token_logprobs = token_logprobs
        self.tokens = tokens
        self.top_logprobs = top_logprobs

class TopLogprob:
    token = ""
    bytes = []
    logprob = 0.0

    def __init__(self, token, bytes, logprob):
        self.token = token
        self.bytes = bytes
        self.logprob = logprob

class ChatCompletionTokenLogprob:
    token = ""
    bytes = []
    logprob = 0.0
    top_logprobs = []

    def __init__(self, token, bytes, logprob, top_logprobs):
        self.token = token
        self.bytes = bytes
        self.logprob = logprob
        self.top_logprobs = top_logprobs

class ChoiceLogprobs:
    content = None

    def __init__(self, content):
        self.content = content

In [None]:
for dataset_name in ["statics", "assistments09", "assistments17"]:
    for approach in ["minimal", "extended", "minimal-zero-shot"]:
        print(f"=====\nProcessing {dataset_name} with {approach} approach")
        replaced_approach = approach.replace("-zero-shot", "")
        with open(f"inference_results/{dataset_name}-{replaced_approach}-true_completions_binary.txt", "r") as f:
            true_completions = f.read().splitlines()
            true_completions = [int(x) for x in true_completions]

        with open(f"inference_results/{dataset_name}-{approach}-all_completions_binary.txt", "r") as f:
            all_completions = f.read().splitlines()
            all_completions = [int(x) for x in all_completions]

        with open(f"inference_results/{dataset_name}-{approach}-logprobs.txt", "r") as f:
            logprobs_raw = f.read().splitlines()

        probs = []
        for logprob_line in logprobs_raw:
            if "zero-shot" in approach:
                top_logprobs = eval(logprob_line).content[0].top_logprobs
                logprob_dict = {top_logprob.token: top_logprob.logprob for top_logprob in top_logprobs}
            else:
                logprob_dict = eval(logprob_line).top_logprobs[0]
            total_prob_of_correct = 0
            total_prob_of_wrong = 0
            for token, logprob in logprob_dict.items():
                if ("correct").startswith(token.lower()):
                    total_prob_of_correct += exp(logprob)
                elif ("wrong").startswith(token.lower()):
                    total_prob_of_wrong += exp(logprob)
            probs.append(total_prob_of_correct / (total_prob_of_correct + total_prob_of_wrong))

        auc = roc_auc_score(true_completions, probs)
        mse = mean_squared_error(true_completions, probs)
        rmse = sqrt(mse)
        print(f"AUC: {auc}\nRMSE: {rmse}")