In [None]:
import os
import time

import gqr
import litellm
import pandas as pd
from tqdm import tqdm

In [None]:
API_KEY = os.getenv("API_KEY")
API_BASE = os.getenv("API_BASE")

In [None]:
domain = gqr.load_id_test_dataset()
ood = gqr.load_ood_test_dataset()

In [None]:
def batch_categorize_passages(
    passages: list[str], model_name: str, api_key: str, api_base: str
) -> list[str]:
    system_prompt = """You are a highly accurate text classifier.
    Your task is to categorize passages into one of four predefined domains.
    The ONLY valid categories are: Law, Finance, Health, and Other.

    Any passage that does not clearly belong to Law, Finance, or Health MUST be categorized as Other.
    You must respond with ONLY the category name, and nothing else.  No explanations, no extra words.
    """

    prompts = []
    for passage in passages:
        prompt = f"""Classify the following passage into one of the categories: Law, Finance, Health, or Other.

        Passage:
        {passage}

        Category: """
        prompts.append({"role": "user", "content": prompt})

    messages_batch = []
    for user_prompt in prompts:
        messages_batch.append(
            [{"role": "system", "content": system_prompt}, user_prompt]
        )

    response_batch = litellm.batch_completion(
        model=model_name,
        api_key=api_key,
        api_base=api_base,
        temperature=0,
        messages=messages_batch,
    )
    # print(response_batch)
    responses = []
    for response in response_batch:
        try:
            responses.append(response["choices"][0]["message"]["content"].strip())
        except Exception:
            print(response)
            responses.append("Error")
    return responses

In [None]:
for model_name in ["openai/gpt-4o-mini", "llama3.2:3b", "llama3.1:8b"]:
    save_path = model_name.replace("/", "_")
    save_path = f"naive_{save_path}"
    BATCH_SIZE = 16

    # OOD dataset
    prompts = ood['text'].values.tolist()
    batches = [prompts[i:i + BATCH_SIZE] for i in range(0, len(prompts), BATCH_SIZE)]
    results = []
    for batch in tqdm(batches, desc="Processing batches"):
        batch_results = batch_categorize_passages(batch, model_name=model_name, api_key=API_KEY, api_base=API_BASE)
        results.extend(batch_results)
    ood['pred'] = results
    ood.to_csv(f"data/results/{save_path}_ood_results.csv", index=False)

    # ID dataset
    prompts = domain['text'].values.tolist()
    batches = [prompts[i:i + BATCH_SIZE] for i in range(0, len(prompts), BATCH_SIZE)]
    results = []
    for batch in tqdm(batches, desc="Processing batches"):
        batch_results = batch_categorize_passages(batch, model_name=model_name, api_key=API_KEY, api_base=API_BASE)
        results.extend(batch_results)
    domain['pred'] = results
    domain.to_csv(f"data/results/{save_path}_domain_results.csv", index=False)

    data = pd.read_csv("data/batch_data.csv")
    batch_results =[]
    data = data['prompt'].tolist()
    for batch_size in tqdm([1,32,64, 128, 256]):
        batches = [
            data[i : i + batch_size]
            for i in range(0, len(data), batch_size)
        ]
        for batch in batches:
            start_time = time.perf_counter()
            results = batch_categorize_passages(batch)
            end_time = time.perf_counter()
            elapsed_time = end_time - start_time
            batch_results.append({
                "batch_size": batch_size,
                "time_taken": elapsed_time,
                "results": results,
                "model_name": "naive-gpt4o-mini",
            })

    pd.DataFrame(batch_results).to_csv(f"data/batch/{save_path}-batch.csv", index=False)