In [1]:
from openai import OpenAI
from prompts import product_aug_sys_prompt
from functools import cache
from tqdm import tqdm

import os
import json
from dotenv import load_dotenv
load_dotenv()

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

In [2]:
# consider help categories where product substitution is applicable
augmentable = open("product_augmentable.txt", 'r').readlines()
augmentable = [a.strip() for a in augmentable]

In [3]:
def augment_queries(queries: list[str]):
    augmented = []
    for query in queries:
        res = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": product_aug_sys_prompt},
                {"role": "user", "content": query},
            ],
            response_format={"type": "json_object"},
        )
        res = json.loads(res.choices[0].message.content)
        if res["customer_queries"]:
            augmented.extend(res["customer_queries"])

    return augmented

In [4]:
def persist(data: dict, file):
    """Persist data to the given file."""
    json.dump(data, file, ensure_ascii=False)
    file.write('\n')

In [5]:
def main(query_file_name: str):
    query_file = open(query_file_name, 'r')
    data_file = open("product_aug.jsonl", 'a')
    augmented = []
    for line in tqdm(query_file):
        temp = json.loads(line)
        if not temp["title"] in augmentable:
            continue
            
        augmented = augment_queries(temp["customer_queries"])
        augmented = {"customer_queries": augmented}
        augmented.update({"title": temp["title"], "tone": temp["tone"], "style": temp["style"]})

        persist(augmented, data_file)

    query_file.close()
    data_file.close()

In [6]:
main("queries.jsonl")

53it [04:43,  5.36s/it]


In [7]:
n = 0
with open("product_aug.jsonl", 'r') as f:
    for json_str in f:
        obj = json.loads(json_str)
        n += len(obj['customer_queries'])

In [8]:
print(n)

588
