#### Installation & Imports

In [43]:
import os
from dotenv import load_dotenv
from pathlib import Path
from sqlalchemy import create_engine
from langchain_community.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain_groq import ChatGroq
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

#### Configuration & Credentials

In [44]:
# Load the environment variables from the parent directory
env_path = Path('..') / '.env'
load_dotenv(dotenv_path=env_path, override=True)

# Grab the database credentials and the Groq API key
groq_api_key = os.getenv("GROQ_API_KEY")
db_user = os.getenv("DB_USERNAME")
db_password = os.getenv("DB_PASSWORD")
db_host = os.getenv("DB_HOST")
db_name = os.getenv("DB_DATABASE")
db_port = os.getenv("DB_PORT")

# Verify the Groq key is loaded
if not groq_api_key:
    print("❌ Error: GROQ_API_KEY is missing. Check your .env file.")
else:
    print("✅ Keys loaded successfully.")

✅ Keys loaded successfully.


#### Connect to Database

In [45]:
# Build the connection string for your TiDB MySQL database
db_uri = f"mysql+mysqlconnector://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}"

# Create the engine with keep-alive settings to prevent timeouts
engine = create_engine(db_uri, pool_pre_ping=True, pool_recycle=300)

# Initialize the LangChain database wrapper
db = SQLDatabase(engine, sample_rows_in_table_info=3)

print("✅ Database connected.")

✅ Database connected.


#### Define Training Examples

In [46]:
# List of example questions and their matching SQL queries to guide the model
few_shots = [
    {'Question': "How many t-shirts do we have left for Nike in XS size and white color?",
     'SQLQuery': "SELECT sum(stock_quantity) FROM t_shirts WHERE brand = 'Nike' AND color = 'White' AND size = 'XS'",
     'SQLResult': "Result of the SQL query",
     'Answer': "There are 91 Nike t-shirts in XS size and white color."},
    {'Question': "How much is the total price of the inventory for all S-size t-shirts?",
     'SQLQuery': "SELECT SUM(price*stock_quantity) FROM t_shirts WHERE size = 'S'",
     'SQLResult': "Result of the SQL query",
     'Answer': "The total price for all S-size t-shirts is 22,292."},
    {'Question': "If we have to sell all the Levi’s T-shirts today with discounts applied. How much revenue our store will generate (post discounts)?",
     'SQLQuery': "SELECT sum(a.total_amount * ((100-COALESCE(discounts.pct_discount,0))/100)) as total_revenue from (select sum(price*stock_quantity) as total_amount, t_shirt_id from t_shirts where brand = 'Levi' group by t_shirt_id) a left join discounts on a.t_shirt_id = discounts.t_shirt_id",
     'SQLResult': "Result of the SQL query",
     'Answer': "The total revenue for Levi's t-shirts after discounts is 16,725.4."},
    {'Question': "If we have to sell all the Levi’s T-shirts today. How much revenue our store will generate without discount?",
     'SQLQuery': "SELECT SUM(price * stock_quantity) FROM t_shirts WHERE brand = 'Levi'",
     'SQLResult': "Result of the SQL query",
     'Answer': "The total revenue for Levi's t-shirts without discounts is 17,462."},
    {'Question': "How many white color Levi's shirt I have?",
     'SQLQuery': "SELECT sum(stock_quantity) FROM t_shirts WHERE brand = 'Levi' AND color = 'White'",
     'SQLResult': "Result of the SQL query",
     'Answer': "You have 50 white Levi's shirts."}
]

#### Create Vector Database

In [47]:
# Initialize the embedding model on the CPU
model_kwargs, encode_kwargs = {'device': 'cpu'}, {'normalize_embeddings': False}
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2', model_kwargs=model_kwargs, encode_kwargs=encode_kwargs)

# Store the examples in ChromaDB for easy retrieval
to_vectorize = [" ".join(example.values()) for example in few_shots]
vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=few_shots)

# Configure the selector to pick the 2 most relevant examples per question
example_selector = SemanticSimilarityExampleSelector(vectorstore=vectorstore, k=2)

print("✅ Knowledge base built.")

✅ Knowledge base built.


#### Configure Prompt Template

In [48]:
# Define the expert persona and explicitly ban Markdown and repetition
mysql_prompt = """You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run.
IMPORTANT: Return ONLY the raw SQL code in the SQLQuery section. Do NOT use markdown.
IMPORTANT: In the Answer section, do NOT repeat the SQL code. Only output the final natural language sentence.

Format:
Question: {input}
SQLQuery: Raw SQL query without formatting
SQLResult: Result of the SQLQuery
Answer: Final natural language response
"""

# Restrict the model to use only the specific table schema provided
PROMPT_SUFFIX = "Only use these tables: {table_info}\nQuestion: {input}"

# Set the structure for how examples are presented to the model
example_prompt = PromptTemplate(
    input_variables=["Question", "SQLQuery", "SQLResult", "Answer"], 
    template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\nAnswer: {Answer}"
)

# Combine the instructions, schema, and examples into the final prompt
few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector, 
    example_prompt=example_prompt, 
    prefix=mysql_prompt, 
    suffix=PROMPT_SUFFIX, 
    input_variables=["input", "table_info", "top_k"]
)

print("✅ Instructions updated.")

✅ Instructions updated.


#### Execute Chain

In [49]:
# Initialize Groq Llama 3
llm = ChatGroq(model_name="llama-3.3-70b-versatile", groq_api_key=groq_api_key, temperature=0)

# Create the database chain
chain = SQLDatabaseChain.from_llm(llm, db, verbose=False, prompt=few_shot_prompt, return_intermediate_steps=True)

# Define the function
def ask_database(question):
    try:
        # Run the chain
        response = chain.invoke(question)
        
        # Extract SQL: strip labels
        sql_code = response['intermediate_steps'][1].split("SQLQuery:")[-1].strip()
        
        # Extract Result: Clean labels
        raw_result = response['result'].split("Answer:")[-1].strip()
        
        # Take the last line (your preferred logic)
        final_result = raw_result.split("\n")[-1].strip()

        # --- THE FIX ---
        # If the result contains "SQLQuery:" or starts with "SELECT", it's a hallucination/empty result.
        if "SQLQuery:" in final_result or final_result.startswith("SELECT"):
            final_result = "Result Not Found"
        # ---------------

        # Print the clean output
        print("-" * 50)
        print(f"QUESTION: {question}")
        print(f"SQL CODE: {sql_code}")
        print(f"RESULT:   {final_result}")
        print("-" * 50)

    except Exception as e:
        # Print actual errors if the query is invalid
        print("-" * 50)
        print(f"QUESTION: {question}")
        print(f"RESULT:   ❌ Database Error: {e}")
        print("-" * 50)

print("✅ Function 'ask_database()' ready.")

✅ Function 'ask_database()' ready.


In [50]:
# Check total revenue with discounts applied
ask_database("If we have to sell all the Levi’s T-shirts today with discounts applied. How much revenue our store will generate (post discounts)?")

--------------------------------------------------
QUESTION: If we have to sell all the Levi’s T-shirts today with discounts applied. How much revenue our store will generate (post discounts)?
SQL CODE: SELECT SUM(t.price * t.stock_quantity * ((100 - COALESCE(d.pct_discount, 0)) / 100)) AS total_revenue FROM t_shirts t LEFT JOIN discounts d ON t.t_shirt_id = d.t_shirt_id WHERE t.brand = 'Levi'
RESULT:   The total revenue for Levi's t-shirts after discounts is 4320.
--------------------------------------------------


In [51]:
# Check total revenue without any discounts
ask_database("If we have to sell all the Levi’s T-shirts today. How much revenue our store will generate without discount?")

--------------------------------------------------
QUESTION: If we have to sell all the Levi’s T-shirts today. How much revenue our store will generate without discount?
SQL CODE: SELECT sum(price * stock_quantity) as total_revenue FROM t_shirts WHERE brand = 'Levi'
RESULT:   The total revenue for Levi's t-shirts without discounts is 4600.
--------------------------------------------------


In [52]:
# Ask about a brand that doesn't exist to test error handling
ask_database("How many t-shirts do we have for the brand 'Supreme'?")

--------------------------------------------------
QUESTION: How many t-shirts do we have for the brand 'Supreme'?
SQL CODE: SELECT sum(stock_quantity) FROM t_shirts WHERE brand = 'Supreme'
RESULT:   Result Not Found
--------------------------------------------------
