In [1]:
from openai import OpenAI
from prompts import colloquial_style_aug_sys_prompt
from functools import cache

import os
import json
from dotenv import load_dotenv
load_dotenv()

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

In [2]:
product_aug_sys_prompt = """\
You are an AI assistant specializing in data augmentation tasks.
Remember:
1. Keep the core message and inquiry intact.
2. For the given queries, modify them such that:
   a) If the query mentions a general product category or could apply to multiple products, create variations using the specific products: smartphone, laptop, and washing machine.
   b) If the query already mentions a specific product or is not related to product inquiries, leave it as is.
   c) For the given query, if a certain product does not make sense to be included, then feel free to skip that product only.
   d) Maintain the original tone, style, and level of formality of the query.
3. Return the augmented queries in JSON format with the key "customer_queries".
4. If the query cannot be augmented or the augmentation does not make sense or is not product-related, return null for "customer_queries".

Examples:

Example input 1:
"When will my order arrive? It's been a week since I placed it."

Example output 1 in JSON format:
{
    "customer_queries": [
        "When will my smartphone arrive? It's been a week since I placed the order.",
        "When will my laptop arrive? It's been a week since I placed the order.",
        "When will my washing machine arrive? It's been a week since I placed the order."
    ]
}


Example input 2:
"hey! quick question – is there a time limit for when i gotta report if i got the wrong item? i'd really appreciate your help!"

Example output 2 in JSON format:
{
    "customer_queries": [
        "hey! quick question – is there a time limit for when i gotta report if i got the wrong smartphone? i'd really appreciate your help!",
        "hey! quick question – is there a time limit for when i gotta report if i got the wrong laptop? i'd really appreciate your help!",
        "hey! quick question – is there a time limit for when i gotta report if i got the wrong washing machine? i'd really appreciate your help!"
    ]
}

Example input 3:
"Are there any exceptions to the types of items I can order if I want free shipping in the 4-hour window?"

Example output 3 in JSON format:
{
    "customer_queries": null
}
###
"""

In [3]:
augmentable = open("product_augmentable.txt", 'r').readlines()
augmentable = [a.strip() for a in augmentable]

In [4]:
def augment_queries(queries: list[str]):
    augmented = []
    for query in queries:
        res = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": product_aug_sys_prompt},
                {"role": "user", "content": query},
            ],
            response_format={"type": "json_object"},
        )
        res = json.loads(res.choices[0].message.content)
        if res["customer_queries"]:
            augmented.extend(res["customer_queries"])

    return augmented

In [5]:
def persist(data: dict, file):
    """Persist data to the given file."""
    json.dump(data, file, ensure_ascii=False)
    file.write('\n')

In [6]:
def main(query_file_name: str):
    query_file = open(query_file_name, 'r')
    data_file = open("product_aug.jsonl", 'a')
    augmented = []
    for line in query_file:
        temp = json.loads(line)
        if not temp["title"] in augmentable:
            continue
            
        augmented = augment_queries(temp["customer_queries"])
        augmented = {"customer_queries": augmented}
        augmented.update({"title": temp["title"], "tone": temp["tone"], "style": temp["style"]})

        persist(augmented, data_file)
        break

    query_file.close()
    data_file.close()

In [7]:
main("queries.jsonl")