In [None]:
%pip install -r requirements.txt

In [1]:
%pip install -U google-generativeai langchain-google-genai langchain-huggingface langchain langchain-community langsmith pymysql sentence-transformers

Collecting langchain-google-genai
  Using cached langchain_google_genai-2.1.7-py3-none-any.whl.metadata (7.0 kB)
INFO: pip is looking at multiple versions of langchain-google-genai to determine which version is compatible with other requirements. This could take a while.
  Using cached langchain_google_genai-2.1.6-py3-none-any.whl.metadata (7.0 kB)
  Using cached langchain_google_genai-2.1.5-py3-none-any.whl.metadata (5.2 kB)
  Using cached langchain_google_genai-2.1.4-py3-none-any.whl.metadata (5.2 kB)
  Using cached langchain_google_genai-2.1.3-py3-none-any.whl.metadata (4.7 kB)
  Using cached langchain_google_genai-2.1.2-py3-none-any.whl.metadata (4.7 kB)
  Using cached langchain_google_genai-2.1.1-py3-none-any.whl.metadata (4.7 kB)
  Using cached langchain_google_genai-2.1.0-py3-none-any.whl.metadata (3.6 kB)
INFO: pip is still looking at multiple versions of langchain-google-genai to determine which version is compatible with other requirements. This could take a while.
  Using ca

In [2]:
# --- Load Environment Variables ---
# Ensure your .env file exists in the same directory as this script

import os
from dotenv import load_dotenv
import decimal # New import for Decimal type handling

load_dotenv()

True

In [3]:
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
db_user =  os.getenv("db_user")
db_password =  os.getenv("db_password")
db_host =  os.getenv("db_host")
db_port =  os.getenv("db_port")
db_name =  os.getenv("db_name")

In [4]:
import google.generativeai as genai

# --- Configure Gemini API ---
# This line configures the Gemini API with your API key.
# It's crucial for authenticating your requests to Google's generative models.
genai.configure(api_key=GOOGLE_API_KEY)

In [5]:
# --- Initialize SQLDatabase ---
# This connects to your MySQL database and extracts its schema information.
# The 'sample_rows_in_table_info=3' helps Gemini understand the data types and typical values.

from langchain.utilities import SQLDatabase

print("--- Connecting to MySQL Database and Fetching Schema ---")

try:
    db = SQLDatabase.from_uri(
        f"mysql+pymysql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}",
        sample_rows_in_table_info=3
    )
    database_schema = db.table_info
    print("Database Schema:\n", database_schema)
    print("---------------------------------------------------\n")
except Exception as e:
    print(f"Error connecting to database or fetching schema: {e}")
    print("Please ensure MySQL is running, your .env variables are correct, and pymysql is installed.")
    exit() # Exit if database connection fails, as further steps depend on it.


--- Connecting to MySQL Database and Fetching Schema ---
Database Schema:
 
CREATE TABLE discounts (
	discount_id INTEGER NOT NULL COMMENT 'Unique ID for each discount entry' AUTO_INCREMENT, 
	t_shirt_id INTEGER NOT NULL COMMENT 'Foreign key referencing the t-shirt being discounted', 
	pct_discount DECIMAL(5, 2) COMMENT 'Discount percentage (0-100)', 
	PRIMARY KEY (discount_id), 
	CONSTRAINT discounts_ibfk_1 FOREIGN KEY(t_shirt_id) REFERENCES t_shirts (t_shirt_id), 
	CONSTRAINT discounts_chk_1 CHECK ((`pct_discount` between 0 and 100))
)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci COMMENT='Table storing discount information for t-shirts'

/*
3 rows from discounts table:
discount_id	t_shirt_id	pct_discount
1	1	10.00
2	2	15.00
3	3	20.00
*/


CREATE TABLE t_shirts (
	t_shirt_id INTEGER NOT NULL COMMENT 'Unique ID for each t-shirt entry' AUTO_INCREMENT, 
	brand ENUM('Van Huesen','Levi','Nike','Adidas') NOT NULL COMMENT 'Brand of the t-shirt', 
	color ENUM('Red','Blue','

In [6]:
# --- List Available Gemini Models ---
# It's good practice to list models to confirm which ones support 'generateContent'
# in your region and for your API key.
print("Available models supporting 'generateContent':")
available_models = []
for m in genai.list_models():
    if 'generateContent' in m.supported_generation_methods:
        available_models.append(m.name)
        print(f"  - {m.name}")

if not available_models:
    print("\nNo models found that support 'generateContent'. Please check your API key and region availability.")
    exit()

Available models supporting 'generateContent':
  - models/gemini-1.0-pro-vision-latest
  - models/gemini-pro-vision
  - models/gemini-1.5-pro-latest
  - models/gemini-1.5-pro-002
  - models/gemini-1.5-pro
  - models/gemini-1.5-flash-latest
  - models/gemini-1.5-flash
  - models/gemini-1.5-flash-002
  - models/gemini-1.5-flash-8b
  - models/gemini-1.5-flash-8b-001
  - models/gemini-1.5-flash-8b-latest
  - models/gemini-2.5-pro-preview-03-25
  - models/gemini-2.5-flash-preview-04-17
  - models/gemini-2.5-flash-preview-05-20
  - models/gemini-2.5-flash
  - models/gemini-2.5-flash-preview-04-17-thinking
  - models/gemini-2.5-flash-lite-preview-06-17
  - models/gemini-2.5-pro-preview-05-06
  - models/gemini-2.5-pro-preview-06-05
  - models/gemini-2.5-pro
  - models/gemini-2.0-flash-exp
  - models/gemini-2.0-flash
  - models/gemini-2.0-flash-001
  - models/gemini-2.0-flash-exp-image-generation
  - models/gemini-2.0-flash-lite-001
  - models/gemini-2.0-flash-lite
  - models/gemini-2.0-flash-p

In [7]:
# --- Initialize Gemini Model ---
# We'll use 'gemini-1.5-flash' as it's often a good balance of speed and capability for free tier.
# You can change this to another model name from the 'Available models' list if needed.
model_name_to_use = 'gemini-1.5-flash' 

In [8]:
# Verify the chosen model is in the available list
if f"models/{model_name_to_use}" not in available_models and model_name_to_use not in available_models:
    print(f"\nWarning: The chosen model '{model_name_to_use}' is not in the list of available models.")
    print("Attempting to use it anyway, but it might result in an error.")
    print("Consider changing 'model_name_to_use' to one from the list above.")
else:
    model = genai.GenerativeModel(model_name_to_use)
    print("Model initialised 😄\nGood to go !!!")

Model initialised 😄
Good to go !!!


In [51]:
# --- Example 1: Basic Text Generation with Gemini ---
# This shows how Gemini responds to a general question.
prompt_general = "Who is the Prime minister of India?"
print(f"\n--- Asking Gemini (General Question): '{prompt_general}' ---")
try:
    response_general = model.generate_content(prompt_general)
    print(response_general.text.strip())
except Exception as e:
    print(f"Error during general text generation: {e}")
print("---------------------------------------------------\n")


--- Asking Gemini (General Question): 'Who is the Prime minister of India?' ---
Error during general text generation: 429 You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. [violations {
  quota_metric: "generativelanguage.googleapis.com/generate_content_free_tier_requests"
  quota_id: "GenerateRequestsPerDayPerProjectPerModel-FreeTier"
  quota_dimensions {
    key: "model"
    value: "gemini-1.5-flash"
  }
  quota_dimensions {
    key: "location"
    value: "global"
  }
  quota_value: 50
}
, links {
  description: "Learn more about Gemini API quotas"
  url: "https://ai.google.dev/gemini-api/docs/rate-limits"
}
, retry_delay {
  seconds: 25
}
]
---------------------------------------------------



In [19]:
import pymysql

# ================================================================
# --- NEW FUNCTION FOR SQL GENERATION AND EXECUTION ---
# ================================================================

def ask_sql_query(user_question: str, gemini_model, sql_database_obj):
    """
    Generates a MySQL query from a natural language question using Gemini,
    executes it against the provided database, and returns the results.

    Args:
        user_question (str): The natural language question from the user.
        gemini_model: The initialized Gemini GenerativeModel instance.
        sql_database_obj: The initialized langchain.utilities.SQLDatabase object.

    Returns:
        list: A list of tuples containing the query results.
              Returns 0 if it's an aggregate query yielding no results (SUM/COUNT=NULL or empty).
              Returns None if an error occurs.
    """
    print(f"\n--- Processing user question: '{user_question}' ---")

    # Construct the prompt for Gemini
    sql_generation_prompt = f"""
Given the following MySQL database schema:

{sql_database_obj.table_info}

Generate a valid MySQL query for the following question:
"{user_question}"

When generating the query, ensure that:
1. Column names from the schema are used exactly as provided.
2. String comparisons (like 'Nike', 'extra small', 'white') should match the case and format
   of data as it is actually stored in the database. If exact matches are needed,
   use '='. If a case-insensitive match might be better, consider functions like LOWER().
3. For questions asking "how many" or "total", generate a query that returns a single aggregate value (e.g., SUM, COUNT).
   If no matching rows are found, the query should logically yield 0 or NULL for that aggregate.

Provide only the SQL query, without any additional explanation or text.
"""

    # print("--- Prompt sent to Gemini ---")
    # print(sql_generation_prompt)
    # print("---------------------------\n")

    generated_sql = None
    try:
        response_sql = gemini_model.generate_content(sql_generation_prompt)
        generated_sql = response_sql.text.strip()

        # Remove Markdown code block delimiters
        if generated_sql.startswith("```sql"):
            generated_sql = generated_sql[len("```sql"):].strip()
        if generated_sql.endswith("```"):
            generated_sql = generated_sql[:-len("```")].strip()
            
        print("--- Generated SQL Query (cleaned) ---")
        print(generated_sql)
        print("-----------------------------------\n")

    except Exception as e:
        print(f"Error during SQL generation by Gemini: {e}")
        return None # Indicate failure

    # Execute the generated SQL query
    print("--- Executing SQL Query and Fetching Results ---")
    try:
        conn = pymysql.connect(
            host=db_host,
            user=db_user,
            password=db_password,
            database=db_name,
            port=int(db_port)
        )
        cursor = conn.cursor()
        cursor.execute(generated_sql)
        results = cursor.fetchall()
        
        if cursor.description:
            columns = [i[0] for i in cursor.description]
            print("Columns:", columns)
        
        # --- ENHANCED HANDLING FOR EMPTY/NULL/SINGLE AGGREGATE RESULTS ---
        if not results or (len(results) == 1 and results[0][0] is None):
            print("Query Results: 0 (or no matching data)")
            return 0 # Return 0 for aggregate queries with no results
        elif len(results) == 1 and len(results[0]) == 1 and isinstance(results[0][0], (int, float, decimal.Decimal)):
            # If it's a single aggregate value (like COUNT(*), SUM), return the scalar directly
            print(f"Query Result: {results[0][0]}")
            return results[0][0]
        else:
            print("Query Results:")
            for row in results:
                print(row)
            return results # Return the actual results (list of tuples for multi-row/col)

    except pymysql.Error as db_error:
        print(f"Error executing SQL query: {db_error}")
        return None # Indicate failure

    finally:
        if 'conn' in locals() and conn.open: # Ensure connection is closed even on error
            conn.close()

In [21]:
# This is the core logic for SpeakSQL project.
# Test Query 1: Your original example
q1 = "How many tshirts do we have left for Nike in extra small size and red colour?"
result1 = ask_sql_query(q1, model, db)
print(f"\nFinal Answer for '{q1}': {result1}")

print("\n---------------------------------------------------\n")


--- Processing user question: 'How many tshirts do we have left for Nike in extra small size and red colour?' ---
Error during SQL generation by Gemini: 429 You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. [violations {
  quota_metric: "generativelanguage.googleapis.com/generate_content_free_tier_requests"
  quota_id: "GenerateRequestsPerDayPerProjectPerModel-FreeTier"
  quota_dimensions {
    key: "model"
    value: "gemini-1.5-flash"
  }
  quota_dimensions {
    key: "location"
    value: "global"
  }
  quota_value: 50
}
, links {
  description: "Learn more about Gemini API quotas"
  url: "https://ai.google.dev/gemini-api/docs/rate-limits"
}
, retry_delay {
  seconds: 55
}
]

Final Answer for 'How many tshirts do we have left for Nike in extra small size and red colour?': None

---------------------------------------------------



In [23]:
q2 = "If we have to sell all the Levi's tshirt today. How much revenue will our store generate?"
result2 = ask_sql_query(q2, model, db)
print(f"\nFinal Answer for '{q2}': {result2}")


--- Processing user question: 'If we have to sell all the Levi's tshirt today. How much revenue will our store generate?' ---
Error during SQL generation by Gemini: 429 You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. [violations {
  quota_metric: "generativelanguage.googleapis.com/generate_content_free_tier_requests"
  quota_id: "GenerateRequestsPerDayPerProjectPerModel-FreeTier"
  quota_dimensions {
    key: "model"
    value: "gemini-1.5-flash"
  }
  quota_dimensions {
    key: "location"
    value: "global"
  }
  quota_value: 50
}
, links {
  description: "Learn more about Gemini API quotas"
  url: "https://ai.google.dev/gemini-api/docs/rate-limits"
}
, retry_delay {
  seconds: 42
}
]

Final Answer for 'If we have to sell all the Levi's tshirt today. How much revenue will our store generate?': None


In [25]:
q3 = "Show me the total number of products in stock."
result3 = ask_sql_query(q3, model, db)
print(f"\nFinal Answer for '{q3}': {result3}")


--- Processing user question: 'Show me the total number of products in stock.' ---
Error during SQL generation by Gemini: 429 You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. [violations {
  quota_metric: "generativelanguage.googleapis.com/generate_content_free_tier_requests"
  quota_id: "GenerateRequestsPerDayPerProjectPerModel-FreeTier"
  quota_dimensions {
    key: "model"
    value: "gemini-1.5-flash"
  }
  quota_dimensions {
    key: "location"
    value: "global"
  }
  quota_value: 50
}
, links {
  description: "Learn more about Gemini API quotas"
  url: "https://ai.google.dev/gemini-api/docs/rate-limits"
}
, retry_delay {
  seconds: 41
}
]

Final Answer for 'Show me the total number of products in stock.': None


In [27]:
q4 = "How much is the price of inventory for all small sized tshirts?"
result4 = ask_sql_query(q4, model, db)
print(f"\nFinal Answer for '{q4}': {result4}")


--- Processing user question: 'How much is the price of inventory for all small sized tshirts?' ---
Error during SQL generation by Gemini: 429 You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. [violations {
  quota_metric: "generativelanguage.googleapis.com/generate_content_free_tier_requests"
  quota_id: "GenerateRequestsPerDayPerProjectPerModel-FreeTier"
  quota_dimensions {
    key: "model"
    value: "gemini-1.5-flash"
  }
  quota_dimensions {
    key: "location"
    value: "global"
  }
  quota_value: 50
}
, links {
  description: "Learn more about Gemini API quotas"
  url: "https://ai.google.dev/gemini-api/docs/rate-limits"
}
, retry_delay {
  seconds: 40
}
]

Final Answer for 'How much is the price of inventory for all small sized tshirts?': None


In [29]:
q5 = "If we have to sell all the Levi's tshirt with discounts applied. How much revenue will our store generate? (post discounts)"
# q5 = """SELECT SUM(t.stock_quantity * IFNULL(t.price * (1 - d.pct_discount / 100), t.price)) as Total_Revenue
#         FROM t_shirts AS t
#         LEFT JOIN discounts AS d
#           ON t.t_shirt_id = d.t_shirt_id
#         WHERE t.brand = 'Levi';"""

# OR 

# q5 = """SELECT SUM(a.Total_Amount * ((100- COALESCE(discounts.pct_discount, 0))/100)) as Total_Revenue
#         FROM (SELECT SUM(price*stock_quantity) as Total_Amount, t_shirt_id from t_shirts
#             	WHERE brand = 'Levi'
#             	GROUP BY t_shirt_id) a
#         LEFT JOIN discounts 
#         ON a.t_shirt_id = discounts.t_shirt_id;"""

# OR

# WITH TotalAmounts AS (
#     SELECT 
#         t_shirt_id,
#         SUM(price * stock_quantity) AS Total_Amount
#     FROM t_shirts
#     WHERE brand = 'Levi'
#     GROUP BY t_shirt_id
# )

# SELECT 
#     SUM(TotalAmounts.Total_Amount * ((100 - COALESCE(discounts.pct_discount, 0)) / 100)) AS Total_Revenue
# FROM TotalAmounts
# LEFT JOIN discounts 
# ON TotalAmounts.t_shirt_id = discounts.t_shirt_id;


result5 = ask_sql_query(q5, model, db)
print(f"\nFinal Answer for '{q5}': {result5}")


--- Processing user question: 'If we have to sell all the Levi's tshirt with discounts applied. How much revenue will our store generate? (post discounts)' ---
Error during SQL generation by Gemini: 429 You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. [violations {
  quota_metric: "generativelanguage.googleapis.com/generate_content_free_tier_requests"
  quota_id: "GenerateRequestsPerDayPerProjectPerModel-FreeTier"
  quota_dimensions {
    key: "model"
    value: "gemini-1.5-flash"
  }
  quota_dimensions {
    key: "location"
    value: "global"
  }
  quota_value: 50
}
, links {
  description: "Learn more about Gemini API quotas"
  url: "https://ai.google.dev/gemini-api/docs/rate-limits"
}
, retry_delay {
  seconds: 39
}
]

Final Answer for 'If we have to sell all the Levi's tshirt with discounts applied. How much revenue will our store generate? (post discounts)'

In [31]:
q6 = "How many White coloured Levi's tshirts do we have in each size?"
result6 = ask_sql_query(q6, model, db)
print(f"\nFinal Answer for '{q6}': {result6}")


--- Processing user question: 'How many White coloured Levi's tshirts do we have in each size?' ---
Error during SQL generation by Gemini: 429 You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. [violations {
  quota_metric: "generativelanguage.googleapis.com/generate_content_free_tier_requests"
  quota_id: "GenerateRequestsPerDayPerProjectPerModel-FreeTier"
  quota_dimensions {
    key: "model"
    value: "gemini-1.5-flash"
  }
  quota_dimensions {
    key: "location"
    value: "global"
  }
  quota_value: 50
}
, links {
  description: "Learn more about Gemini API quotas"
  url: "https://ai.google.dev/gemini-api/docs/rate-limits"
}
, retry_delay {
  seconds: 38
}
]

Final Answer for 'How many White coloured Levi's tshirts do we have in each size?': None


## Few Shot Learning

In [33]:
# few_shots = [
#     {
#         "Question":"How many tshirts do we have left for Nike in extra small size and red colour?",
#         "SQLQuery":"""SELECT SUM(stock_quantity)
#                       FROM t_shirts
#                       WHERE brand = 'Nike' AND size = 'XS' AND color = 'Red';""",
#         "SQLResult":"Result of the SQL Query",
#         "Answer":"83"
#     },
#     {
#         "Question":"How much is the price of inventory for all small sized tshirts?",
#         "SQLQuery":"""SELECT SUM(price * stock_quantity)
#                       FROM t_shirts
#                       WHERE size = 'S';""",
#         "SQLResult":"Result of the SQL Query",
#         "Answer":"13924"
#     },
#     {
#         "Question":"If we have to sell all the Levi's tshirt with discounts applied. How much revenue will our store generate? (post discounts)",
#         "SQLQuery":"""SELECT SUM(a.Total_Amount * ((100- COALESCE(discounts.pct_discount, 0))/100)) as Total_Revenue
#                       FROM (SELECT SUM(price*stock_quantity) as Total_Amount, t_shirt_id from t_shirts
#                           	WHERE brand = 'Levi'
#                           	GROUP BY t_shirt_id) a
#                       LEFT JOIN discounts 
#                       ON a.t_shirt_id = discounts.t_shirt_id""",
#         "SQLResult":"Result of the SQL Query",
#         "Answer":"30972.450000"
#     },
#     {
#         "Question":"If we have to sell all the Levi's tshirt today. How much revenue will our store generate?",
#         "SQLQuery":"""SELECT SUM(t_shirts.price * t_shirts.stock_quantity)
#                       FROM t_shirts
#                       WHERE t_shirts.brand = 'Levi'; """,
#         "SQLResult":"Result of the SQL Query",
#         "Answer":"32170"
#     },
#     {
#         "Question":"How many White coloured Levi's tshirts do we have in each size?",
#         "SQLQuery":"""SELECT SUM(stock_quantity)
#                       FROM t_shirts
#                       WHERE brand = 'Levi' AND color = 'White'; """,
#         "SQLResult":"Result of the SQL Query",
#         "Answer":"369"
#     },
#     {
#         "Question": "How many Adidas t-shirts are available in total?",
#         "SQLQuery": """SELECT SUM(stock_quantity)
#                        FROM t_shirts
#                        WHERE brand = 'Adidas';""",
#         "SQLResult": "Result of the SQL Query",
#         "Answer": "758"
#     },
#     {
#         "Question": "What is the total inventory value of all black coloured t-shirts?",
#         "SQLQuery": """SELECT SUM(price * stock_quantity)
#                        FROM t_shirts
#                        WHERE color = 'Black';""",
#         "SQLResult": "Result of the SQL Query",
#         "Answer": "22275"
#     },
#     {
#         "Question": "How many Van Huesen t-shirts do we have in medium size?",
#         "SQLQuery": """SELECT SUM(stock_quantity)
#                        FROM t_shirts
#                        WHERE brand = 'Van Huesen' AND size = 'M';""",
#         "SQLResult": "Result of the SQL Query",
#         "Answer": "89"
#     },
#     {
#         "Question": "What is the total revenue if we sell all Adidas t-shirts with discounts applied?",
#         "SQLQuery": """SELECT SUM(a.Total_Amount * ((100 - COALESCE(discounts.pct_discount, 0)) / 100)) AS Total_Revenue
#                        FROM (SELECT SUM(price * stock_quantity) AS Total_Amount, t_shirt_id 
#                              FROM t_shirts 
#                              WHERE brand = 'Adidas'
#                              GROUP BY t_shirt_id) a
#                        LEFT JOIN discounts 
#                        ON a.t_shirt_id = discounts.t_shirt_id;""",
#         "SQLResult": "Result of the SQL Query",
#         "Answer": "19887.250000"
#     }
# ]

In [35]:
# ================================================================
# --- Load Few-Shot Examples from JSON File ---
# ================================================================

import json

few_shots_file_path = 'few_shots.json'

try:
    with open(few_shots_file_path, 'r') as f:
        few_shot_examples = json.load(f)
    print(f"\n--- Loaded {len(few_shot_examples)} few-shot examples from '{few_shots_file_path}' ---")
except FileNotFoundError:
    print(f"\nError: '{few_shots_file_path}' not found. Please create the file with your few-shot examples.")
    few_shot_examples = [] # Initialize as empty to avoid errors
except json.JSONDecodeError as e:
    print(f"\nError decoding JSON from '{few_shots_file_path}': {e}")
    print("Please ensure your few_shots.json file is valid JSON.")
    few_shot_examples = []


--- Loaded 9 few-shot examples from 'few_shots.json' ---


In [37]:
from langchain_huggingface import HuggingFaceEmbeddings
import numpy as np

print("--- Initializing HuggingFace Embedding Model ---")
try:
    # Using the updated HuggingFaceEmbeddings from langchain_huggingface
    embedding_model_name = 'sentence-transformers/all-MiniLM-L6-v2'
    embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
    print(f"HuggingFace Embedding Model '{embedding_model_name}' loaded successfully.")
except Exception as e:
    print(f"Error loading HuggingFace embedding model: {e}")
    print("Please ensure 'langchain-huggingface' is installed: pip install -U langchain-huggingface")
    # It's good practice to exit or handle this error gracefully if embeddings are critical
    exit() 

--- Initializing HuggingFace Embedding Model ---
HuggingFace Embedding Model 'sentence-transformers/all-MiniLM-L6-v2' loaded successfully.


In [39]:
from langchain.vectorstores import Chroma # Import the Chroma vector store class.

# Prepare the text data for vectorization.
# It concatenates all string values from each dictionary in few_shot_examples into a single string for each example. This combined string will be embedded.
to_vectorize = ["".join(example.values()) for example in few_shot_examples]

# Create a Chroma vector store from the prepared text data.
# - `to_vectorize`: The list of strings to be embedded and stored.
# - `embedding`: The embedding model (HuggingFaceEmbeddings instance) used to convert text to vectors.
# - `metadatas`: A list of dictionaries (your original few_shot_examples) that will be associated with each embedded text. This allows you to retrieve the original question, SQL query, etc., when a similar query is made.
vectorstore = Chroma.from_texts(to_vectorize, embedding=embeddings, metadatas=few_shot_examples)

In [41]:
from langchain.prompts import SemanticSimilarityExampleSelector

# Initialize the example selector.
# - `vectorstore`: The Chroma vector store containing your embedded few-shot examples.
# - `k`: The number of most similar examples to retrieve.
example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore, 
    k=2
)

In [43]:
# Select examples based on a new question.
# It finds the 'k' most semantically similar examples from the vectorstore.
example_selector.select_examples({"Question":"How many VanHeusen tshirts I have in my store?"})

[{'SQLResult': ' Result of the SQL Query - ',
  'Question': 'How many tshirts do we have left for Nike in extra small size and red colour?',
  'Answer': '83',
  'SQLQuery': "SELECT SUM(stock_quantity) FROM t_shirts WHERE brand = 'Nike' AND size = 'XS' AND color = 'Red';"},
 {'SQLResult': ' Result of the SQL Query - ',
  'Answer': '83',
  'SQLQuery': "SELECT SUM(stock_quantity) FROM t_shirts WHERE brand = 'Nike' AND size = 'XS' AND color = 'Red';",
  'Question': 'How many tshirts do we have left for Nike in extra small size and red colour?'}]

## Adding the prompts that Langchain already provides.

In [44]:
# Import predefined prompt parts for SQL chains.
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt 

In [45]:
print(_mysql_prompt, PROMPT_SUFFIX)

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of

In [46]:
from langchain.prompts.prompt import PromptTemplate # Import the PromptTemplate class.

# Define a template for formatting each individual few-shot example.
# It specifies how a single (Question, SQLQuery, SQLResult, Answer) tuple
# should be presented within the overall prompt to the LLM.
example_prompt = PromptTemplate(
    input_variables=["Question", "SQLQuery", "SQLResult", "Answer"],
    template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\nAnswer: {Answer}"
)

In [47]:
from langchain.prompts import FewShotPromptTemplate # Import the FewShotPromptTemplate class.

# Create a FewShotPromptTemplate instance.
few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector, # Uses the semantic similarity selector to pick examples.
    example_prompt=example_prompt,     # Defines how each selected example is formatted.
    prefix=_mysql_prompt,                     # `_mysql_prompt` typically sets up the context for MySQL SQL generation.
    suffix=PROMPT_SUFFIX,                    # `PROMPT_SUFFIX` usually contains the placeholder for the user's new question.
    input_variables=["input", "table_info", "top_k"], # Variables expected by the final prompt.
    # `input`: The user's new natural language question.
    # `table_info`: The database schema information.
    # `top_k`: (Optional) A parameter for top-k results in SQL.
    # template = "" # Not needed if using prefix/suffix and example_selector
)


In [48]:
from langchain_google_genai import ChatGoogleGenerativeAI # For LangChain's Gemini LLM wrapper

llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0, google_api_key=GOOGLE_API_KEY)

In [49]:
from langchain_experimental.sql import SQLDatabaseChain

# Create the chain using your LLM, database connection, and the few-shot prompt.
# Setting verbose=True is great for debugging, as it shows you the intermediate steps.
sql_chain = SQLDatabaseChain.from_llm(
    llm=llm,
    db=db,
    verbose=True,
    prompt=few_shot_prompt
)

In [50]:
# Test Query 1: Original example
q1 = "How many tshirts do we have left for Nike in extra small size and white colour?"
print(f"\nQuestion 1: {q1}")
try:
    response1 = sql_chain.invoke(q1)
    print(f"Final Answer for '{q1}': {response1}")
except Exception as e:
    print(f"Error running chain for '{q1}': {e}")

print("\n---------------------------------------------------\n")


Question 1: How many tshirts do we have left for Nike in extra small size and white colour?


[1m> Entering new SQLDatabaseChain chain...[0m
How many tshirts do we have left for Nike in extra small size and white colour?
SQLQuery:

Retrying langchain_google_genai.chat_models._chat_with_retry.<locals>._chat_with_retry in 2.0 seconds as it raised ResourceExhausted: 429 You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. [violations {
  quota_metric: "generativelanguage.googleapis.com/generate_content_free_tier_requests"
  quota_id: "GenerateRequestsPerDayPerProjectPerModel-FreeTier"
  quota_dimensions {
    key: "model"
    value: "gemini-1.5-flash"
  }
  quota_dimensions {
    key: "location"
    value: "global"
  }
  quota_value: 50
}
, links {
  description: "Learn more about Gemini API quotas"
  url: "https://ai.google.dev/gemini-api/docs/rate-limits"
}
, retry_delay {
  seconds: 28
}
].


Error running chain for 'How many tshirts do we have left for Nike in extra small size and white colour?': 429 You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. [violations {
  quota_metric: "generativelanguage.googleapis.com/generate_content_free_tier_requests"
  quota_id: "GenerateRequestsPerDayPerProjectPerModel-FreeTier"
  quota_dimensions {
    key: "model"
    value: "gemini-1.5-flash"
  }
  quota_dimensions {
    key: "location"
    value: "global"
  }
  quota_value: 50
}
, links {
  description: "Learn more about Gemini API quotas"
  url: "https://ai.google.dev/gemini-api/docs/rate-limits"
}
, retry_delay {
  seconds: 25
}
]

---------------------------------------------------

