In [1]:
from openai import OpenAI
from typing import Literal
from prompts import cheerful_tone_aug_sys_prompt, annoyed_tone_aug_sys_prompt
from functools import cache

import os
import json
from dotenv import load_dotenv
load_dotenv()

True

In [2]:
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

In [3]:
# @cache
def augment_queries(queries: dict[str, list[str]], tone: Literal["cheerful", "annoyed"]):
    """Given a list of queries and the required tone, generate the augmented queries."""
    sys_prompt = {
        "cheerful": cheerful_tone_aug_sys_prompt,
        "annoyed": annoyed_tone_aug_sys_prompt
    }[tone]
    res = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": json.dumps(queries, indent=4)},
        ],
        response_format={"type": "json_object"},
    )
    return res.choices[0].message.content


In [4]:
def persist(data: dict, file):
    """Persist data to the given file."""
    json.dump(data, file, ensure_ascii=False)
    file.write('\n')

In [7]:
def main():
    query_file = open("queries.jsonl", 'r')
    persist_file = open("tone_augmented.jsonl", 'a')
    for line in query_file:
        temp = json.loads(line)
        queries = {"customer_queries": temp["customer_queries"]}
        
        cheerful = augment_queries(queries, "cheerful")
        annoyed = augment_queries(queries, "annoyed")
        
        cheerful = json.loads(cheerful)
        annoyed = json.loads(annoyed)
        
        cheerful.update({"title": temp["title"], "tone": "cheerful", "style": temp["style"]})
        annoyed.update({"title": temp["title"], "tone": "annoyed", "style": temp["style"]})

        persist(cheerful, persist_file)
        persist(annoyed, persist_file)

    query_file.close()
    persist_file.close()

In [8]:
main()