In [None]:
from openai import AzureOpenAI
from dotenv import load_dotenv
import os
import time
import pandas as pd
from jinja2 import Environment, FileSystemLoader
from openai.types.beta.threads import TextContentBlock
from openai.types.beta.threads.runs import ToolCallsStepDetails
from pathlib import Path
import json
from tqdm.notebook import tqdm
from IPython.display import Image
import sys
from sklearn.metrics import accuracy_score

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

from utils.utils import convert_types
from utils.vars import DATA_DIR, EXCEPT_FILES, QUESTION_PATH

load_dotenv()

In [None]:
def create_or_retrieve_assistants(
    client: AzureOpenAI, file_path: str, prompt_path: str, assistant_name: str
) -> str:
    """
    Create or retreive an assistant with the given name and return assistant id
    """
    # get the list of assistants
    df_assistants = pd.DataFrame.from_records(
        [s.to_dict() for s in client.beta.assistants.list().data]
    )
    if not df_assistants.empty:
        df_assistants = df_assistants[df_assistants["name"] == assistant_name]

    if df_assistants.shape[0] == 0:
        # upload the data
        data = client.files.create(file=open(file_path, "rb"), purpose="assistants")
        # read the prompt
        instruction = (
            Environment(loader=FileSystemLoader(".")).get_template(prompt_path).render()
        )
        # create an assistant
        assistant = client.beta.assistants.create(
            name=assistant_name,
            instructions=instruction,
            tools=[{"type": "code_interpreter"}],
            tool_resources={"code_interpreter": {"file_ids": [data.id]}},
            model="gpt-4o",
            temperature=0,
            top_p=1,
        )
    else:
        if df_assistants.shape[0] > 1:
            print("More than one assistant with the same name. Select the first one.")
        assistant_id = df_assistants.iloc[0]["id"]
        assistant = client.beta.assistants.retrieve(assistant_id=assistant_id)

    return assistant.id


def ask_assistant_a_question(client: AzureOpenAI, question: str, assistant_id: str):
    # create a and run a thread
    run = client.beta.threads.create_and_run(
        assistant_id=assistant_id,
        thread={"messages": [{"role": "user", "content": question}]},
    )

    # looping until the run completes or fails
    while run.status in ["queued", "in_progress", "cancelling"]:
        time.sleep(1)
        run = client.beta.threads.runs.retrieve(thread_id=run.thread_id, run_id=run.id)

        if run.status == "completed":
            messages = client.beta.threads.messages.list(thread_id=run.thread_id)

            # format the output
            result_message = []
            result_attachment = []
            for message in messages.data[::-1]:
                for item in message.content:
                    if isinstance(item, TextContentBlock):
                        if item.text.annotations != []:
                            result_attachment.append(
                                {
                                    "file_bytes": client.files.content(
                                        item.text.annotations[0].file_path.file_id
                                    ).read(),
                                    "file_name": Path(
                                        item.text.annotations[0].text
                                    ).name,
                                }
                            )

                        result_message.append(f"{message.role}: {item.text.value}")

            answer_pred = result_message[-1].split(": ", 1)[-1]
            try:
                answer_pred = json.loads(answer_pred)["output"]
            except json.decoder.JSONDecodeError:
                print(f"JSONDecodeError: {answer_pred}")

            # get the code
            result_code = []
            run_steps = client.beta.threads.runs.steps.list(
                thread_id=run.thread_id, run_id=run.id
            )
            for step in run_steps.data[::-1]:
                if isinstance(step.step_details, ToolCallsStepDetails):
                    if len(step.step_details.tool_calls) != 1:
                        print("Weird in step_details.")
                    else:
                        tool_call = step.step_details.tool_calls[0]
                        result_code.append(tool_call.code_interpreter.input)

            return {
                **run.usage.to_dict(),
                "question": question,
                "answer_pred": convert_types(answer_pred),
                "message": result_message,
                "code": result_code,
                "attachment": result_attachment,
                "execution_time_s": run.completed_at - run.created_at,
            }
        elif run.status == "requires_action":
            # the assistant requires calling some functions
            # and submit the tool outputs back to the run
            pass
        else:
            print(run.status)

    # cleanup the assistant
    # client.beta.assistants.delete(assistant.id)
    client.beta.threads.delete(run.thread_id)

In [None]:
# get the client object
client = AzureOpenAI(
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    api_version="2024-05-01-preview",  # only support this version
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
)

In [None]:
# # delete assistant
# print (client.beta.assistants.list().data)
# client.beta.assistants.delete(assistant_id="asst_v3sGqHyIYbaE2FDWDaktod6w")

## Run 3 files x all questions

In [None]:
prompt_path = "prompt/code_interpreter_instruction.jinja2"
assistant_name_prefix = "code_interpreter"

In [None]:
# read questions
df_questions = pd.read_csv(QUESTION_PATH)

In [None]:
df_result = []
for file_path in Path(DATA_DIR).glob("*.csv"):
    if file_path.name in EXCEPT_FILES:
        continue
    print(f"file: {file_path.name}")
    assistant_id = create_or_retrieve_assistants(
        client=client,
        file_path=file_path,
        prompt_path=prompt_path,
        assistant_name=f"{assistant_name_prefix}_{file_path.stem}",
    )
    for _, row in tqdm(df_questions.iterrows(), total=len(df_questions)):
        question = row["question"]
        answer_true = row[Path(file_path).name]

        result = ask_assistant_a_question(
            client=client, question=question, assistant_id=assistant_id
        )
        answer_pred = result["answer_pred"]

        df_result.append(
            {
                **result,
                "file": file_path.name,
                "answer_true": convert_types(answer_true),
            }
        )

In [None]:
df_result = pd.DataFrame(df_result)
for file in df_result["file"].unique():
    df_tmp = df_result[df_result["file"] == file]
    print(
        f"File: {file}; Accuracy: {accuracy_score(df_tmp['answer_true'].astype(str).tolist(), df_tmp['answer_pred'].astype(str).tolist())}"
    )
    for _, row in df_tmp[df_tmp["answer_true"] != df_tmp["answer_pred"]].iterrows():
        print(f"question: {row['question']}")
        print(f"answer_pred: {row['answer_pred']}; answer_true: {row['answer_true']}")
        # debug
        # print the output
        print("\n".join(row["message"]))
        # print the code from steps
        print("\n".join(row["code"]))
        print("*" * 50)
    del df_tmp

In [None]:
with pd.option_context("display.max_rows", None, "display.max_columns", None):
    display(df_result.groupby(["file"]).describe())

## Run one question

In [None]:
# list all assistants
pd.DataFrame.from_records([s.to_dict() for s in client.beta.assistants.list().data])

In [None]:
question = "What is the time column?"
assistant_id = "asst_TvV7npSBPELrEfgeA2XXcewL"
result = ask_assistant_a_question(
    client=client, question=question, assistant_id=assistant_id
)
result

## Appendix: Generate an image

In [None]:
question = "Generate a box plot of the target column using seaborn with text annotation for min, max, q1, q3, and median."
assistant_id = create_or_retrieve_assistants(
    client=client,
    file_path="../../data/air_passengers.csv",
    prompt_path="prompt/code_interpreter_instruction.jinja2",
    assistant_name="code_interpreter_air_passengers_others",
)
result = ask_assistant_a_question(
    client=client, question=question, assistant_id=assistant_id
)

# print the output
print("\n".join(result["message"]))
# print the code from steps
print("\n".join(result["code"]))
result

In [None]:
img = result["attachment"][0]
with open(img["file_name"], "wb") as file:
    file.write(img["file_bytes"])
Image(filename=img["file_name"], width=1000)