In [None]:
!pip install openai
!pip install python-dotenv

In [1]:
from openai import OpenAI
import os
from dotenv import load_dotenv
import time
import json5
import json

In [2]:
_ = load_dotenv()
openai_client = OpenAI()

In [3]:
GPT3_MODEL = "gpt-3.5-turbo-1106"
GPT4_MODEL = "gpt-4-1106-preview"

In [4]:
input_data = """
[
  { "name": "Sachin", "team": "India" },
  { "name": "Sourav", "team": "India" },
  { "name": "Lara", "team": "West Indies" }
]
"""

output_data = """
[
 { "team": "India", "playerCount": 2 },
 { "team": "West Indies", "playerCount": 1 }
]
"""

In [5]:
assistant_instructions = """You are a MongoDB expert with great expertise in writing MongoDB queries \
for any given data to produce an expected output."""

In [6]:
def get_user_prompt(input_data, output_data):
    return f"""Your task is to write a MongoDB Query, specifically an aggregation pipeline\
    that would produce the expected output for the given input.

    You will always return a JSON response with the following fields.
    ```
    mongoDBQuery: The MongoDB aggregation pipeline to produce the expected output for a given input.\
    This field corresponds to just the list of stages in the aggregation pipeline \
    and shouldn't contain the "db.collection.aggregate" prefix.
    
    queryExplanation: A detailed explanation for the query that was returned.
    ```
    
    Input data: {input_data} 
    Expected output data: {output_data}
    """

In [7]:
user_prompt = get_user_prompt(input_data, output_data)

In [8]:
assistant = openai_client.beta.assistants.create(
    name="MongoDB SME",
    instructions=assistant_instructions,
    model=GPT3_MODEL
)

MONGO_DB_SME_ASSISTANT_ID = assistant.id

In [9]:
thread = openai_client.beta.threads.create()

In [10]:
message = openai_client.beta.threads.messages.create(
    thread_id=thread.id,
    role="user",
    content=user_prompt
)

In [11]:
run = openai_client.beta.threads.runs.create(
    thread_id=thread.id,
    assistant_id=assistant.id
)

In [12]:
max_attempts = 60
sleep_interval = 2

for i in range(max_attempts):
    try:
        print("assistant working...")
        
        updated_run = openai_client.beta.threads.runs.retrieve(
          thread_id=thread.id,
          run_id=run.id
        )
        # Check if the status indicates completion
        if updated_run.status == "completed":
            messages = openai_client.beta.threads.messages.list(
                thread_id=thread.id
            )
            print(f"Assistants Response: {messages.data[0].content[0].text.value}")
            break
    except Exception as e:
        print(f"Error: {str(e)}. Trying again...")
    finally:
        time.sleep(sleep_interval)
else:        
    # If max_attempts reached without completion
    print("Timeout: Assistant didn't respond in time. Please try again.")    

assistant working...
assistant working...
assistant working...
assistant working...
assistant working...
Assistants Response: Here's the MongoDB aggregation pipeline that can produce the expected output for the given input:

```json
[
  {
    $group: {
      _id: "$team",
      playerCount: { $sum: 1 }
    }
  },
  {
    $project: {
      _id: 0,
      team: "$_id",
      playerCount: 1
    }
  }
]
```

Explanation:
1. `$group` stage is used to group the documents by the `team` field and calculate the count of players within each group using the `$sum` accumulator.
2. The `_id` field is used to group the documents by the `team` field.
3. In the following `$project` stage, the `_id` field is excluded and the `team` and `playerCount` fields are projected.

This pipeline first groups the documents by the `team` field and calculates the count of players for each team. Then, in the projection stage, it reshapes the output to match the expected format.


## Format the assistant response using function calling

In [13]:
assistant = openai_client.beta.assistants.update(
    MONGO_DB_SME_ASSISTANT_ID,
    tools=[
        {
            "type": "function",
            "function": {
                "name": "formatResponse",
                "description": "Format the assistant's response before responding to user",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "mongoDBQuery": {
                            "type": "string",
                            "description": """The MongoDB aggregation pipeline to produce the expected output for a given input. 
                                           This field corresponds to just the list of stages in the aggregation pipeline 
                                           and shouldn't contain the 'db.collection.aggregate' prefix."""
                        },
                        "queryExplanation": {
                            "type": "string",
                            "description": "A detailed explanation for the query that was returned."
                        } 
                    },
                    "required": ["mongoDBQuery", "queryExplanation"]
                }
            }
        }
    ]
)

In [14]:
def get_user_prompt(input_data, output_data):
    return f"""Your task is to write a MongoDB Query, specifically an aggregation pipeline\
    that would produce the expected output for the given input.

    Important: You will always format the response using the formatResponse tool before responding to user. 
    
    Input data: {input_data} 
    Expected output data: {output_data}
    """

In [17]:
user_prompt = get_user_prompt(input_data, output_data)

In [18]:
thread = openai_client.beta.threads.create()

In [19]:
message = openai_client.beta.threads.messages.create(
    thread_id=thread.id,
    role="user",
    content=user_prompt
)

In [20]:
run = openai_client.beta.threads.runs.create(
    thread_id=thread.id,
    assistant_id=assistant.id
)

In [21]:
max_attempts = 60
sleep_interval = 2

for i in range(max_attempts):
    try:      
        updated_run = openai_client.beta.threads.runs.retrieve(
          thread_id=thread.id,
          run_id=run.id
        )
        # Check if the status indicates completion
        if updated_run.status == "requires_action":
            assitant_response = json5.loads(updated_run.required_action\
                                            .submit_tool_outputs\
                                            .tool_calls[0].function.arguments)
            print(f"Assistant Response:\n{assitant_response}")
            break
    except Exception as e:
        print(f"Error: {str(e)}. Trying again...")
    finally:
        time.sleep(sleep_interval)
else:        
    # If max_attempts reached without completion, then assistant call timed out
    print("Timeout: Assistant didn't respond in time. Please try again.")    

Assistant Response:
{'mongoDBQuery': '[{\n  "$group": {\n    "_id": "$team",\n    "playerCount": { "$sum": 1 }\n  }\n},{\n  "$project": {\n    "team": "$_id",\n    "playerCount": 1,\n    "_id": 0\n  }\n}]', 'queryExplanation': "The MongoDB aggregation query consists of two stages:\n1. $group stage: Groups the documents by the 'team' field and calculates the count of players in each team using the $sum aggregation operator.\n2. $project stage: Reshapes the output to display the 'team' and 'playerCount' fields while removing the _id field."}


## Verify the query by using a function tool

Submit the result of the function execution back to the assistant.

In [22]:
execute_query_function_interface = {
    "name": "executeQuery",
    "description": "Execute the MongoDB Query on the given input data to verify the output",
    "parameters": {
        "type": "object",
        "properties": {
            "mongoDBQuery": {
                "type": "string",
                "description": """The MongoDB aggregation pipeline to produce the expected output for a given input. 
                               This field corresponds to just the list of stages in the aggregation pipeline 
                               and shouldn't contain the 'db.collection.aggregate' prefix."""
            },
            "queryExplanation": {
                "type": "string",
                "description": "A detailed explanation for the query that was returned."
            } 
        },
        "required": ["mongoDBQuery", "queryExplanation"]
    }
}

In [23]:
assistant = openai_client.beta.assistants.update(
    MONGO_DB_SME_ASSISTANT_ID,
    tools = [
        {
            "type": "function", 
            "function": execute_query_function_interface
        }
    ]
)

In [24]:
def process_user_input(user_input):
    #Create a new thread
    thread = openai_client.beta.threads.create()
    
    #Add a message with the user query to the thread
    message = openai_client.beta.threads.messages.create(
        thread_id=thread.id,
        role="user",
        content=user_prompt
    )
    
    #Create a run to invoke the assistant
    run = openai_client.beta.threads.runs.create(
        thread_id=thread.id,
        assistant_id=assistant.id
    )
    return thread, run

In [25]:
def get_completed_run(thread, run, max_attempts=60, sleep_interval=2):
    for i in range(max_attempts):
        try:
            run = openai_client.beta.threads.runs.retrieve(
              thread_id=thread.id,
              run_id=run.id
            )
            # Check if the status indicates completion
            if run.status == "completed" or run.status == "requires_action":
                return run
        except Exception as e:
            print(f"Error: {str(e)}. Trying again...")
        finally:
            time.sleep(sleep_interval)
    else:        
        # If max_attempts reached without completion, then assistant call timed out
        return None

### Custom function

In [26]:
def execute_query(mongoDBQuery):
    return "success"

In [27]:
user_prompt = f"""
Your task is to write a MongoDB Query, specifically an aggregation pipeline\
that would produce the expected output for the given input.

Important: You will always execute the query to verify that it produces the expected output.

Input data: {input_data} 
Expected output data: {output_data}
"""

In [28]:
thread, run = process_user_input(user_prompt)

for i in range(3):
    run = get_completed_run(thread, run)

    if run:
        if run.status == "requires_action": 
            tool_call = run.required_action.submit_tool_outputs.tool_calls[0]
            function_name = tool_call.function.name
            arguments = json5.loads(tool_call.function.arguments)
            print(f"Function Name: {function_name}\nArguments: {arguments}")
            
            response = execute_query(arguments["mongoDBQuery"])
            if response == "success":
                print(f"Assistant Response - MongoDB Query: {arguments['mongoDBQuery']}")
                break
            else:
                run = openai_client.beta.threads.runs.submit_tool_outputs(
                    thread_id=thread.id,
                    run_id=run.id,
                    tool_outputs=[
                        {
                            "tool_call_id": tool_call.id,
                            "output": """The generated MongoDB Query 
                                         didn't produce the expected output.
                                         Please try again"""
                        }
                    ]
                )
        elif run.status == "completed":
            messages = openai_client.beta.threads.messages.list(thread.id)
            print(f"Assistant Response: {messages.data[0].content[0].text.value}")
            break
    else: 
        print("Error: Assistant timed out.")
else:
    print("Error: Assistant couldn't produce the query for the given input.")

Function Name: executeQuery
Arguments: {'mongoDBQuery': '[{\n  "$group": {\n    "_id": "$team",\n    "playerCount": { "$sum": 1 }\n  }\n},\n{\n  "$project": {\n    "_id": 0,\n    "team": "$_id",\n    "playerCount": 1\n  }\n} ]', 'queryExplanation': "The aggregation pipeline first groups the data by the 'team' field and calculates the count of players for each team. Then, it projects the output to display only the 'team' and 'playerCount' fields."}
Assistant Response - MongoDB Query: [{
  "$group": {
    "_id": "$team",
    "playerCount": { "$sum": 1 }
  }
},
{
  "$project": {
    "_id": 0,
    "team": "$_id",
    "playerCount": 1
  }
} ]
