In [None]:
import os
import re
import pandas as pd
import pandasql as ps
from dotenv import load_dotenv
from google import genai


load_dotenv()
API_KEY = os.getenv("GOOGLE_API_KEY")


client = genai.Client(api_key=API_KEY)

def ask_gemini(prompt: str) -> str:
    """
    Sends a prompt to Gemini model and returns the response text.
    Cleans out code fences and formatting.
    """
    response = client.models.generate_content(
        model="gemini-2.5-flash-lite",
        contents=prompt
    )
    raw_text = response.text.strip()
    
    clean_text = re.sub(r"^```(?:sql|python)?\s*|\s*```$", "", raw_text, flags=re.MULTILINE)
    clean_text = re.sub(r"^///.*|^--.*|^#.*", "", clean_text, flags=re.MULTILINE)
    clean_text = re.sub(r"^\*\*.*\*\*", "", clean_text, flags=re.MULTILINE)
    clean_text = clean_text.strip()
    
    return clean_text

def build_system_prompt_multi(dfs: dict, user_question: str) -> str:
    """
    Builds schema-aware system prompt for multiple DataFrames.
    """
    schema_blocks = []
    for name, df in dfs.items():
        schema_info = []
        for col in df.columns:
            dtype = str(df[col].dtype)
            example = df[col].dropna().astype(str).head(3).tolist()
            schema_info.append(f"- {col} ({dtype}), examples: {example}")
        schema_text = "\n".join(schema_info)
        schema_blocks.append(f"Table '{name}' schema:\n{schema_text}")
        
    schema_full = "\n\n".join(schema_blocks)
    
    system_prompt = f"""
You are a SQL generator. You are given multiple pandas DataFrames. 
You must ONLY output a clean, valid SQL query that can run with pandasql.sqldf().
Do NOT include any markdown formatting, code fences, explanations, or comments.
Just return the pure SQL query as plain text.

{schema_full}

User Question: {user_question}

Remember:
- Use exact DataFrame variable names as table names
- Use exact column names from schema
- Return ONLY the SQL query as plain text
- No markdown, no explanations, no formatting
"""
    return system_prompt

def format_results_nlp(result_df: pd.DataFrame, user_question: str) -> str:
    """
    Formats the query results in a natural language, list-friendly format.
    """
    if result_df.empty:
        return f"No results found for: {user_question}"
    
    # Create a natural language summary
    total_count = len(result_df)
    summary = f"Found {total_count} result{'s' if total_count != 1 else ''} for: {user_question}\n\n"
    
    # Format results as a clean list
    results_list = []
    for idx, row in result_df.iterrows():
        row_info = []
        for col, value in row.items():
            if pd.notna(value):
                row_info.append(f"{col}: {value}")
        
        result_entry = f"{idx + 1}. " + " | ".join(row_info)
        results_list.append(result_entry)
    
    return summary + "\n".join(results_list)

def query_with_ai_multi(dfs: dict, user_question: str):
    """
    Build schema-aware prompt for multiple DataFrames,
    ask Gemini to produce SQL, then run it and return formatted results.
    """
    try:
        prompt = build_system_prompt_multi(dfs, user_question)
        sql_query = ask_gemini(prompt)
        
        scope = {name: df for name, df in dfs.items()}
        result = ps.sqldf(sql_query, scope)
        
        # Format results in NLP-friendly way
        formatted_output = format_results_nlp(result, user_question)
        
        return sql_query, result, formatted_output
        
    except Exception as e:
        return None, None, f"Error processing query: {str(e)}"

if __name__ == "__main__":
    Personnel = pd.read_csv("final_unified_personnel_smart_modified.csv")
    Companies = pd.read_csv("unified_companies_complete.csv")
    
    dfs = {"Personnel": Personnel, "Companies": Companies}
    
    user_question = "Find all Architects in mumbai location"
    sql, result_df, nlp_output = query_with_ai_multi(dfs, user_question)
    
    if sql:
        print("=" * 50)
        print("NATURAL LANGUAGE RESULTS:")
        print("=" * 50)
        print(nlp_output)
        print("\n" + "=" * 50)
        print("Generated SQL Query:")
        print("=" * 50)
        print(sql)
        
        if not result_df.empty:
            print("\n" + "=" * 50)
            print("Raw Data (first 5 rows):")
            print("=" * 50)
            print(result_df.head().to_string(index=False))
    else:
        print(nlp_output) 