<a href="https://colab.research.google.com/github/vichitrarora/warp/blob/main/rewardfunctionfinal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [24]:
!pip install fuzzywuzzy



In [25]:
import json
import re
from fuzzywuzzy import fuzz

def extract_json_from_find(query: str):
    try:
        start = query.find("find(")
        if start == -1:
            raise ValueError("⚠️ Error: Invalid MongoDB Query Format. Ensure the query is correctly formatted.")
        brace_start = query.find("{", start)
        if brace_start == -1:
            raise ValueError("⚠️ Error: Could not find a valid JSON object in find().")
        stack = []
        for i in range(brace_start, len(query)):
            if query[i] == "{":
                stack.append(i)
            elif query[i] == "}":
                stack.pop()
                if not stack:
                    json_text = query[brace_start:i+1]
                    json_text = json_text.replace("'", '"')  # Ensure valid JSON
                    return json.loads(json_text)
    except Exception as e:
        raise ValueError(f"⚠️ Error parsing MongoDB query: {e}")
    raise ValueError("⚠️ Error: Mismatched braces in MongoDB query.")


In [26]:
def convert_mongo_shell_to_json(mongo_shell_query: str):
    try:
        json_dict = json.loads(mongo_shell_query)
        return json_dict
    except json.JSONDecodeError:
        try:
            json_query = extract_json_from_find(mongo_shell_query)
            json_dict = json.loads(json_query)
            return json_dict
        except (json.JSONDecodeError, ValueError) as e:
            print(f"⚠️ Error: {e}")
            return None

In [27]:
def convert_schema_shell_to_json(schema_shell: str):
    try:
        schema_json = schema_shell.replace("'", '"')
        schema_dict = json.loads(schema_json)
        return schema_dict
    except json.JSONDecodeError:
        print("⚠️ Error: Invalid Schema Format. Ensure it is correctly formatted.")
        return None

In [28]:
# Modified: Do not propagate ignored keys as prefix.
def extract_fields_from_mongo(query: dict):
    fields = set()
    IGNORED_KEYS = {"collection", "filter"}
    MONGO_OPERATORS = {"$gte", "$lte", "$eq", "$ne", "$in", "$exists", "$regex", "$or", "$and"}

    def extract_keys(obj, prefix=""):
        if isinstance(obj, dict):
            for key, value in obj.items():
                # If key is ignored or an operator, do not add it or its key as prefix.
                if key in IGNORED_KEYS or key in MONGO_OPERATORS:
                    extract_keys(value, prefix)
                else:
                    full_key = f"{prefix}.{key}" if prefix else key
                    fields.add(full_key)
                    extract_keys(value, full_key)
        elif isinstance(obj, list):
            for item in obj:
                extract_keys(item, prefix)
    extract_keys(query)
    return fields

In [29]:
# Modified: Extract fields differently based on schema format.
def extract_fields_from_schema(schema: dict):
    fields = set()
    # Check if schema uses document format (contains "document" key)
    is_document_format = any("document" in val for val in schema.values() if isinstance(val, dict))

    if is_document_format:
        # For document-based schema, extract only leaf field names.
        for collection_data in schema.values():
            if "document" in collection_data and "properties" in collection_data["document"]:
                props = collection_data["document"]["properties"]
                def extract_leaf_fields(obj):
                    leaves = set()
                    if isinstance(obj, dict):
                        for key, value in obj.items():
                            if isinstance(value, dict) and "properties" in value:
                                leaves |= extract_leaf_fields(value["properties"])
                            else:
                                leaves.add(key)
                    return leaves
                leaves = extract_leaf_fields(props)
                fields |= leaves
    else:
        # For non-document format, assume nested structure with "fields" keys.
        def extract_nested_fields(obj, prefix=""):
            if isinstance(obj, dict):
                if "fields" in obj and isinstance(obj["fields"], list):
                    for field in obj["fields"]:
                        full_field = f"{prefix}.{field}" if prefix else field
                        fields.add(full_field)
                else:
                    for key, value in obj.items():
                        new_prefix = f"{prefix}.{key}" if prefix else key
                        extract_nested_fields(value, new_prefix)
            elif isinstance(obj, list):
                for item in obj:
                    extract_nested_fields(item, prefix)
        extract_nested_fields(schema)
    return fields

In [30]:
def extract_collection_from_mongo_query(mongo_query: str):
    try:
        mongo_dict = json.loads(mongo_query)
        if "collection" in mongo_dict:
            return mongo_dict["collection"]
    except json.JSONDecodeError:
        pass
    match = re.search(r'db\.(\w+)\.find(?:One)?\s*\(', mongo_query)
    return match.group(1) if match else None


In [31]:
def extract_valid_collections_from_schema(schema: dict):
    if not isinstance(schema, dict):
        print("⚠️ Error: Schema is not a dictionary.")
        return set()
    return set(schema.keys())

In [32]:
def is_schema_linking_correct(nlp_query: str, mongo_output: dict, mongo_query: str, database: str, schema: dict, threshold=80):
    query_fields = extract_fields_from_mongo(mongo_output)
    schema_fields = extract_fields_from_schema(schema)

    if not query_fields:
        return 0.0

    total_score = 0
    for field in query_fields:
        simplified_query = field.lower()
        full_matches = []
        best_partial = 0
        for valid_field in schema_fields:
            # For non-document schema, compare only the last part of the field.
            # For document schema, fields are already leaves.
            simplified_valid = valid_field.split(".")[-1].lower()
            score = fuzz.ratio(simplified_query, simplified_valid)
            if score >= threshold:
                full_matches.append(valid_field)
            if score > best_partial:
                best_partial = score
        if full_matches:
            total_score += 1 / len(full_matches)
        elif best_partial >= 50:
            total_score += best_partial / 100
        else:
            total_score += 0
    schema_score = total_score / len(query_fields) if query_fields else 0.0
    predicted_collection = extract_collection_from_mongo_query(mongo_query)
    valid_collections = extract_valid_collections_from_schema(schema)
    if predicted_collection not in valid_collections:
        schema_score *= 0.5
    return round(schema_score, 2)

In [33]:
nlp_query = input("Enter NLP Query: ")
mongo_shell_query = input("Enter MongoDB Output Query (Mongo Shell Format): ")
mongo_output = convert_mongo_shell_to_json(mongo_shell_query)
if mongo_output is None:
    print("Error: Invalid MongoDB query format. Exiting.")
    exit()
database = input("Enter Database Name: ")
schema_shell = input("Enter Schema (MongoDB Shell Format): ")
schema = convert_schema_shell_to_json(schema_shell)
if schema is None:
    print("Error: Invalid Schema format. Exiting.")
    exit()
schema_score = is_schema_linking_correct(nlp_query, mongo_output, mongo_shell_query, database, schema)
print(f"\nSchema Linking Score: {schema_score:.2f}")
print("Extracted Query Fields:", extract_fields_from_mongo(mongo_output))
print("Extracted Schema Fields:", extract_fields_from_schema(schema))


Enter NLP Query: Find users who registered in the last 6 months.
Enter MongoDB Output Query (Mongo Shell Format): {   "collection": "customers",   "filter": { "signup_date": { "$gte": "2024-09-01" } } }
Enter Database Name: usef
Enter Schema (MongoDB Shell Format): {   "users": {     "fields": ["user_id", "username", "registration_date"]   } }

Schema Linking Score: 0.28
Extracted Query Fields: {'signup_date'}
Extracted Schema Fields: {'users.registration_date', 'users.username', 'users.user_id'}
