In [5]:
import openai
import psycopg2
import pandas as pd
from fuzzywuzzy import fuzz

# Database connection details
DATABASE_HOST = "database-test-postgress-instance.cpk2uyae6iza.ap-south-1.rds.amazonaws.com"
DATABASE_USERNAME = "postgres"
DATABASE_PASSWORD = "valign#123"
DATABASE_DB = "python_test_poc"
PORT = 5432

# Function to connect to PostgreSQL database
def connect_to_db():
    try:
        conn = psycopg2.connect(
            dbname=DATABASE_DB,
            user=DATABASE_USERNAME,
            password=DATABASE_PASSWORD,
            host=DATABASE_HOST,
            port=PORT
        )
        return conn
    except psycopg2.Error as e:
        print(f"Error connecting to the database: {e}")
        raise

# Fetch schema with column names and data types
def fetch_schema_with_data_types(conn):
    try:
        query = """
        SELECT table_name, column_name
        FROM information_schema.columns
        WHERE table_schema = 'public'
        """
        schema_df = pd.read_sql(query, conn)
        return schema_df
    except Exception as e:
        print(f"Error fetching schema with data types: {e}")
        raise

# Fetch unique column values (entities) from specific columns
def fetch_unique_column_values(conn, table_name, column_name):
    try:
        query = f"SELECT DISTINCT {column_name} FROM {table_name};"
        df = pd.read_sql(query, conn)
        return df[column_name].dropna().astype(str).tolist()  # Convert all to strings
    except Exception as e:
        print(f"Error fetching unique values from {table_name}.{column_name}: {e}")
        raise

# Fuzzy matching function
def is_fuzzy_match(query, options, threshold=75):
    query = str(query).lower()  # Ensure the query is a string and lowercase
    for option in options:
        option = str(option).lower()  # Convert each option to lowercase string
        if fuzz.partial_ratio(query, option) >= threshold:
            return True
    return False

# Function to classify query type
def classify_query(user_query, schema_df, conn, threshold=75):
    user_query_lower = user_query.lower()

    # Extract unique table and column names
    table_names = schema_df['table_name'].str.lower().unique()
    column_names = schema_df['column_name'].str.lower().unique()

    # Classify as 'database' query if table/column names are mentioned
    if is_fuzzy_match(user_query_lower, table_names, threshold) or \
       is_fuzzy_match(user_query_lower, column_names, threshold):
        
        # Check for 'database entities' by matching user query against unique values in columns
        for table_name in table_names:
            # Get columns of the table
            columns_in_table = schema_df[schema_df['table_name'].str.lower() == table_name]['column_name'].str.lower().tolist()

            for col in columns_in_table:
                # Fetch unique values (entities) from the specific column
                unique_values = fetch_unique_column_values(conn, table_name, col)
                if is_fuzzy_match(user_query_lower, unique_values, threshold):
                    return "database entities"

        return "database"

    # Otherwise, classify as 'knowledge' query
    return "knowledge"

# Example usage
if __name__ == "__main__":
    conn = connect_to_db()

    # Fetch schema data
    schema_df = fetch_schema_with_data_types(conn)

    # Test queries
    user_queries = [
        "give me the count of total projects",  # Database query
        "tell me about the project completion status of project X",  # Database entities query
        "what is the meaning of life?",  # Knowledge query
    ]

    for query in user_queries:
        query_type = classify_query(query, schema_df, conn)
        print(f"Query: '{query}'")
        print(f"Classified as: {query_type}\n")

    conn.close()


  schema_df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)


Query: 'give me the count of total projects'
Classified as: database entities



  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)


Query: 'tell me about the project completion status of project X'
Classified as: database entities

Query: 'what is the meaning of life?'
Classified as: knowledge



  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
  df = pd.read_sql(query, conn)
