In [33]:
import pandas as pd
import os
from openai import OpenAI

from dotenv import load_dotenv
import json

# Load OPENAI_API_KEY from .env file
load_dotenv()

client = OpenAI()

In [26]:
def extract_json_from_string(input_string):
    # Find the start and end of the JSON object in the string
    start_index = input_string.find('{')
    end_index = input_string.rfind('}') + 1
    
    if start_index != -1 and end_index != -1:
        # Extract the JSON string
        json_str = input_string[start_index:end_index]
        
        # Convert the JSON string into a Python object
        json_obj = json.loads(json_str)
        
        return json_obj
    else:
        raise ValueError("No JSON found in the input string.")

## Filtration approach

In [3]:
def create_prompt(prompt, custom_id):
    return {
        "custom_id": custom_id,
        "method": "POST",
        "url": "/v1/chat/completions",
        "body": {
            "model": "gpt-4o-2024-08-06",
            "messages": [
                {"role": "user", "content": prompt},
            ],
            "max_tokens": 1000,
            "temperature": 0
        }
    }

In [24]:
discarded_examples = pd.read_csv("../../data/wdc/filtered/small/discarded_examples.csv")
prompt_template = json.load(open("../../prompts/error_analysis.json"))

In [50]:
discarded_examples["label"].shape

(494,)

In [12]:
# loop through the rows and create the prompts
prompts = []
for index, row in discarded_examples.iterrows():
    # if label = 1, then the system answer is yes, else no
    if row["label"] == 1:
        system_answer = "yes"
        correct_answer = "no"
    else:
        system_answer = "no"
        correct_answer = "yes"
    
    prompt = prompt_template.get("prompt")
    prompt = prompt.replace("{{entity_1}}", row["title_left"])
    prompt = prompt.replace("{{entity_2}}", row["title_right"])
    prompt = prompt.replace("{{system_answer}}", system_answer)
    prompt = prompt.replace("{{correct_label}}", correct_answer)
    prompts.append(create_prompt(prompt, str(index)))
    
# Start a batch request
batch_file_path = "filter.jsonl"
with open(batch_file_path, "w") as f:
    for request in prompts:
        f.write(json.dumps(request) + "\n")

batch_input_file = client.files.create(
    file=open(batch_file_path, "rb"),
    purpose="batch"
)

batch_input_file_id = batch_input_file.id

batch = client.batches.create(
    input_file_id=batch_input_file_id,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata={"description": "Error analysis for discarded examples"}
)

# delete the batch input file
os.remove(batch_file_path)  

In [35]:
def parse_response(response):
    body = response.get("body", {})
    usage = body.get("usage", {})
    choices = body.get("choices", [{}])
    message = choices[0].get("message", {}) if choices else {}

    return pd.Series({
        "status_code": response.get("status_code"),
        "request_id": response.get("request_id"),
        "completion_id": body.get("id"),
        "created": body.get("created"),
        "model": body.get("model"),
        "content": message.get("content"),
        "prompt_tokens": usage.get("prompt_tokens"),
        "completion_tokens": usage.get("completion_tokens"),
        "total_tokens": usage.get("total_tokens"),
    })

In [44]:
# Function to extract values into new columns
def extract_error_classes(error_dict, class_number):
    error_class =  error_dict.get(str(class_number), None)
    if error_class:
        return int(error_class)
    else:
        return None

In [45]:
# Load the batch completions
error_analysis = pd.read_json("../../data/wdc/filtered/small/discarded_analysis.jsonl", lines=True)
error_analysis_parsed = error_analysis["response"].apply(parse_response)
error_analysis = pd.concat([error_analysis, error_analysis_parsed], axis=1)

# Extract the json from the content
error_analysis["error_classes"] = error_analysis["content"].apply(lambda x: extract_json_from_string(x))
# Create new columns for error classes 1 to 9
for i in range(1, 10):
    error_analysis[f'error_class_{i}'] = error_analysis['error_classes'].apply(lambda x: extract_error_classes(x, i))

In [46]:
# Count the occurrences where the value is greater than 70 for each error class
counts = {}
for i in range(1, 10):
    counts[f'error_class_{i}'] = (error_analysis[f'error_class_{i}'] > 70).sum()

# Display the counts
print(counts)

{'error_class_1': 223, 'error_class_2': 335, 'error_class_3': 239, 'error_class_4': 0, 'error_class_5': 249, 'error_class_6': 34, 'error_class_7': 77, 'error_class_8': 38, 'error_class_9': 0}
