# Inference

This file uses the OpenAI API to generate responses by our fine-tuned models.

In [None]:
# Configuration variables

dataset_name = "statics" # statics, assistments09, assistments17
approach = "minimal" # minimal, extended
model_name_from_api = "" # ANON

In [None]:
# Imports

import os
import pandas as pd
from tqdm import tqdm
import tenacity
from retry import retry
import backoff 
import openai
from openai import OpenAI

os.environ["WANDB_DISABLED"] = "true"

# From https://galea.medium.com/how-to-love-jsonl-using-json-line-format-in-your-workflow-b6884f65175b

from json import JSONEncoder

class MyEncoder(JSONEncoder):
        def default(self, o):
            return o.__dict__ 
        
import json

def dump_jsonl(data, output_path, append=False):
    """
    Write list of objects to a JSON lines file.
    """
    mode = 'a+' if append else 'w'
    with open(output_path, mode, encoding='utf-8') as f:
        for line in data:
            json_record = json.dumps(line, ensure_ascii=False, cls=MyEncoder)
            f.write(json_record + '\n')
    print('Wrote {} records to {}'.format(len(data), output_path))

def load_jsonl(input_path) -> list:
    """
    Read list of objects from a JSON lines file.
    """
    data = []
    with open(input_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.rstrip('\n|\r')))
    print('Loaded {} records from {}'.format(len(data), input_path))
    return data

class JSONLDataObject:
    prompt = ""
    completion = ""

    def __init__(self, prompt, completion):
        self.prompt = prompt
        self.completion = completion

    def __repr__(self):
        return repr((self.prompt, self.completion))


In [None]:
API_KEY = "" # ANON
ORGANIZATION = "" # ANON
openai.api_key = API_KEY
openai.organization = ORGANIZATION
client = OpenAI(api_key=API_KEY, organization=ORGANIZATION)

In [None]:
test_data = load_jsonl(f"jsonl_files/{dataset_name}-{approach}-test.jsonl")
print(len(test_data), test_data[0]) # sanity check

In [None]:
all_completions = []
all_original = []
all_logprobs = []

In [None]:
with tqdm(total=len(test_data)) as pbar:
    for i in range(len(test_data)):
        prompt = test_data[i]['prompt']
        response = client.completions.create(
            model=model_name_from_api,
            prompt=prompt,
            max_tokens=2,
            temperature=0.0,
            logprobs=20
        )
        all_completions.append(response.choices[0].text.split(' ')[0].strip())
        all_original.append(prompt)
        all_logprobs.append(str(response.choices[0].logprobs).strip())
        pbar.update(1)

unique_values = list(set(all_completions))
print(unique_values) # should be ['CORRECT', 'WRONG']

In [None]:
print(len(test_data), len(all_completions), len(all_original), len(all_logprobs)) # sanity check

In [None]:
true_completions = []
for i in range(len(test_data)):
    true_completions.append(test_data[i]['completion'])

all_completions_cleaned = []
for completion in all_completions:
    if 'correct' in completion.lower():
        all_completions_cleaned.append('CORRECT')
    else:
        all_completions_cleaned.append('WRONG')


# convert all CORRECT to 1 and all WRONG to 0
true_completions_binary = []
for completion in true_completions:
    if 'correct' in completion.lower():
        true_completions_binary.append(1)
    else:
        true_completions_binary.append(0)

all_completions_binary = []
for completion in all_completions_cleaned:
    if 'correct' in completion.lower():
        all_completions_binary.append(1)
    else:
        all_completions_binary.append(0)

print(list(set(true_completions_binary)), len(true_completions_binary)) # should start with [0, 1]
print(list(set(all_completions_binary)), len(all_completions_binary)) # should start with [0, 1]

In [None]:
from sklearn.metrics import balanced_accuracy_score, accuracy_score, f1_score, roc_auc_score, precision_score, recall_score

print("Count of zero:", all_completions_binary.count(0))
print("Count of one:", all_completions_binary.count(1))
print("")

print("Balanced accuracy =>", "{:.4f}".format(balanced_accuracy_score(true_completions_binary, all_completions_binary)))
print("Raw Accuracy =>", "{:.4f}".format(accuracy_score(true_completions_binary, all_completions_binary)))
print("F1 =>", "{:.4f}".format(f1_score(true_completions_binary, all_completions_binary)))
print("Precision =>", "{:.4f}".format(precision_score(true_completions_binary, all_completions_binary)))
print("Recall =>", "{:.4f}".format(recall_score(true_completions_binary, all_completions_binary)))

# save true_completions_binary and all_completions_binary to a file
with open(f'inference_results/{dataset_name}-{approach}-true_completions_binary.txt', 'w') as f:
    for item in true_completions_binary:
        f.write("%s\n" % item)

with open(f'inference_results/{dataset_name}-{approach}-all_completions_binary.txt', 'w') as f:
    for item in all_completions_binary:
        f.write("%s\n" % item)

with open(f'inference_results/{dataset_name}-{approach}-logprobs.txt', 'w') as f:
    for item in all_logprobs:
        f.write("%s\n" % item)