# Zero-shot Inference

This file uses the GPT-3.5 API to perform zero-shot inference.

In [None]:
import os
import pandas as pd
from tqdm import tqdm
import tenacity
from retry import retry
import backoff 
import openai
from openai import OpenAI, AzureOpenAI


model_name = "gpt-3.5-turbo"
dataset_name = "assistments09" # statics, assistments09, assistments17
approach = "minimal" # minimal, extended

API_KEY = "" # ANON
ORGANIZATION = "" # ANON
openai.api_key = API_KEY
openai.organization = ORGANIZATION
client = OpenAI(api_key=API_KEY, organization=ORGANIZATION)


os.environ["WANDB_DISABLED"] = "true"

In [None]:
# From https://galea.medium.com/how-to-love-jsonl-using-json-line-format-in-your-workflow-b6884f65175b

from json import JSONEncoder

class MyEncoder(JSONEncoder):
        def default(self, o):
            return o.__dict__ 

import json

def dump_jsonl(data, output_path, append=False):
    """
    Write list of objects to a JSON lines file.
    """
    mode = 'a+' if append else 'w'
    with open(output_path, mode, encoding='utf-8') as f:
        for line in data:
            json_record = json.dumps(line, ensure_ascii=False, cls=MyEncoder)
            f.write(json_record + '\n')
    print('Wrote {} records to {}'.format(len(data), output_path))

def load_jsonl(input_path) -> list:
    """
    Read list of objects from a JSON lines file.
    """
    data = []
    with open(input_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.rstrip('\n|\r')))
    print('Loaded {} records from {}'.format(len(data), input_path))
    return data

class JSONLDataObject:
    prompt = ""
    completion = ""

    def __init__(self, prompt, completion):
        self.prompt = prompt
        self.completion = completion

    def __repr__(self):
        return repr((self.prompt, self.completion))


In [None]:
test_data = load_jsonl(f'jsonl_files/{dataset_name}-{approach}-test.jsonl')
print(len(test_data), test_data[0])

In [None]:
all_completions_temp_0 = []
all_original_temp_0 = []
all_logprobs = []

In [None]:
with tqdm(total=len(test_data) - 2401) as pbar:
    for i in range(2401, len(test_data)):
        prompt = test_data[i]['prompt']
        response = client.chat.completions.create(
            model=model_name,
            messages=[
                {"role": "system", "content": "You are an instructor and want to predict whether a student will get a question CORRECT or WRONG. The only information you have is the student's previous answers to a series of related questions. You know how many questions they got CORRECT and how many they got WRONG. Based on this information, you should make a prediction by outputting a single word: CORRECT if you think the student will answer the next question correctly, and WRONG if you think the student will answer the next question wrong. Output no other word at all, this is very important. Try to estimate the knowledge of the student before making your prediction."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.0,
            logprobs=True,
            top_logprobs=5
            
        )
        all_completions_temp_0.append(response.choices[0].message.content.split(' ')[0].strip())
        all_original_temp_0.append(prompt)
        all_logprobs.append(response.choices[0].logprobs)
        pbar.update(1)

unique_values = list(set(all_completions_temp_0))
print(unique_values) # should be ['CORRECT', 'WRONG']
print(len(all_completions_temp_0), len(all_original_temp_0), len(all_logprobs), len(test_data))

In [None]:
print("Example test data", test_data[0])

true_completions = []
for i in range(len(test_data)):
    true_completions.append(test_data[i]['completion'])

all_completions_cleaned = []
for completion in all_completions_temp_0:
    if 'correct' in completion.lower():
        all_completions_cleaned.append('CORRECT')
    else:
        all_completions_cleaned.append('WRONG')


# convert all CORRECT to 1 and all WRONG to 0
true_completions_binary = []
for completion in true_completions:
    if 'correct' in completion.lower():
        true_completions_binary.append(1)
    else:
        true_completions_binary.append(0)

all_completions_binary = []
for completion in all_completions_cleaned:
    if 'correct' in completion.lower():
        all_completions_binary.append(1)
    else:
        all_completions_binary.append(0)

print(list(set(true_completions_binary))) # should be [0, 1]
print(list(set(all_completions_binary))) # should be [0, 1]

In [None]:
from sklearn.metrics import balanced_accuracy_score, accuracy_score, f1_score, precision_score, recall_score

print(all_completions_binary.count(0))
print(all_completions_binary.count(1))

print("Balanced accuracy =>", "{:.4f}".format(balanced_accuracy_score(true_completions_binary, all_completions_binary)))
print("Raw Accuracy =>", "{:.4f}".format(accuracy_score(true_completions_binary, all_completions_binary)))
print("F1 =>", "{:.4f}".format(f1_score(true_completions_binary, all_completions_binary)))
print("Precision =>", "{:.4f}".format(precision_score(true_completions_binary, all_completions_binary)))
print("Recall =>", "{:.4f}".format(recall_score(true_completions_binary, all_completions_binary)))

with open(f'inference_results/feb16pm-{dataset_name}-{approach}-nospace-zero-shot-true_completions_binary.txt', 'w') as f:
    for item in true_completions_binary:
        f.write("%s\n" % item)

with open(f'inference_results/feb16pm-{dataset_name}-{approach}-nospace-zero-shot-all_completions_binary.txt', 'w') as f:
    for item in all_completions_binary:
        f.write("%s\n" % item)

# save logprobs to a file
with open(f'inference_results/feb16pm-{dataset_name}-{approach}-nospace-zero-shot-logprobs.txt', 'w') as f:
    for item in all_logprobs:
        f.write("%s\n" % str(item))