In [195]:
from langchain_groq import ChatGroq

llm = ChatGroq(
    model="llama-3.1-8b-instant",
    groq_api_key="GROQ_API_KEY"
)


In [196]:
#### print(llm.invoke("Write a shayari about expressing love for maggie.").content)

# connect with database asked some basic questions

from langchain_community.utilities import SQLDatabase


db_user = "root"
db_password = "your db password"
db_host = "localhost"
db_name = "amazon_tshirts"

db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}",sample_rows_in_table_info=3)


schema = db.get_table_info()
print(schema)



CREATE TABLE discounts (
	discount_id INTEGER NOT NULL AUTO_INCREMENT, 
	t_shirt_id INTEGER NOT NULL, 
	pct_discount DECIMAL(5, 2), 
	PRIMARY KEY (discount_id), 
	CONSTRAINT discounts_ibfk_1 FOREIGN KEY(t_shirt_id) REFERENCES t_shirts (t_shirt_id), 
	CONSTRAINT discounts_chk_1 CHECK ((`pct_discount` between 0 and 100))
)DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci

/*
3 rows from discounts table:
discount_id	t_shirt_id	pct_discount
1	1	10.00
2	2	15.00
3	3	20.00
*/


CREATE TABLE t_shirts (
	t_shirt_id INTEGER NOT NULL AUTO_INCREMENT, 
	brand ENUM('Van Huesen','Levi','Nike','Adidas') NOT NULL, 
	color ENUM('Red','Blue','Black','White') NOT NULL, 
	size ENUM('XS','S','M','L','XL') NOT NULL, 
	price INTEGER, 
	stock_quantity INTEGER NOT NULL, 
	PRIMARY KEY (t_shirt_id), 
	CONSTRAINT t_shirts_chk_1 CHECK ((`price` between 10 and 50))
)DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci

/*
3 rows from t_shirts table:
t_shirt_id	brand	color	size	price	stock

In [186]:

from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

prompt = PromptTemplate.from_template("""
You are a MySQL expert. Use the schema to write a correct query.
Return ONLY SQL. Do not explain.

RULES:
If the question asks "how many" or "left in stock", use SUM(stock_quantity)
Return ONLY the raw SQL query.
Do NOT wrap it in ``` or sql tags.
Do NOT add explanations.

Schema:
{schema}

Question:
{question}

SQL:
""")

chain = prompt | llm | StrOutputParser()

question = "How many white color Levi's t shirts we have available?"

sql = chain.invoke({"schema": schema, "question": question})

print("\nðŸ§  Generated SQL:\n", sql)


ðŸ§  Generated SQL:
 SELECT SUM(stock_quantity) 
FROM t_shirts 
WHERE brand = 'Levi' AND color = 'White';


In [187]:
result = db.run(sql)
print("\nðŸ“Š Result:\n", result)


ðŸ“Š Result:
 [(Decimal('169'),)]


In [218]:
few_shots = [
    {'Question' : "How many t-shirts do we have left for Nike in XS size and white color?",
     'SQLQuery' : "SELECT sum(stock_quantity) FROM t_shirts WHERE brand = 'Nike' AND color = 'White' AND size = 'XS'",
     'SQLResult': "Result of the SQL query",
     'Answer' : '66'},
    {'Question': "How much is the total price of the inventory for all S-size t-shirts?",
     'SQLQuery':"SELECT SUM(price*stock_quantity) FROM t_shirts WHERE size = 'S'",
     'SQLResult': "Result of the SQL query",
     'Answer': '18734'},
    {'Question': "If we have to sell all the Leviâ€™s T-shirts today with discounts applied. How much revenue  our store will generate (post discounts)?" ,
     'SQLQuery' : """SELECT sum(a.total_amount * ((100-COALESCE(discounts.pct_discount,0))/100)) as total_revenue from
(select sum(price*stock_quantity) as total_amount, t_shirt_id from t_shirts where brand = 'Levi'
group by t_shirt_id) a left join discounts on a.t_shirt_id = discounts.t_shirt_id
 """,
     'SQLResult': "Result of the SQL query",
     'Answer': '16309.600000'} ,
     {'Question' : "If we have to sell all the Leviâ€™s T-shirts today. How much revenue our store will generate without discount?" ,
      'SQLQuery': "SELECT SUM(price * stock_quantity) FROM t_shirts WHERE brand = 'Levi'",
      'SQLResult': "Result of the SQL query",
      'Answer' : '16691'},
    {'Question': "How many white color Levi's shirt I have?",
     'SQLQuery' : "SELECT sum(stock_quantity) FROM t_shirts WHERE brand = 'Levi' AND color = 'White'",
     'SQLResult': "Result of the SQL query",
     'Answer' : '169'
     },
    {'Question': "How much revenue  our store will generate by selling all Van Heuson TShirts with discount?",
     'SQLQuery' : """SELECT sum(a.total_amount * ((100-COALESCE(discounts.pct_discount,0))/100)) as total_revenue from
(select sum(price*stock_quantity) as total_amount, t_shirt_id from t_shirts where brand = 'Van Heuson'
group by t_shirt_id) a left join discounts on a.t_shirt_id = discounts.t_shirt_id
 """
     }
]

In [219]:
from langchain_core.documents import Document
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings

fewshot_docs = []

for ex in few_shots:
    text = f"""Q: {ex['Question']}
            A: {ex['SQLQuery']}"""
    
    fewshot_docs.append(
        Document(page_content=text, metadata={"type": "sql_example"})
    )

emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

persist_dir = "fewshot_chroma"

vector_store = Chroma.from_documents(
    fewshot_docs, embeddings, persist_directory=persist_dir
)

vector_store.persist()


In [220]:
vector_store = Chroma(
    embedding_function=embeddings,
    persist_directory=persist_dir
)

retriever = vector_store.as_retriever(search_kwargs={"k": 2})


In [221]:
def get_fewshot_examples(query):
    docs = retriever.invoke(query)
    return "\n\n".join(d.page_content for d in docs)


In [222]:
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

prompt = PromptTemplate.from_template("""
You are a MySQL expert. Use the schema and examples to write a correct SQL query.
Return ONLY SQL. No explanations. No ```.

EXAMPLES:
{examples}

RULES:
- If the question asks "how many" or "left in stock", use SUM(stock_quantity)
- Always return a single SQL query
- Use only valid columns

Schema:
{schema}

Question:
{question}

SQL:
""")

sql_chain = prompt | llm | StrOutputParser()


In [226]:
query = " how much revenue it will make after selling van hueson t_shirt with discount of extra small size"

fewshot_context = get_fewshot_examples(query)

sql = sql_chain.invoke({
    "schema": schema,
    "question": query,
    "examples": fewshot_context
})

print("\nðŸ§  Generated SQL:\n", sql)
print("\nðŸ“Š DB Result:\n", db.run(sql))



ðŸ§  Generated SQL:
 SELECT SUM(a.total_amount * ((100-COALESCE(d.pct_discount,0))/100)) as total_revenue 
FROM (SELECT SUM(price*stock_quantity) as total_amount, t_shirt_id 
      FROM t_shirts 
      WHERE brand = 'Van Huesen' AND size = 'XS' 
      GROUP BY t_shirt_id) a 
LEFT JOIN discounts d ON a.t_shirt_id = d.t_shirt_id

ðŸ“Š DB Result:
 [(Decimal('989.000000'),)]
