# Generate MongoDB Query using function calling

In [2]:
!pip install python-dotenv
!pip install pymongo
!pip install termcolor 
!pip install openai
!pip install requests
!pip install json5

Collecting python-dotenv
  Using cached python_dotenv-1.0.0-py3-none-any.whl (19 kB)
Installing collected packages: python-dotenv
Successfully installed python-dotenv-1.0.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Collecting pymongo
  Obtaining dependency information for pymongo from https://files.pythonhosted.org/packages/00/5b/52158a2666945f517e000e6cebb70bc7b36971376c6b63683faa2955ae8e/pymongo-4.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Using cached pymongo-4.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (22 kB)
Collecting dnspython<3.0.0,>=1.16.0 (from pymongo)
  Obtaining dependency information for dnspython<3.0.0,>=1.16.0 from https://files.pythonhosted.org/packages/f6/b4/0a9bee52c50f226a3cbfb54263d02bb421c7f2adc13

In [3]:
import json
import json5
import ast
from openai import OpenAI
import os
from termcolor import colored
from pymongo import MongoClient
from dotenv import load_dotenv

## Load the OpenAI API key

There is a `.env` file that contains the below values and they are loaded as environment variables. 

- `OPENAI_API_KEY="sk-TZP9XNLsdfskh23423jh234"`
- `MONGODB_CONNECTION_STRING="mongodb+srv://selvam:abcdefgh@playground.cwfqcov.mongodb.net/?retryWrites=true&w=majority"`
- `DB_NAME=mongodb_test`
- `COLLECTION_NAME=mqg`

In [None]:
_ = load_dotenv()

MONGODB_CONNECTION_STRING = os.environ.get("MONGODB_CONNECTION_STRING")
DB_NAME = os.environ.get("DB_NAME")
COLLECTION_NAME = os.environ.get("COLLECTION_NAME")

In [4]:

openai_client = OpenAI()

#GPT_MODEL = "gpt-3.5-turbo-1106"
GPT_MODEL = "gpt-4-1106-preview"

## Define a custom function

In [7]:
tools = [
    {
        "type": "function",
        "function": {
            "name": "verify_query",
            "description": f"""This function is used to verify that the MongoDB query produces the expected output data for the given user input. 
                            Input to this function should be a fully formed MongoDB query.""",
            "parameters": {
                "type": "object",
                "properties": {
                    "inputData": {
                        "type": "string",
                        "description": "The sample input in JSON format given by the user",
                    },
                    "expectedOutput": {
                        "type": "string",
                        "description": "The expected output data in JSON format given by the user",
                    },
                    "mongoDBQuery": {
                        "type": "string",
                        "description": "MongoDB aggregation pipeline stages array that produces the expected output for the given input data. \
                        This field corresponds to just the list of stages in the aggregation pipeline \
                        and shouldn't contain the `db.collection.aggregate` prefix"
                    }
                },
                "required": ["inputData", "expectedOutput", "mongoDBQuery"]
            }
        }
    }
]

## Define function to invoke Chat Completion API

In [6]:
def chat_completion(messages, tools=None, tool_choice=None, model=GPT_MODEL):
    chat_completion = openai_client.chat.completions.create(
        model=GPT_MODEL,
        messages=messages,
        tools=tools,
        tool_choice=tool_choice,
        temperature=0,
        response_format={"type": "json_object"}
    )
    print(f"chat_completion: {chat_completion}")
    chat_completion_json = json.loads(chat_completion.model_dump_json())
    assistant_message = chat_completion_json["choices"][0]["message"]
    finish_reason = chat_completion_json["choices"][0]["finish_reason"]

    # This is done to avoid a bug when we invoke the Chat Completion API
    # the next time around with the AI returned message in chat history
    # that has function_call as None. 
    if assistant_message.get("function_call") is None:
        assistant_message.pop("function_call", None)
    
    #print(f"Assistant full message: {assistant_message}")

    chat_response = {
        "assistant_message": assistant_message,
        "finish_reason": finish_reason
    }
    
    return chat_response

## Utility functions to read / write data in MongoDB 

In [8]:
def get_database():
    client = MongoClient(MONGODB_CONNECTION_STRING)
    return client[DB_NAME]

def insert_data(collection, data_array):
    if len(data_array) != 0 and collection is not None:
        collection.insert_many(data_array)

def delete_data(collection):
    if collection is not None:
        collection.delete_many({})

def execute_query(collection, aggregation_pipeline=None):
    if aggregation_pipeline:
        result = collection.aggregate(aggregation_pipeline)
    else:
        raise ValueError("Aggregation pipeline must be provided.")
    return list(result)

## Custom function tool implementation

In [9]:
def verify_query(input_data, expected_output, query): 
    db = get_database()
    collection = db[COLLECTION_NAME]

    delete_data(collection)
    insert_data(collection, input_data)

    result_from_db = execute_query(collection, query)
    print(f"Result: {result_from_db}")

    # Sort the lists based on a canonical representation of each dictionary
    sorted_list1 = sorted(expected_output, key=lambda x: json.dumps(x, sort_keys=True))
    sorted_list2 = sorted(result_from_db, key=lambda x: json.dumps(x, sort_keys=True))
    
    return sorted_list1 == sorted_list2

## Function that invokes the custom function tool 

In [10]:
def execute_function_call(message):
    success = { "result": "success", "message": "This query produces the expected output" }
    failure = { "result": "failure", "message": "This query doesn't produce the expected output" }
    error = { "result": "error" }
    
    if message["tool_calls"][0]["function"]["name"] == "verify_query":
        arguments = json.loads(message["tool_calls"][0]["function"]["arguments"])
        print(f"Assistant Generated Function Arguments: {arguments}")
        
        #Parsing the arguments that are strings to python lists
        input_data = ast.literal_eval(arguments["inputData"])
        expected_output = ast.literal_eval(arguments["expectedOutput"])
        query = json5.loads(arguments["mongoDBQuery"])
        
        result_bool = verify_query(input_data, expected_output, query)
        print(f"Results Match: {result_bool}")
            
        return success if result_bool else failure
    else:
        error.update({"message": f"Error: function {message['function_call']['name']} does not exist"})
        return error

## Data

In [11]:
input_data = [
    { "name": "Sachin", "team": "India" },
    { "name": "Sourav", "team": "India" },
    { "name": "Lara", "team": "West Indies" }
]

output_data = [
    { "team": "India", "playerCount": 2 },
    { "team": "West Indies", "playerCount": 1 }
]

## Prompts

In [12]:
system_prompt = f"""
You are a MongoDB expert with great expertise in writing MongoDB queries 
for any given data to produce an expected output.
"""

user_prompt = f"""
Below are your tasks to perform. Follow the instructions entirely without missing anything. 

1. Your task is to write a MongoDB Query, specifically an aggregation pipeline\
that would produce the expected output for the given input.

2. Verify that executing the generated query actually produces the output \
matching the given expected output.

3. The final response should always be returned in JSON with the following fields.

4. Do not make a call to the same tool twice in a single response. 
```
mongoDBQuery: The MongoDB aggregation pipeline to produce the expected output for a given input.\
This field corresponds to just the list of stages in the aggregation pipeline \
and shouldn't contain the "db.collection.aggregate" prefix.

queryExplanation: A detailed explanation for the query that was returned.
```

Input data: {input_data} 
Expected output data: {output_data}
"""

user_prompt_failure_scenario = f"""
The query that you provided as argument to the function call didn't produce the expected output. \
Try again to write a new query that would produce the expected output
"""

## Utility function to format the chat conversation

In [5]:
def pretty_print_conversation(messages):
    role_to_color = {
        "system": "red",
        "user": "green",
        "assistant": "blue",
        "tool": "magenta"
    }
    
    for message in messages:
        if message["role"] == "system":
            print(colored(f"system: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "user":
            print(colored(f"user: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "assistant" and message.get("tool_calls"):
            print(colored(f"assistant: {message['tool_calls'][0]['function']}\n", role_to_color[message["role"]]))
        elif message["role"] == "assistant" and not message.get("tool_calls"):
            print(colored(f"assistant: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "tool":
            print(colored(f"function ({message['name']}): {message['content']}\n", role_to_color[message["role"]]))

## Core logic for this example

In [13]:
counter = 1
max_loop = 3

messages = []
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_prompt})

chat_response = chat_completion(
    messages=messages, 
    tools=tools, 
    tool_choice={"type": "function", "function": {"name": "verify_query"}}
)

assistant_message = chat_response["assistant_message"]
messages.append(assistant_message)

pretty_print_conversation(messages)
#print(json.dumps(messages, indent=2))

while counter < max_loop:
    print(f"Counter: {counter}")
    counter += 1
    
    if assistant_message.get("tool_calls"):
        result = execute_function_call(assistant_message)
        print(f"Results after invoking the function: {result}")

        messages.append({
            "role": "tool", 
            "tool_call_id": assistant_message["tool_calls"][0]['id'], 
            "name": assistant_message["tool_calls"][0]["function"]["name"], 
            "content": result["message"]
        })

        if result["result"] == "failure":
            messages.append({"role": "user", "content": user_prompt_failure_scenario})
            chat_response = chat_completion(
                messages=messages, 
                tools=tools, 
                tool_choice={"type": "function", "function": {"name": "verify_query"}}
            )
        else:
            chat_response = chat_completion(messages, tools)
            
        assistant_message = chat_response["assistant_message"]
        messages.append(assistant_message)
        
        pretty_print_conversation(messages)
    else:
        break

if assistant_message.get("tool_calls"):
    arguments = json.loads(assistant_message["tool_calls"][0]["function"]["arguments"])
    mongoDBQuery = json5.loads(arguments["mongoDBQuery"])
    return_val = {
        "mongoDBQuery": mongoDBQuery,
        "queryExplanation": ""
    }
    print(f"The best response from the assistant: {return_val}")
else:
    print(f"Final Response from the assistant: {assistant_message['content']}")

chat_completion: ChatCompletion(id='chatcmpl-8dVihaUDNh7mr3ZffKOg3JiZ3UUPj', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_7owge0JDeWWmynyfjvdI5SvP', function=Function(arguments='{\n  "inputData": "[{\'name\': \'Sachin\', \'team\': \'India\'}, {\'name\': \'Sourav\', \'team\': \'India\'}, {\'name\': \'Lara\', \'team\': \'West Indies\'}]",\n  "expectedOutput": "[{\'team\': \'India\', \'playerCount\': 2}, {\'team\': \'West Indies\', \'playerCount\': 1}]",\n  "mongoDBQuery": "[{\'$group\': {\'_id\': \'$team\', \'playerCount\': {\'$sum\': 1}}}, {\'$project\': {\'team\': \'$_id\', \'_id\': 0, \'playerCount\': 1}}]"\n}', name='verify_query'), type='function')]))], created=1704427415, model='gpt-4-1106-preview', object='chat.completion', system_fingerprint='fp_3905aa4f79', usage=CompletionUsage(completion_tokens=140, prompt_tokens=421, total_token