In [1]:
import chromadb
import dspy
from sqlalchemy import create_engine, text
import os
from dotenv import load_dotenv
from train_set import train_data

load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
lm = dspy.LM('groq/qwen-2.5-32b', api_key=os.getenv('GROQ_API_KEY'))
dspy.configure(lm=lm)

In [3]:
chroma_client = chromadb.PersistentClient(path="./chroma_db")
db_collection = chroma_client.get_collection(name="sql_schema")

In [4]:
def create_db_connection():
    engine = create_engine("sqlite:///electrical_parts.db", echo=True)
    return engine.connect()

def execute_sql(query: str):
    """Executes the SQL query in SQLite and fetches results."""
    try:
        # validate_sql_query(query)
        with create_db_connection() as conn:
            result = conn.execute(text(query)).fetchall()
        return result
    except Exception as e:
        return {"error": str(e), "valid": False}

In [5]:
class RetrieveSchema(dspy.Module):
    def forward(self, user_query: str):
        """Retrieves relevant schema details from ChromaDB."""
        table_results = db_collection.query(
            query_texts=[user_query], n_results=3, where={"type": "table"}
        )
        tables = [doc["table_name"] for doc in table_results.get("metadatas", [])[0]]

        column_results = db_collection.query(
            query_texts=[user_query],
            n_results=3,
            where={"$and": [{"type": {"$eq": "column"}}, {"table": {"$in": tables}}]},
        )
        columns = [
            (doc["table"], doc["columns"])
            for doc in column_results.get("metadatas", [])[0]
        ]

        relationship_results = db_collection.query(
            query_texts=[user_query], n_results=3, where={"type": "relationship"}
        )
        relationships = [
            (doc["table1"], doc["table2"], doc["relationship_type"])
            for doc in relationship_results.get("metadatas", [])[0]
        ]

        # pprint({"tables": tables, "columns": columns, "relationships": relationships})
        return {"tables": tables, "columns": columns, "relationships": relationships}


In [15]:
class GenerateSQL(dspy.Signature):
    """Generate appropriate response to the user's question about the database.
        Dont just give tabular data, instead give in a meaningful polite sentence format with proper result and not just [result]
    For SELECT queries: Generate SQL and answer.
    For other operations: Return standard warning message."""
    
    question: str = dspy.InputField()
    context: str = dspy.InputField()
    history: list[str] = dspy.InputField(default=[])
    sql_query: str = dspy.OutputField(desc="Empty if operation not allowed")
    answer: str = dspy.OutputField()

sql_query_generator = dspy.ReAct(GenerateSQL, tools=[execute_sql])

In [16]:
class SQLAgent(dspy.Module):
    def __init__(self):
        super().__init__()
        self.retrieve = RetrieveSchema()
        self.react = sql_query_generator
        
    def forward(self, question, context=None, history=None):
        history = history or []
        if context is None:
            context = self.retrieve(question)
            
        response = self.react(question=question, context=context, history=history)
        
        # ReAct returns the full trace, we need to extract the final prediction
        if hasattr(response, 'answer'):
            return dspy.Prediction(
                answer=response.answer,
                sql_query=getattr(response, 'sql_query', '')
            )
        return dspy.Prediction(answer="Error: No valid response generated", sql_query="")

In [17]:
def validate_prediction(example, pred, trace=None):
    try:
        # For non-SELECT operations
        if not example.sql_query:
            return pred.answer.startswith("Sorry, but you are not allowed")
        
        # For SELECT queries
        if not hasattr(pred, 'sql_query') or not pred.sql_query:
            return False
            
        return (example.sql_query.lower().strip() == pred.sql_query.lower().strip() and 
                example.answer.lower() in pred.answer.lower())
    except Exception as e:
        print(f"Validation error: {e}")
        return False

In [18]:
optimizer = dspy.BootstrapFewShot(
    metric=validate_prediction,
    max_bootstrapped_demos=8,
    max_labeled_demos=8,
    teacher_settings=dict(lm=dspy.LM('groq/qwen-2.5-32b', api_key=os.getenv('GROQ_API_KEY')))
)

agent = SQLAgent()

In [19]:
# Optimize
optimized_agent = optimizer.compile(
    agent, 
    trainset=train_data
)

 20%|██        | 1/5 [00:03<00:13,  3.25s/it]

2025-04-02 12:18:19,481 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2025-04-02 12:18:19,483 INFO sqlalchemy.engine.Engine SELECT product_id FROM products WHERE product_name = 'LED Light Bulb 10W';
2025-04-02 12:18:19,485 INFO sqlalchemy.engine.Engine [generated in 0.00376s] ()
2025-04-02 12:18:19,489 INFO sqlalchemy.engine.Engine ROLLBACK
2025-04-02 12:18:20,909 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2025-04-02 12:18:20,914 INFO sqlalchemy.engine.Engine PRAGMA table_info(products);
2025-04-02 12:18:20,916 INFO sqlalchemy.engine.Engine [generated in 0.00720s] ()
2025-04-02 12:18:20,920 INFO sqlalchemy.engine.Engine ROLLBACK
2025-04-02 12:18:22,379 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2025-04-02 12:18:22,380 INFO sqlalchemy.engine.Engine SELECT product_id FROM products WHERE name = 'LED Light Bulb 10W';
2025-04-02 12:18:22,382 INFO sqlalchemy.engine.Engine [generated in 0.00328s] ()
2025-04-02 12:18:22,385 INFO sqlalchemy.engine.Engine ROLLBACK


 80%|████████  | 4/5 [00:54<00:14, 14.36s/it]

2025-04-02 12:19:16,096 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2025-04-02 12:19:16,102 INFO sqlalchemy.engine.Engine DELETE FROM customers WHERE id NOT IN (SELECT customer_id FROM orders WHERE order_date >= DATE('now', '-6 months'));
2025-04-02 12:19:16,105 INFO sqlalchemy.engine.Engine [generated in 0.00879s] ()
2025-04-02 12:19:16,109 INFO sqlalchemy.engine.Engine ROLLBACK
2025-04-02 12:19:23,973 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2025-04-02 12:19:23,975 INFO sqlalchemy.engine.Engine DELETE FROM customers WHERE customer_id NOT IN (SELECT customer_id FROM orders WHERE order_date >= DATE('now', '-6 months'));
2025-04-02 12:19:23,976 INFO sqlalchemy.engine.Engine [generated in 0.00381s] ()
2025-04-02 12:19:23,983 INFO sqlalchemy.engine.Engine ROLLBACK


100%|██████████| 5/5 [01:40<00:00, 20.09s/it]

Bootstrapped 0 full traces after 4 examples for up to 1 rounds, amounting to 5 attempts.





In [23]:
response = optimized_agent("Purchase the LED Light Bulb for me of quantity 5 for me")
print(response)

2025-04-02 12:23:51,601 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2025-04-02 12:23:51,602 INFO sqlalchemy.engine.Engine SELECT product_id FROM products WHERE name = 'LED Light Bulb';
2025-04-02 12:23:51,602 INFO sqlalchemy.engine.Engine [generated in 0.00388s] ()
2025-04-02 12:23:51,602 INFO sqlalchemy.engine.Engine ROLLBACK
2025-04-02 12:23:53,098 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2025-04-02 12:23:53,098 INFO sqlalchemy.engine.Engine SELECT product_id FROM products WHERE name = 'LED Light Bulb';
2025-04-02 12:23:53,098 INFO sqlalchemy.engine.Engine [generated in 0.00392s] ()
2025-04-02 12:23:53,098 INFO sqlalchemy.engine.Engine ROLLBACK
2025-04-02 12:23:54,664 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2025-04-02 12:23:54,664 INFO sqlalchemy.engine.Engine INSERT INTO order_items (product_id, quantity) VALUES (1, 5);
2025-04-02 12:23:54,664 INFO sqlalchemy.engine.Engine [generated in 0.00300s] ()
2025-04-02 12:23:54,664 INFO sqlalchemy.engine.Engine ROLLBACK
Predi

In [12]:
only_agent = SQLAgent()
resp = only_agent("how many copper wires are there!")
resp.answer

2025-04-02 12:16:21,927 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2025-04-02 12:16:21,928 INFO sqlalchemy.engine.Engine SELECT COUNT(*) FROM products WHERE name = 'copper wire';
2025-04-02 12:16:21,929 INFO sqlalchemy.engine.Engine [generated in 0.00273s] ()
2025-04-02 12:16:21,932 INFO sqlalchemy.engine.Engine ROLLBACK


'There are no copper wires in the products table.'