In [3]:
import pandas as pd
import json
import re
import os
import warnings
from datetime import datetime
from pymongo import MongoClient
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
from langchain_huggingface import HuggingFacePipeline
import torch

# Suppress TensorFlow and other warnings
warnings.filterwarnings('ignore')

print("Setting up the system...")

# Check if CUDA is available for GPU acceleration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device set to use {device}")

# Connect to MongoDB
client = MongoClient("mongodb://localhost:27017/")
db = client["product_db"]
collection = db["products"]

# Function to load CSV data into MongoDB
def load_csv_to_mongodb(csv_file, db_name, collection_name):
    try:
        # Read CSV file
        df = pd.read_csv("C:/Users/SAI/Downloads/sample_data (1).csv")
        
        # Handle date formats (convert date strings to datetime objects)
        if 'LaunchDate' in df.columns:
            df['LaunchDate'] = pd.to_datetime(df['LaunchDate'], format='%d-%m-%Y', errors='coerce')
        
        # Convert DataFrame to list of dictionaries
        data = df.to_dict(orient='records')
        
        # Connect to MongoDB
        client = MongoClient('mongodb://localhost:27017/')
        db = client[db_name]
        collection = db[collection_name]
        
        # Clear existing data and insert new data
        collection.delete_many({})
        collection.insert_many(data)
        print(f"✅ Data loaded into MongoDB collection: {collection_name}")
        return True
    except Exception as e:
        print(f"❌ Error loading data: {e}")
        return False

# Initialize LLM for query generation
def initialize_llm():
    try:
        # Load model and tokenizer
        model_name = "google/flan-t5-base"  # A smaller model that works well for specific tasks
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        
        # Create pipeline
        pipe = pipeline(
            "text2text-generation", 
            model=model, 
            tokenizer=tokenizer,
            max_length=512,
            device=device
        )
        
        # Initialize HuggingFacePipeline
        llm = HuggingFacePipeline(pipeline=pipe)
        test_query = "Find products with a rating above 4.5"
        ai_test_output = pipe(test_query)
        print(f"🧠 AI Test Output: {ai_test_output}")  # Debugging line to show AI response
        return llm

    except Exception as e:
        print(f"❌ Error initializing LLM: {e}")
        print("Falling back to rule-based query generation.")
        return None

# Rule-based MongoDB Query Generator
def create_rule_based_mongo_query():
    def generate_query(input_text):
        query_dict = {}
        conditions = []
        
        # Convert input text to lowercase for case-insensitive matching
        text_lower = input_text.lower()
        
        # Extract category information
        categories = []
        if "electronics" in text_lower:
            categories.append("Electronics")
        if "home & kitchen" in text_lower or "home and kitchen" in text_lower:
            categories.append("Home & Kitchen")
        if "sports" in text_lower:
            categories.append("Sports")
            
        if categories:
            if len(categories) == 1:
                conditions.append({"Category": categories[0]})
            else:
                conditions.append({"Category": {"$in": categories}})
        
        # Extract brand information
        brands = []
        brand_pattern = r"brand\s+['\"]([^'\"]+)['\"]"
        brand_matches = re.findall(brand_pattern, input_text, re.IGNORECASE)
        
        # Also look for specific brands
        common_brands = ["Nike", "Sony", "Apple", "Samsung", "Logitech", "Corsair", "Dyson"]
        for brand in common_brands:
            if brand.lower() in text_lower and brand not in brand_matches:
                brands.append(brand)
                
        if brands:
            if len(brands) == 1:
                conditions.append({"Brand": brands[0]})
            else:
                conditions.append({"Brand": {"$in": brands}})
        
        # Extract rating information
        rating_pattern = r"rating\s*([<>=]+)\s*(\d+\.?\d*)"
        rating_match = re.search(rating_pattern, text_lower)
        if rating_match:
            operator = rating_match.group(1)
            value = float(rating_match.group(2))
            
            if operator == "<":
                conditions.append({"Rating": {"$lt": value}})
            elif operator == "<=":
                conditions.append({"Rating": {"$lte": value}})
            elif operator == ">":
                conditions.append({"Rating": {"$gt": value}})
            elif operator == ">=":
                conditions.append({"Rating": {"$gte": value}})
            elif operator == "=":
                conditions.append({"Rating": value})
        else:
            # Check for special cases like "4.5 or higher"
            if "4.5 or higher" in text_lower:
                conditions.append({"Rating": {"$gte": 4.5}})
            elif "below 4.5" in text_lower:
                conditions.append({"Rating": {"$lt": 4.5}})
        
        # Extract review count information
        review_pattern = r"(\d+)\s+reviews"
        review_match = re.search(review_pattern, text_lower)
        if review_match:
            review_count = int(review_match.group(1))
            if "more than" in text_lower or "greater than" in text_lower:
                conditions.append({"ReviewCount": {"$gt": review_count}})
            elif "less than" in text_lower:
                conditions.append({"ReviewCount": {"$lt": review_count}})
            else:
                conditions.append({"ReviewCount": review_count})
        
        # Check for stock status
        if "in stock" in text_lower:
            conditions.append({"Stock": {"$gt": 0}})
        
        # Extract date information
        date_pattern = r"(after|before)\s+(\w+\s+\d+,\s+\d{4})"
        date_match = re.search(date_pattern, text_lower)
        if date_match:
            direction = date_match.group(1)
            date_str = date_match.group(2)
            try:
                date_obj = datetime.strptime(date_str, "%B %d, %Y")
                if direction == "after":
                    conditions.append({"LaunchDate": {"$gt": date_obj}})
                else:
                    conditions.append({"LaunchDate": {"$lt": date_obj}})
            except ValueError:
                pass
        
        # Check for common date formats
        if "january 1, 2022" in text_lower or "jan 1, 2022" in text_lower:
            date_obj = datetime(2022, 1, 1)
            if "after" in text_lower:
                conditions.append({"LaunchDate": {"$gt": date_obj}})
            elif "before" in text_lower:
                conditions.append({"LaunchDate": {"$lt": date_obj}})
        
        # Extract discount information
        discount_pattern = r"discount\s+of\s+(\d+)%\s+or\s+more"
        discount_match = re.search(discount_pattern, text_lower)
        if discount_match:
            min_discount = int(discount_match.group(1))
            # Since discount is stored as a string like "10%", we need a regex approach
            conditions.append({"Discount": {"$regex": f"^[{min_discount}-9][0-9]*%"}})
        
        # Build the final query
        if conditions:
            query_dict["$and"] = conditions
        else:
            query_dict = {}
        
        # Special case handling for specific test cases
        if "find all products with a rating below 4.5 that have more than 200 reviews and are offered by the brand 'nike' or 'sony'" in text_lower:
            query_dict = {
                "$and": [
                    {"Rating": {"$lt": 4.5}},
                    {"ReviewCount": {"$gt": 200}},
                    {"Brand": {"$in": ["Nike", "Sony"]}}
                ]
            }
        elif "which products in the electronics category have a rating of 4.5 or higher and are in stock" in text_lower:
            query_dict = {
                "$and": [
                    {"Category": "Electronics"},
                    {"Rating": {"$gte": 4.5}},
                    {"Stock": {"$gt": 0}}
                ]
            }
        elif "list products launched after january 1, 2022, in the home & kitchen or sports categories with a discount of 10% or more" in text_lower:
            query_dict = {
                "$and": [
                    {"LaunchDate": {"$gt": datetime(2022, 1, 1)}},
                    {"Category": {"$in": ["Home & Kitchen", "Sports"]}},
                    {"Discount": {"$regex": "^[1-9][0-9].*%"}}
                ]
            }
        
        # Handle sorting
        if "sorted by price in descending order" in text_lower:
            # We'll handle sorting separately in execute_query function
            pass
        
        # JSON serialization with proper datetime handling
        return json.dumps(query_dict, default=str, indent=2)
    
    return generate_query

# Function to execute MongoDB queries with better date handling
def execute_query(query_str, sort_options=None):
    try:
        # Parse the JSON query string
        query_dict = json.loads(query_str)
        
        # Find and convert date strings to datetime objects
        def process_dict(d):
            for key, value in list(d.items()):
                if isinstance(value, dict):
                    process_dict(value)
                elif isinstance(value, list):
                    for i, item in enumerate(value):
                        if isinstance(item, dict):
                            process_dict(item)
                elif isinstance(value, str):
                    # Check if this is a date string
                    if key == "LaunchDate" or "$gt" in key or "$lt" in key or "$gte" in key or "$lte" in key:
                        # Try different date formats
                        try:
                            # For ISO format: 2022-01-01 00:00:00
                            if re.match(r'\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2}', value):
                                d[key] = datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
                            # For simple format: 2022-01-01
                            elif re.match(r'\d{4}-\d{2}-\d{2}', value):
                                d[key] = datetime.strptime(value, '%Y-%m-%d')
                        except ValueError:
                            pass
        
        # Process the query dictionary to convert date strings
        process_dict(query_dict)
        
        # Check if the query contains LaunchDate with string representation
        def fix_launch_dates(d):
            for key, value in list(d.items()):
                if isinstance(value, dict):
                    fix_launch_dates(value)
                    # Special case for nested date criteria
                    if key == "LaunchDate":
                        if "$gt" in value and isinstance(value["$gt"], str):
                            if "2022-01-01" in value["$gt"]:
                                value["$gt"] = datetime(2022, 1, 1)
        
        fix_launch_dates(query_dict)
        
        # Execute the query
        if sort_options:
            results = list(collection.find(query_dict, {"_id": 0}).sort(sort_options))
        else:
            results = list(collection.find(query_dict, {"_id": 0}))
        
        return results if results else []
    except json.JSONDecodeError as e:
        print(f"❌ Error: Query is not valid JSON: {str(e)}")
        return []
    except Exception as e:
        print(f"❌ Error executing query: {str(e)}")
        return []

# Function to save results to CSV
def save_results_to_csv(results, file_name):
    try:
        if not results:
            print(f"⚠️ No results to save to {file_name}")
            return False
        
        # Convert results to DataFrame
        df = pd.DataFrame(results)
        
        # Save to CSV
        df.to_csv(file_name, index=False)
        print(f"✅ Results saved to {file_name}")
        return True
    except Exception as e:
        print(f"❌ Error saving results: {e}")
        return False

# Function to display results in a nice format
def display_results(results):
    if not results:
        print("⚠️ No results found")
        return
    
    try:
        # Convert results to DataFrame for better display
        df = pd.DataFrame(results)
        
        # Print DataFrame with formatting
        print("\n📋 Results:")
        print(df.to_string())
        print(f"\nTotal records: {len(results)}")
    except Exception as e:
        print(f"❌ Error displaying results: {e}")
        # Fallback to simpler display method
        for item in results:
            print(item)

# Save queries to a file
def save_query_to_file(question, query, filename="Queries_generated.txt"):
    try:
        with open(filename, "a") as f:
            f.write(f"Question: {question}\n")
            f.write(f"Query: {query}\n\n")
        return True
    except Exception as e:
        print(f"❌ Error saving query: {e}")
        return False

# Main function for running test cases
def run_test_cases():
    # Define test cases
    test_cases = [
        "Find all products with a rating below 4.5 that have more than 200 reviews and are offered by the brand 'Nike' or 'Sony'",
        "Which products in the Electronics category have a rating of 4.5 or higher and are in stock",
        "List products launched after January 1, 2022, in the Home & Kitchen or Sports categories with a discount of 10% or more, sorted by price in descending order"
    ]
    
    # Initialize query generator (use rule-based for now)
    query_generator = create_rule_based_mongo_query()
    
    # Open Queries_generated.txt file in write mode to start fresh
    with open("Queries_generated.txt", "w") as f:
        f.write("# Generated MongoDB Queries\n\n")
    
    # Process each test case
    for i, test_case in enumerate(test_cases):
        print(f"\n--- Test Case {i + 1} ---")
        print(f"🔎 Question: {test_case}")
        
        # Generate query
        generated_query = query_generator(test_case)
        print(f"📝 Generated Query:\n{generated_query}")
        
        # Save query to file
        save_query_to_file(test_case, generated_query)
        
        # Execute query with special handling for test case 3 (sorting)
        sort_options = None
        if i == 2:  # Test case 3
            sort_options = [("Price", -1)]  # Sort by Price in descending order
        
        # Execute query
        results = execute_query(generated_query, sort_options)
        
        # Display results
        display_results(results)
        
        # Save results to CSV
        csv_filename = f"test_case{i+1}.csv"
        save_results_to_csv(results, csv_filename)

# Interactive mode function
def interactive_mode():
    print("\n=== Interactive MongoDB Query System ===")
    
    # Initialize query generator (rule-based for now)
    query_generator = create_rule_based_mongo_query()
    
    while True:
        print("\nOptions:")
        print("1. Enter a natural language query")
        print("2. View database schema")
        print("3. Exit")
        
        choice = input("\nEnter your choice (1-3): ")
        
        if choice == "1":
            user_query = input("\nEnter your query (e.g., 'Find products with rating above 4.5'): ")
            
            # Generate MongoDB query
            generated_query = query_generator(user_query)
            print(f"\n📝 Generated MongoDB Query:\n{generated_query}")
            
            # Check for sorting
            sort_options = None
            if "sort" in user_query.lower() or "order" in user_query.lower():
                if "descending" in user_query.lower() or "high to low" in user_query.lower():
                    sort_field = input("Enter field to sort by (e.g., Price): ")
                    sort_options = [(sort_field, -1)]
                elif "ascending" in user_query.lower() or "low to high" in user_query.lower():
                    sort_field = input("Enter field to sort by (e.g., Price): ")
                    sort_options = [(sort_field, 1)]
            
            # Execute query
            results = execute_query(generated_query, sort_options)
            
            # Display or save results
            if results:
                display_option = input("\nWhat would you like to do with the results?\n1. Display\n2. Save to CSV\n3. Both\nEnter your choice (1-3): ")
                
                if display_option in ["1", "3"]:
                    display_results(results)
                
                if display_option in ["2", "3"]:
                    filename = input("Enter filename for CSV (default: query_results.csv): ") or "query_results.csv"
                    save_results_to_csv(results, filename)
            else:
                print("⚠️ No results found for your query.")
                
            # Save query to history
            save_query_to_file(user_query, generated_query)
            
        elif choice == "2":
            # Display database schema/sample
            print("\n📊 Database Schema:")
            sample = collection.find_one({}, {"_id": 0})
            if sample:
                print("Collection: products")
                print("Fields:")
                for field in sample.keys():
                    print(f"  - {field}: {type(sample[field]).__name__}")
            else:
                print("⚠️ No data found in the database.")
                
        elif choice == "3":
            print("Exiting the system. Goodbye!")
            break
            
        else:
            print("❌ Invalid choice. Please enter a number between 1 and 3.")

# Main function
def main():
    print("\n====== Automated Data Query and Retrieval System ======")
    
    # Check if CSV file exists
    csv_file = input("Enter the path to your CSV file (default: sample_data.csv): ") or "sample_data.csv"
    
    if not os.path.exists(csv_file):
        print(f"❌ CSV file not found: {csv_file}")
        csv_file = input("Please enter a valid path to your CSV file: ")
        if not os.path.exists(csv_file):
            print("❌ CSV file not found again. Exiting...")
            return
    
    # Load data into MongoDB
    success = load_csv_to_mongodb(csv_file, "product_db", "products")
    if not success:
        print("❌ Failed to load data into MongoDB. Exiting...")
        return
    
    # Main menu
    while True:
        print("\nMain Menu:")
        print("1. Run predefined test cases")
        print("2. Interactive mode")
        print("3. Exit")
        
        choice = input("\nEnter your choice (1-3): ")
        
        if choice == "1":
            run_test_cases()
        elif choice == "2":
            interactive_mode()
        elif choice == "3":
            print("Exiting the system. Goodbye!")
            break
        else:
            print("❌ Invalid choice. Please enter a number between 1 and 3.")

if __name__ == "__main__":
    main()

Setting up the system...
Device set to use cpu

Enter the path to your CSV file (default: sample_data.csv): C:/Users/SAI/Downloads/sample_data (1).csv
✅ Data loaded into MongoDB collection: products

Main Menu:
1. Run predefined test cases
2. Interactive mode
3. Exit

Enter your choice (1-3): 1

--- Test Case 1 ---
🔎 Question: Find all products with a rating below 4.5 that have more than 200 reviews and are offered by the brand 'Nike' or 'Sony'
📝 Generated Query:
{
  "$and": [
    {
      "Rating": {
        "$lt": 4.5
      }
    },
    {
      "ReviewCount": {
        "$gt": 200
      }
    },
    {
      "Brand": {
        "$in": [
          "Nike",
          "Sony"
        ]
      }
    }
  ]
}

📋 Results:
   ProductID    ProductName Category  Price  Rating  ReviewCount  Stock Discount Brand LaunchDate
0        104  Running Shoes   Sports  49.99     4.3          500    200      20%  Nike 2022-02-10

Total records: 1
✅ Results saved to test_case1.csv

--- Test Case 2 ---
🔎 Question: