In [None]:
import re

import openai
import pandas as pd
from langchain.prompts import PromptTemplate
from tenacity import retry, wait_random_exponential
from tqdm import tqdm

In [None]:
# Load csv file
data = pd.read_json(
    path_or_buf="response_curation_prompt_set_08012023_67802 prompts.jsonl", lines=True
)

# test first 8 prompts end to end
# data = data[:8]

data

In [None]:
# List targets domains
domains = [
    "science",
    "technology",
    "engineering",
    "math",
    "health",
    "coding",
    "business",
    "general",
    "finance",
    "legal",
    "writing",
]

# Get prompt data
prompt_column = "prompt"
prompt_data = data[prompt_column]


# Batch data
batch_size = 8  # default batch size
prompt_data_batches = [
    prompt_data.iloc[i : i + batch_size] for i in range(0, len(prompt_data), batch_size)
]

In [None]:
# String representation of domains
def str_domains(domains):
    result = ""
    for i in range(len(domains)):
        if i < len(domains) - 1:
            result += f"{domains[i]}, "
        else:
            result += f"and {domains[i]}"
    return result


# Zero Shot Template
Zero_Shot_template = """You are a domain classification expert.
Your task is to analyse a given prompt and classify it into one of {nb_domains} domains.
The seven domains to choose from are {str_domains}.
                      
Prompt:
                      
{prompt}
                      
The domain classification is:"""

Zero_Shot = PromptTemplate(
    template=Zero_Shot_template,
    input_variables=["prompt"],
    partial_variables={"nb_domains": len(domains), "str_domains": str_domains(domains)},
)

In [None]:
# Prompt formatter
def format_prompts(prompts, prompt_template):
    formatted_prompts = []
    for p in prompts:
        prompt_formatted = prompt_template.format(prompt=p)
        formatted_prompts.append(prompt_formatted)
    return formatted_prompts


# Cleaning and formatting of AI responses; commented code is for logprobs if not 0
def extract_domains(responses):
    # top_logprobs = [response['logprobs']['top_logprobs'][0].keys() for response in responses.choices]
    extracted = [response["text"] for response in responses.choices]

    # def is_domain(d):
    #     return d in domains
    def clean_resp(resp):
        resp = re.sub(r"[^a-zA-Z ]", "", resp)
        resp = resp.lower().strip()
        return resp

    cleaned_choices = [clean_resp(choice) for choice in extracted]
    # top_domains = (filter(is_domain, cleaned_choices)) # only take choices that match domains
    # top_domains = list(dict.fromkeys(top_domains)) # remove duplicates
    # top_domains = ', '.join(top_domains)
    # results.append(top_domains)
    # return results
    return cleaned_choices

In [None]:
# Get API responses
openai.api_key = "api_key"


@retry(
    wait=wait_random_exponential(min=5, max=60)
)  # necessary retry wrapper to handle Rate Limit Error
def get_openai_response(prompts, answers_per_resp=0):
    responses = openai.Completion.create(
        engine="text-davinci-003",
        prompt=prompts,
        max_tokens=2,
        logprobs=answers_per_resp,  # get top <=answer_per_resp predictions
    )
    return extract_domains(responses)


# I suggest uncommenting the count variable and its printing because it allows you to see if there is anything wrong with the api calls
def get_domain_predictions(prompt_template):
    predictions = []
    failures = []

    for batch in tqdm(prompt_data_batches):
        try:
            formatted_prompts = format_prompts(list(batch), prompt_template)
            predictions.extend(get_openai_response(formatted_prompts))

        except:
            failures.append(batch)

    return predictions, failures

In [None]:
domain_predictions, failures = get_domain_predictions(Zero_Shot)
print(len(domain_predictions), len(failures))

data["domain_classifications"] = domain_predictions
data.to_csv(
    "domain_classified_data.csv", index=False
)  # output CSV file of slice of data + corresponding predictions

In [None]:
failures