In [None]:
!pip install openai
!pip install python-dotenv

In [1]:
from openai import OpenAI
import os
from dotenv import load_dotenv
import time
import json5
import json

In [2]:
_ = load_dotenv()
openai_client = OpenAI()

In [3]:
GPT3_MODEL = "gpt-3.5-turbo-1106"
GPT4_MODEL = "gpt-4-1106-preview"

In [4]:
ex1_input_data = """
[
  {
    "name": "Sachin",
    "team": "India"
  },
  {
    "name": "Sourav",
    "team": "India"
  },
  {
    "name": "Lara",
    "team": "West Indies"
  }
]
"""

ex1_output_data = """
[
 {
   "team": India,
   "playerCount": 2
 },
 {
   "team": "West Indies",
   "playerCount": 1
 }
]
"""

In [5]:
assistant_instructions = """You are a MongoDB expert with great expertise in writing MongoDB queries \
for any given data to produce an expected output."""

In [6]:
def get_user_prompt(input_data, output_data):
    return f"""Your task is to write a MongoDB Query, specifically an aggregation pipeline\
    that would produce the expected output for the given input.

    You will always return a JSON response with the following fields.
    ```
    mongoDBQuery: The MongoDB aggregation pipeline to produce the expected output for a given input.\
    This field corresponds to just the list of stages in the aggregation pipeline \
    and shouldn't contain the "db.collection.aggregate" prefix.
    
    queryExplanation: A detailed explanation for the query that was returned.
    ```
    
    Input data: {input_data} 
    Expected output data: {output_data}
    """

In [7]:
user_prompt = get_user_prompt(ex1_input_data, ex1_output_data)

In [8]:
assistant = openai_client.beta.assistants.create(
    name="MongoDB SME",
    instructions=assistant_instructions,
    model=GPT3_MODEL
)

MONGO_DB_SME_ASSISTANT_ID = assistant.id

In [9]:
thread = openai_client.beta.threads.create()

In [10]:
message = openai_client.beta.threads.messages.create(
    thread_id=thread.id,
    role="user",
    content=user_prompt
)

In [11]:
run = openai_client.beta.threads.runs.create(
    thread_id=thread.id,
    assistant_id=assistant.id
)

In [12]:
max_attempts = 60
sleep_interval = 2

for i in range(max_attempts):
    try:
        print("assistant working...")
        
        updated_run = openai_client.beta.threads.runs.retrieve(
          thread_id=thread.id,
          run_id=run.id
        )
        # Check if the status indicates completion
        if updated_run.status == "completed":
            messages = openai_client.beta.threads.messages.list(
                thread_id=thread.id
            )
            print(f"Assistants Response: {messages.data[0].content[0].text.value}")
            break
    except Exception as e:
        print(f"Error: {str(e)}. Trying again...")
    finally:
        time.sleep(sleep_interval)
else:        
    # If max_attempts reached without completion
    print("Timeout: Assistant didn't respond in time. Please try again.")    

assistant working...
assistant working...
assistant working...
assistant working...
assistant working...
Assistants Response: The MongoDB aggregation pipeline query to produce the expected output for the given input data is as follows:

```json
[
  {
    $group: {
      _id: "$team",
      playerCount: { $sum: 1 }
    }
  },
  {
    $project: {
      _id: 0,
      team: "$_id",
      playerCount: 1
    }
  }
]
```

Explanation:
1. The `$group` stage groups the documents by the "team" field and calculates the count of players for each team using the `$sum` accumulator.
2. The `$project` stage reshapes the output to include only the "team" and "playerCount" fields, while excluding the `_id` field.

This pipeline first groups the documents by the "team" field and then projects the result to include the team name and the count of players, as per the expected output.


## Format the assistant response using function calling

In [13]:
assistant = openai_client.beta.assistants.update(
    MONGO_DB_SME_ASSISTANT_ID,
    tools=[
        {
            "type": "function",
            "function": {
                "name": "formatResponse",
                "description": "Format the assistant's response before responding to user",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "mongoDBQuery": {
                            "type": "string",
                            "description": """The MongoDB aggregation pipeline to produce the expected output for a given input. 
                                           This field corresponds to just the list of stages in the aggregation pipeline 
                                           and shouldn't contain the 'db.collection.aggregate' prefix."""
                        },
                        "queryExplanation": {
                            "type": "string",
                            "description": "A detailed explanation for the query that was returned."
                        } 
                    },
                    "required": ["mongoDBQuery", "queryExplanation"]
                }
            }
        }
    ]
)

In [14]:
def get_user_prompt(input_data, output_data):
    return f"""Your task is to write a MongoDB Query, specifically an aggregation pipeline\
    that would produce the expected output for the given input.

    Important: You will always format the response using the formatResponse tool before responding to user. 
    
    Input data: {input_data} 
    Expected output data: {output_data}
    """

In [15]:
user_prompt = get_user_prompt(ex1_input_data, ex1_output_data)

In [16]:
thread = openai_client.beta.threads.create()

In [17]:
message = openai_client.beta.threads.messages.create(
    thread_id=thread.id,
    role="user",
    content=user_prompt
)

In [18]:
run = openai_client.beta.threads.runs.create(
    thread_id=thread.id,
    assistant_id=assistant.id
)

In [19]:
max_attempts = 60
sleep_interval = 2

for i in range(max_attempts):
    try:      
        updated_run = openai_client.beta.threads.runs.retrieve(
          thread_id=thread.id,
          run_id=run.id
        )
        # Check if the status indicates completion
        if updated_run.status == "requires_action":
            assitant_response = json5.loads(updated_run.required_action\
                                            .submit_tool_outputs\
                                            .tool_calls[0].function.arguments)
            print(f"Assistant Response:\n{assitant_response}")
            break
    except Exception as e:
        print(f"Error: {str(e)}. Trying again...")
    finally:
        time.sleep(sleep_interval)
else:        
    # If max_attempts reached without completion, then assistant call timed out
    print("Timeout: Assistant didn't respond in time. Please try again.")    

Assistant Response:
{'mongoDBQuery': '[{\n  "$group": {\n    "_id": "$team",\n    "playerCount": { "$sum": 1 }\n  }\n},{\n  "$project": {\n    "_id": 0,\n    "team": "$_id",\n    "playerCount": 1\n  }\n}]\n', 'queryExplanation': "The MongoDB aggregation pipeline starts with a $group stage where we group the data by the 'team' field and use the $sum accumulator to count the players for each team. Then, in the $project stage, we reshape the output to include 'team' and 'playerCount' fields and exclude the '_id' field."}
