## Generate MongoDB Query using function calling

In [None]:
!pip install python-dotenv
!pip install pymongo
!pip install tenacity
!pip install termcolor 
!pip install openai
!pip install requests
!pip install json5

In [1]:
import re
import json
import json5
import ast
import openai
import os
from termcolor import colored
import requests
from tenacity import retry, wait_random_exponential, stop_after_attempt
from pymongo import MongoClient
from dotenv import load_dotenv

In [None]:
_ = load_dotenv()
openai.api_key = os.environ['OPENAI_API_KEY']

GPT_MODEL = "gpt-3.5-turbo-0613"

In [2]:
def pretty_print_conversation(messages):
    role_to_color = {
        "system": "red",
        "user": "green",
        "assistant": "blue",
        "function": "magenta",
    }
    
    for message in messages:
        if message["role"] == "system":
            print(colored(f"system: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "user":
            print(colored(f"user: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "assistant" and message.get("function_call"):
            print(colored(f"assistant: {message['function_call']}\n", role_to_color[message["role"]]))
        elif message["role"] == "assistant" and not message.get("function_call"):
            print(colored(f"assistant: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "function":
            print(colored(f"function ({message['name']}): {message['content']}\n", role_to_color[message["role"]]))

In [3]:
@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request(messages, functions=None, function_call=None, model=GPT_MODEL):
    headers = {
        "Content-Type": "application/json",
        "Authorization": "Bearer " + openai.api_key,
    }
    json_data = {"model": model, "messages": messages, "temperature": 0}
    if functions is not None:
        json_data.update({"functions": functions})
    if function_call is not None:
        json_data.update({"function_call": function_call})
    try:
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers=headers,
            json=json_data,
        )
        return response
    except Exception as e:
        print("Unable to generate ChatCompletion response")
        print(f"Exception: {e}")
        return e

In [4]:
def get_database():
    CONNECTION_STRING = "mongodb+srv://selvam:vnA0WRL9OJqVyzrC@playground.cwfqcov.mongodb.net/?retryWrites=true&w=majority"
    DB_NAME="pymongo_tutorial"
    client = MongoClient(CONNECTION_STRING)
    return client[DB_NAME]

def insert_data(collection, data_array):
    if len(data_array) != 0 and collection is not None:
        collection.insert_many(data_array)

def delete_data(collection):
    if collection is not None:
        collection.delete_many({})

def get_query_type(query):
    pattern = r'db\.collection\.(.*?)\('
    match = re.search(pattern, query)
    if match:
        extracted_query_type = match.group(1)
        print("Extracted string:", extracted_query_type)
        return extracted_query_type
    return ""

# def is_find_query(query):
#     parsed_query = ast.literal_eval(query)
#     if isinstance(parsed_query, dict):
#         return True
#     return False

# def is_aggregation_pipeline(query):
#     if isinstance(query, list) and all(isinstance(stage, dict) for stage in query):
#         return True
#     return False

def verify_results(expected_data, actual_result): 
    # Convert the JSON objects to Python dictionaries
    print(f"expected_data: {expected_data}")
    print(f"actual_result: {actual_result}")

    # Sort the lists based on a canonical representation of each dictionary
    sorted_list1 = sorted(expected_data, key=lambda x: json.dumps(x, sort_keys=True))
    sorted_list2 = sorted(actual_result, key=lambda x: json.dumps(x, sort_keys=True))

    # Compare the sorted lists
    return sorted_list1 == sorted_list2

def extract_query_from_command(command):
    # Use regular expressions to extract the query portion
    match = re.search(r'\((.*?)\)', command)
    if match:
        return match.group(1)
    return None

def generic_query(collection, find_query=None, aggregation_pipeline=None, projection=None):
    if find_query and aggregation_pipeline:
        raise ValueError("Both find_query and aggregation_pipeline cannot be provided simultaneously.")
    
    if find_query:
        result = collection.find(find_query, projection)
    elif aggregation_pipeline:
        result = collection.aggregate(aggregation_pipeline)
    else:
        raise ValueError("Either find_query or aggregation_pipeline must be provided.")
    
    return list(result)

In [5]:
input_data = [
    {
        "name": "Sachin",
        "team": "India"
    },
    {
        "name": "Sourav",
        "team": "India"
    },
    {
        "name": "Lara",
        "team": "West Indies"
    }
]

output_data = [
    {
        "team": "India",
        "playerCount": 2
    },
    {
        "team": "West Indies",
        "playerCount": 1
    }
]

In [6]:
functions = [
    {
        "name": "verify_query",
        "description": f"""This function is used to verify that the MongoDB query produces the expected output data for the given user input. 
                        Input to this function should be a fully formed MongoDB query.""",
        "parameters": {
            "type": "object",
            "properties": {
                "inputData": {
                    "type": "string",
                    "description": "The sample input in JSON format given by the user",
                },
                "expectedOutput": {
                    "type": "string",
                    "description": "The expected output data in JSON format given by the user",
                },
                "query": {
                    "type": "string",
                    "description": "MongoDB query to produce the expected output for the given input data"
                }
            },
            "required": ["inputData", "query"]
        }
    }
]

In [7]:
def verify_query(input_data, expected_output, query): 
    db = get_database()
    collection = db["mqg"]
    EXCLUDE_ID_PROJECTION = {"_id": 0}

    delete_data(collection)
    insert_data(collection, input_data)

    query_type = get_query_type(query)
    print(f"MongoDB query type: {query_type}")
    
    #extracted_query is in the form of Javascript object where the keys and values are not enclosed in quotes
    extracted_query = extract_query_from_command(query)
    print(f"Extracted Query: {extracted_query}")
    
    #Converting the Javascript object to python object by enclosing the keys and values with quotes
    py_extracted_query = json5.loads(extracted_query)
    print(f"py_extracted_query): {py_extracted_query}")
    
    result = []
    if query_type == "find": 
        print("It's a find query.")
        result_from_db = generic_query(collection, find_query=py_extracted_query, projection=EXCLUDE_ID_PROJECTION)
    elif query_type == "aggregate": 
        print("It's an aggregation pipeline query")
        result_from_db = generic_query(collection, aggregation_pipeline=py_extracted_query)
    else:
        print(f"Invalid query: {query}")

    print(f"Result: {result_from_db}")
    
    return verify_results(expected_output, result_from_db)

In [23]:
def execute_function_call(message):
    success = { "result": "success", "message": "This query produces the expected output" }
    failure = { "result": "failure", "message": "This query doesn't produce the expected output" }
    error = { "result": "error" }
    
    if message["function_call"]["name"] == "verify_query":
        arguments = json.loads(message["function_call"]["arguments"])
        print(f"Assistant Generated Function Arguments: {arguments}")
        
        #Parsing the data to python dictionary as pymongo collection.insert_many expects data in the form of a dict
        input_data = ast.literal_eval(arguments["inputData"])
        
        #Parsing the string to python list, so that it can be used in python object comparison in verify_results fn
        expected_output = ast.literal_eval(arguments["expectedOutput"])
        
        query = arguments["query"]
        
        #Print the different arguments to the function call
        print(f"Input Data: {input_data}")
        print(f"Expected Output: {expected_output}")
        print(f"Assistant Generated Query: {query}")
        
        result_bool = verify_query(input_data, expected_output, query)
        print(f"Results Match: {result_bool}")

        if result_bool:
            success.update({"query": query})
        else:
            failure.update({"query": query})
            
        return success if result_bool else failure
    else:
        error.update({"message": f"Error: function {message['function_call']['name']} does not exist"})
        return error

In [9]:
system_prompt = f"""
Your task is to write a MongoDB query for the given input data and expected output. 

Always verify that the MongoDB query that you write, 
actually produces the expected output data for the given input using the function tool.
"""

user_prompt = f"""Input data: {input_data} and Expected output data: {output_data}"""

In [21]:
messages = []
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_prompt})
chat_response = chat_completion_request(messages, functions)
print(f"Assistant Response: {chat_response.json()}")
assistant_message = chat_response.json()["choices"][0]["message"]
messages.append(assistant_message)
pretty_print_conversation(messages)

Assistant Response: {'id': 'chatcmpl-892df8mfuPbHT4816AyXbtxDXZA1x', 'object': 'chat.completion', 'created': 1697165787, 'model': 'gpt-3.5-turbo-0613', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': None, 'function_call': {'name': 'verify_query', 'arguments': '{\n  "inputData": "[{\'name\': \'Sachin\', \'team\': \'India\'}, {\'name\': \'Sourav\', \'team\': \'India\'}, {\'name\': \'Lara\', \'team\': \'West Indies\'}]",\n  "expectedOutput": "[{\'team\': \'India\', \'playerCount\': 2}, {\'team\': \'West Indies\', \'playerCount\': 1}]",\n  "query": "db.collection.aggregate([{$group: {_id: \'$team\', playerCount: {$sum: 1}}}, {$project: {team: \'$_id\', playerCount: 1, _id: 0}}])"\n}'}}, 'finish_reason': 'function_call'}], 'usage': {'prompt_tokens': 245, 'completion_tokens': 143, 'total_tokens': 388}}
[31msystem: 
Your task is to write a MongoDB query for the given input data and expected output. 

Always verify that the MongoDB query that you write, 
actually produce

In [22]:
if assistant_message.get("function_call"):
    result = execute_function_call(assistant_message)
    print(f"Results after invoking the function: {result}")
    messages.append({"role": "function", "name": assistant_message["function_call"]["name"], "content": result["message"]})
    chat_response = chat_completion_request(messages, functions)
    assistant_message = chat_response.json()["choices"][0]["message"]
    messages.append(assistant_message)
    pretty_print_conversation(messages)

Assistant Generated Function Arguments: {'inputData': "[{'name': 'Sachin', 'team': 'India'}, {'name': 'Sourav', 'team': 'India'}, {'name': 'Lara', 'team': 'West Indies'}]", 'expectedOutput': "[{'team': 'India', 'playerCount': 2}, {'team': 'West Indies', 'playerCount': 1}]", 'query': "db.collection.aggregate([{$group: {_id: '$team', playerCount: {$sum: 1}}}, {$project: {team: '$_id', playerCount: 1, _id: 0}}])"}
Input Data: [{'name': 'Sachin', 'team': 'India'}, {'name': 'Sourav', 'team': 'India'}, {'name': 'Lara', 'team': 'West Indies'}]
Expected Output: [{'team': 'India', 'playerCount': 2}, {'team': 'West Indies', 'playerCount': 1}]
Assistant Generated Query: db.collection.aggregate([{$group: {_id: '$team', playerCount: {$sum: 1}}}, {$project: {team: '$_id', playerCount: 1, _id: 0}}])
Extracted string: aggregate
MongoDB query type: aggregate
Extracted Query: [{$group: {_id: '$team', playerCount: {$sum: 1}}}, {$project: {team: '$_id', playerCount: 1, _id: 0}}]
py_extracted_query): [{'$g