<a href="https://colab.research.google.com/github/vichitrarora/warp/blob/main/rewardfunctionschemafinal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install fuzzywuzzy



In [None]:
import json
import re
from fuzzywuzzy import fuzz

def extract_json_from_find(query: str):
    try:
        start = query.find("find(")
        if start == -1:
            raise ValueError("⚠️ Error: Invalid MongoDB Query Format. Ensure the query is correctly formatted.")

        brace_start = query.find("{", start)
        if brace_start == -1:
            raise ValueError("⚠️ Error: Could not find a valid JSON object in find().")

        stack = []
        for i in range(brace_start, len(query)):
            if query[i] == "{":
                stack.append(i)
            elif query[i] == "}":
                stack.pop()
                if not stack:
                    json_text = query[brace_start:i+1]

                    # ✅ Fix: Use `json.loads()` safely after replacing single quotes
                    json_text = json_text.replace("'", '"')  # Convert to valid JSON
                    return json.loads(json_text)  # ✅ Ensure valid JSON parsing

    except Exception as e:
        raise ValueError(f"⚠️ Error parsing MongoDB query: {e}")

    raise ValueError("⚠️ Error: Mismatched braces in MongoDB query.")




In [None]:
def convert_mongo_shell_to_json(mongo_shell_query: str):
    try:
        # First, try parsing the input directly as JSON (if it's already a valid JSON string)
        json_dict = json.loads(mongo_shell_query)
        return json_dict

    except json.JSONDecodeError:
        try:
            # If not JSON, extract the JSON from MongoDB shell query
            json_query = extract_json_from_find(mongo_shell_query)
            json_dict = json.loads(json_query)
            return json_dict
        except (json.JSONDecodeError, ValueError) as e:
            print(f"⚠️ Error: {e}")
            return None

In [None]:
def convert_schema_shell_to_json(schema_shell: str):
    try:
        schema_json = schema_shell.replace("'", '"')  # Convert single quotes to double quotes
        schema_dict = json.loads(schema_json)  # Convert to Python dictionary
        return schema_dict
    except json.JSONDecodeError:
        print("⚠️ Error: Invalid Schema Format. Ensure it is correctly formatted.")
        return None

In [None]:
def extract_fields_from_mongo(query: dict):
    fields = set()
    IGNORED_KEYS = {"collection", "filter"}  # ✅ Ignore "filter"
    MONGO_OPERATORS = {"$gte", "$lte", "$eq", "$ne", "$in", "$exists", "$regex", "$or", "$and"}

    def extract_keys(obj):
        if isinstance(obj, dict):
            for key, value in obj.items():
                if key not in IGNORED_KEYS and key not in MONGO_OPERATORS:
                    fields.add(key)
                extract_keys(value)  # ✅ Recursively extract deeper fields
        elif isinstance(obj, list):
            for item in obj:
                extract_keys(item)

    extract_keys(query)
    return fields


In [None]:
def extract_fields_from_schema(schema: dict):
    fields = set()

    # Case 1: Correct format with "collections"
    if "collections" in schema:
        for collection in schema["collections"]:
            if "document" in collection and "properties" in collection["document"]:
                fields.update(collection["document"]["properties"].keys())

    # Case 2: Your current format with collection names as keys
    else:
        for collection_name, collection_data in schema.items():
            if "fields" in collection_data:
                fields.update(collection_data["fields"])

    return fields


In [None]:
def extract_collection_from_mongo_query(mongo_query: str):
    # ✅ First, check if it's a JSON dictionary
    try:
        mongo_dict = json.loads(mongo_query)
        if "collection" in mongo_dict:
            return mongo_dict["collection"]
    except json.JSONDecodeError:
        pass  # If not JSON, fallback to regex

    # ✅ Improve regex for MongoDB Shell format
    match = re.search(r'db\.(\w+)\.find(?:One)?\s*\(', mongo_query)
    return match.group(1) if match else None


In [None]:
def extract_valid_collections_from_schema(schema: dict):
    if not isinstance(schema, dict):
        print("⚠️ Error: Schema is not a dictionary.")
        return set()

    # Handle schema in correct format
    if "collections" in schema:
        return {col.get("name") for col in schema["collections"] if "name" in col}

    # Handle schema in the incorrect format (like the one you provided)
    return set(schema.keys())  # Directly extract collection names


In [None]:
def is_schema_linking_correct(nlp_query: str, mongo_output: dict, mongo_query: str, database: str, schema: dict, threshold=80):
    query_fields = extract_fields_from_mongo(mongo_output)
    schema_fields = extract_fields_from_schema(schema)

    if not query_fields:
        return 0.0

    matched_fields = 0
    for field in query_fields:
        best_match_score = max(
            (fuzz.ratio(field.lower(), valid_field.lower()) for valid_field in schema_fields),
            default=0
        )

        # ✅ Fix: Give **partial credit** for fuzzy matches
        if best_match_score >= threshold:
            matched_fields += 1  # ✅ Full match
        elif best_match_score >= 50:  # ✅ If partial match (50-79), give proportional score
            matched_fields += best_match_score / 100

    schema_score = matched_fields / len(query_fields) if query_fields else 0.0

    predicted_collection = extract_collection_from_mongo_query(mongo_query)
    valid_collections = extract_valid_collections_from_schema(schema)

    if predicted_collection not in valid_collections:
        schema_score *= 0.5  # ✅ Apply penalty if collection is wrong

    return round(schema_score, 2)


In [None]:
nlp_query = input("Enter NLP Query: ")
mongo_shell_query = input("Enter MongoDB Output Query (Mongo Shell Format): ")
mongo_output = convert_mongo_shell_to_json(mongo_shell_query)

if mongo_output is None:
    print("Error: Invalid MongoDB query format. Exiting.")
    exit()

database = input("Enter Database Name: ")

schema_shell = input("Enter Schema (MongoDB Shell Format): ")
schema = convert_schema_shell_to_json(schema_shell)

if schema is None:
    print("Error: Invalid Schema format. Exiting.")
    exit()

schema_score = is_schema_linking_correct(nlp_query, mongo_output, mongo_shell_query, database, schema)

print(f"\nSchema Linking Score: {schema_score:.2f}")

print("Extracted Query Fields:", extract_fields_from_mongo(mongo_output))
print("Extracted Schema Fields:", extract_fields_from_schema(schema))



Enter NLP Query: Find users who registered in the last 6 months.
Enter MongoDB Output Query: {   "collection": "customers",   "filter": { "signup_date": { "$gte": "2024-09-01" } } }
Enter Database Name: users_db
Enter Schema: {   "users": {     "fields": ["user_id", "username", "registration_date"]   } }

Schema Linking Score: 0.28
Extracted Query Fields: {'signup_date'}
Extracted Schema Fields: {'user_id', 'registration_date', 'username'}
