## Generate MongoDB Query using function calling

In [None]:
!pip install python-dotenv
!pip install pymongo
!pip install tenacity
!pip install termcolor 
!pip install openai
!pip install requests
!pip install json5

In [1]:
import json
import json5
import ast
from openai import OpenAI
import os
from termcolor import colored
from pymongo import MongoClient
from dotenv import load_dotenv

In [2]:
_ = load_dotenv()
openai_client = OpenAI()

GPT_MODEL = "gpt-3.5-turbo-1106"
#GPT_MODEL = "gpt-4-1106-preview"

In [3]:
def pretty_print_conversation(messages):
    role_to_color = {
        "system": "red",
        "user": "green",
        "assistant": "blue",
        "function": "magenta",
    }
    
    for message in messages:
        if message["role"] == "system":
            print(colored(f"system: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "user":
            print(colored(f"user: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "assistant" and message.get("function_call"):
            print(colored(f"assistant: {message['function_call']}\n", role_to_color[message["role"]]))
        elif message["role"] == "assistant" and not message.get("function_call"):
            print(colored(f"assistant: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "function":
            print(colored(f"function ({message['name']}): {message['content']}\n", role_to_color[message["role"]]))

In [12]:
def chat_completion(messages, functions=None, model=GPT_MODEL):
    chat_completion = openai_client.chat.completions.create(
        model=GPT_MODEL,
        messages=messages,
        functions=functions,
        temperature=0,
        response_format={"type": "json_object"}
    )
    print(f"chat_completion: {chat_completion}")
    chat_completion_json = json.loads(chat_completion.model_dump_json())
    assistant_message = chat_completion_json["choices"][0]["message"]
    finish_reason = chat_completion_json["choices"][0]["finish_reason"]
    
    if assistant_message.get("tool_calls") is None:
        assistant_message.pop("tool_calls", None)
    print(f"Assistant full message: {assistant_message}")

    chat_response = {
        "assistant_message": assistant_message,
        "finish_reason": finish_reason
    }
    
    return chat_response

In [5]:
functions = [
    {
        "name": "verify_query",
        "description": f"""This function is used to verify that the MongoDB query produces the expected output data for the given user input. 
                        Input to this function should be a fully formed MongoDB query.""",
        "parameters": {
            "type": "object",
            "properties": {
                "inputData": {
                    "type": "string",
                    "description": "The sample input in JSON format given by the user",
                },
                "expectedOutput": {
                    "type": "string",
                    "description": "The expected output data in JSON format given by the user",
                },
                "mongoDBQuery": {
                    "type": "string",
                    "description": "MongoDB aggregation pipeline stages array that produces the expected output for the given input data. \
                    This field corresponds to just the list of stages in the aggregation pipeline \
                    and shouldn't contain the `db.collection.aggregate` prefix"
                }
            },
            "required": ["inputData", "expectedOutput", "mongoDBQuery"]
        }
    }
]

In [6]:
def get_database():
    CONNECTION_STRING = "mongodb+srv://selvam:vnA0WRL9OJqVyzrC@playground.cwfqcov.mongodb.net/?retryWrites=true&w=majority"
    DB_NAME="pymongo_tutorial"
    client = MongoClient(CONNECTION_STRING)
    return client[DB_NAME]

def insert_data(collection, data_array):
    if len(data_array) != 0 and collection is not None:
        collection.insert_many(data_array)

def delete_data(collection):
    if collection is not None:
        collection.delete_many({})

def execute_query(collection, aggregation_pipeline=None):
    if aggregation_pipeline:
        result = collection.aggregate(aggregation_pipeline)
    else:
        raise ValueError("Aggregation pipeline must be provided.")
    return list(result)

In [7]:
def verify_query(input_data, expected_output, query): 
    db = get_database()
    collection = db["mqg"]

    delete_data(collection)
    insert_data(collection, input_data)

    result_from_db = execute_query(collection, query)
    print(f"Result: {result_from_db}")

    # Sort the lists based on a canonical representation of each dictionary
    sorted_list1 = sorted(expected_output, key=lambda x: json.dumps(x, sort_keys=True))
    sorted_list2 = sorted(result_from_db, key=lambda x: json.dumps(x, sort_keys=True))
    
    return sorted_list1 == sorted_list2

In [8]:
def execute_function_call(message):
    success = { "result": "success", "message": "This query produces the expected output" }
    failure = { "result": "failure", "message": "This query doesn't produce the expected output" }
    error = { "result": "error" }
    
    if message["function_call"]["name"] == "verify_query":
        arguments = json.loads(message["function_call"]["arguments"])
        print(f"Assistant Generated Function Arguments: {arguments}")
        
        #Parsing the arguments that are strings to python lists
        input_data = ast.literal_eval(arguments["inputData"])
        expected_output = ast.literal_eval(arguments["expectedOutput"])
        query = json5.loads(arguments["mongoDBQuery"])
        
        result_bool = verify_query(input_data, expected_output, query)
        print(f"Results Match: {result_bool}")
            
        return success if result_bool else failure
    else:
        error.update({"message": f"Error: function {message['function_call']['name']} does not exist"})
        return error

In [9]:
input_data = [
    {
        "name": "Sachin",
        "team": "India"
    },
    {
        "name": "Sourav",
        "team": "India"
    },
    {
        "name": "Lara",
        "team": "West Indies"
    }
]

output_data = [
    {
        "team": "India",
        "playerCount": 2
    },
    {
        "team": "West Indies",
        "playerCount": 1
    }
]

In [10]:
system_prompt = f"""
You are a MongoDB expert with great expertise in writing MongoDB queries 
for any given data to produce an expected output.
"""

user_prompt = f"""
Below are your tasks to perform. Follow the instructions entirely without missing anything. 

1. Your task is to write a MongoDB Query, specifically an aggregation pipeline\
that would produce the expected output for the given input.

2. Verify that executing the generated query actually produces the output \
matching the given expected output.

3. The final response should always be returned in JSON with the following fields.
```
mongoDBQuery: The MongoDB aggregation pipeline to produce the expected output for a given input.\
This field corresponds to just the list of stages in the aggregation pipeline \
and shouldn't contain the "db.collection.aggregate" prefix.

queryExplanation: A detailed explanation for the query that was returned.
```

Input data: {input_data} 
Expected output data: {output_data}
"""

user_prompt_failure_scenario = f"""
The query that you provided as argument to the function call didn't produce the expected output. \
Try again to write a new query that would produce the expected output
"""

In [None]:
####This cell needs to be deleted####
# messages = []
# messages.append({"role": "system", "content": system_prompt})
# messages.append({"role": "user", "content": user_prompt})

# assistant_message = chat_completion(messages, functions)
# messages.append(assistant_message)

# pretty_print_conversation(messages)
# #print(json.dumps(messages, indent=2))

In [None]:
####This cell needs to be deleted####

# if assistant_message.get("function_call"):
#     result = execute_function_call(assistant_message)
#     print(f"Results after invoking the function: {result}")

#     messages.append({"role": "function", "name": assistant_message["function_call"]["name"], "content": result["message"]})
#     if result["result"] == "failure":
#         messages.append({"role": "user", "content": user_prompt_failure_scenario})
    
#     assistant_message = chat_completion(messages, functions)
#     messages.append(assistant_message)
    
#     pretty_print_conversation(messages)

In [11]:
counter = 1
max_loop = 3

messages = []
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_prompt})

chat_response = chat_completion(messages, functions)
assistant_message = chat_response["assistant_message"]
finish_reason = chat_response["finish_reason"]

messages.append(assistant_message)

pretty_print_conversation(messages)
#print(json.dumps(messages, indent=2))

while counter < max_loop and finish_reason != "stop":
    print(f"Counter: {counter}")
    counter += 1
    
    if finish_reason == "function_call":
        result = execute_function_call(assistant_message)
        print(f"Results after invoking the function: {result}")
    
        messages.append({"role": "function", "name": assistant_message["function_call"]["name"], "content": result["message"]})
        if result["result"] == "failure":
            messages.append({"role": "user", "content": user_prompt_failure_scenario})
        
        chat_response = chat_completion(messages, functions)
        assistant_message = chat_response["assistant_message"]
        finish_reason = chat_response["finish_reason"]
        
        messages.append(assistant_message)
        pretty_print_conversation(messages)
    else:
        #Unknown condition, exit the loop
        finish_reason = "stop"

print(f"Final Response from the assistant: {assistant_message['content']}")

chat_completion: ChatCompletion(id='chatcmpl-8XjgTAbQ4IJq7Fb7DRexLQZ3k8AZL', choices=[Choice(finish_reason='function_call', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=FunctionCall(arguments='{"inputData":"[{\'name\': \'Sachin\', \'team\': \'India\'}, {\'name\': \'Sourav\', \'team\': \'India\'}, {\'name\': \'Lara\', \'team\': \'West Indies\'}]","expectedOutput":"[{\'team\': \'India\', \'playerCount\': 2}, {\'team\': \'West Indies\', \'playerCount\': 1}]","mongoDBQuery":"[{\'$group\': {\'_id\': \'$team\', \'playerCount\': {\'$sum\': 1}}}, {\'$project\': {\'team\': \'$_id\', \'playerCount\': 1, \'_id\': 0}}]"}', name='verify_query'), tool_calls=None))], created=1703051125, model='gpt-3.5-turbo-1106', object='chat.completion', system_fingerprint='fp_772e8125bb', usage=CompletionUsage(completion_tokens=144, prompt_tokens=394, total_tokens=538))
[31msystem: 
You are a MongoDB expert with great expertise in writing MongoDB queries 
for