In [3]:
from langchain_community.llms import Ollama

llm__ollama = Ollama(
    model="llama3.2"
)  # assuming you have Ollama installed and have llama3 model pulled with `ollama pull llama3 `

llm__ollama.invoke("generate a sql query which can count the total rows") 

  llm__ollama = Ollama(


"Here is an example of SQL queries that can be used to count the total number of rows in a table:\n\n**Method 1: Using COUNT() function**\n\n```sql\nSELECT COUNT(*) AS total_rows FROM table_name;\n```\n\nIn this method, `COUNT(*)` returns the total number of rows in the specified table. The `AS` keyword is used to give an alias to the column.\n\n**Method 2: Using GROUP BY clause**\n\n```sql\nSELECT COUNT(*) AS total_rows FROM table_name GROUP BY null;\n```\n\nThis method uses the `GROUP BY` clause with a grouping column that doesn't affect the count, i.e., `null`. This is a common technique when you don't want to specify any conditions.\n\n**Method 3: Using LIMIT**\n\n```sql\nSELECT COUNT(*) AS total_rows FROM table_name LIMIT 0;\n```\n\nIn this method, we use `LIMIT 0` to retrieve no rows from the table. The `COUNT(*)` function still returns the total number of rows.\n\nNote that these methods assume you are using a SQL dialect that supports these features."

In [4]:
from langchain_community.utilities import SQLDatabase
import urllib.parse
# from sqlalchemy import create_engine
quoted_password = urllib.parse.quote("Prachi@28")

snowflake_db = SQLDatabase.from_uri('snowflake://{user}:{password}@{account_identifier}/SNOWFLAKE_SAMPLE_DATA/TPCDS_SF100TCL'.format(
        user='PRACHIBH',
        password=quoted_password,
        account_identifier='dmhwtcd-nt69450',
    ))

# print(db.dialect)
# print(db.get_usable_table_names())
# db.run("SELECT * FROM Artist LIMIT 10;")

# construct sqlalchemy engine from uri 
# error : full path need to be mentioned => so, full path as in till the db with .db included ?  

In [5]:
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS

from typing import Any, Dict, List, Optional, Union, TypedDict, TYPE_CHECKING

if TYPE_CHECKING:
    from langchain_community.utilities.sql_database import SQLDatabase

def _strip(text: str) -> str:
    print("insisde _strip")
    return text.strip()

class SQLInput(TypedDict):
    """Input for a SQL Chain."""
    question: str

class SQLInputWithTables(TypedDict):
    """Input for a SQL Chain."""
    question: str
    table_names_to_use: List[str]

def create_sql_query_chain(
    llm: BaseLanguageModel,
    db: SQLDatabase,
    prompt: Optional[BasePromptTemplate] = None,
    k: int = 1,
) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]:

    if prompt is not None:
        prompt_to_use = prompt
    elif db.dialect in SQL_PROMPTS:
        prompt_to_use = SQL_PROMPTS[db.dialect]
    else:
        prompt_to_use = PROMPT

    if {"input", "top_k", "table_info"}.difference(
        prompt_to_use.input_variables + list(prompt_to_use.partial_variables)
    ):
        raise ValueError(
            f"Prompt must have input variables: 'input', 'top_k', "
            f"'table_info'. Received prompt with input variables: "
            f"{prompt_to_use.input_variables}. Full prompt:\n\n{prompt_to_use}"
        )
    if "dialect" in prompt_to_use.input_variables:
        prompt_to_use = prompt_to_use.partial(dialect=db.dialect)

    inputs = {
        "input": lambda x: x["question"] + "\nSQLQuery: ",
        "table_info": lambda x: db.get_table_info(
            table_names=x.get("table_names_to_use")
        ),
    }

    return (
        RunnablePassthrough.assign(**inputs)  # type: ignore
        | prompt_to_use.partial(top_k=str(k))
        | llm.bind(stop=["\nSQLResult:"])
        | StrOutputParser()
        | _strip
    )

# Example usage:
# Make sure to replace `llm__ollama` and `db` with your actual language model and database objects.



In [6]:
chain = create_sql_query_chain(llm__ollama,snowflake_db)
response = chain.invoke({
    "question" : "extract male count from CUSTOMER_DEMOGRAPHICS where CD_CREDIT_RATING are good",
    "table_names_to_use": ['customer_demographics']
})
print(response)

insisde _strip
Question: extract male count from CUSTOMER_DEMOGRAPHICS where CD_CREDIT_RATING are good
SQLQuery: 
SELECT cd_dep_employed_count 
FROM customer_demographics 
WHERE cd_gender = 'M' AND cd_credit_rating IN ('Good', 'Excellent')
LIMIT 1;


In [27]:
import re

# Define the regex pattern
pattern = r"SQLQuery:\s*((?:.|\n)*?)(;|$)"

# Search for the pattern in the text
match = re.search(pattern, response)

# Extract the SQL query if the pattern is found
if match:
    sql_query = match.group(1)
else:
    print("SQLQuery not found") 

sql_query = sql_query + ';' 
print(f'sql_query : {sql_query}')

sql_query : SELECT cd_dep_employed_count 
FROM customer_demographics 
WHERE cd_gender = 'M' AND cd_credit_rating IN ('Good', 'Excellent')
LIMIT 1;


In [28]:
import re
from typing import Optional
from langchain_core.runnables import Runnable

class ExecuteDB(Runnable):
    def __init__(self, db: SQLDatabase, response : str):
        self.db = db
        self.response = response

    def invoke(self) -> Optional[str]:
        pattern = r"SQLQuery:\s*(.*)"
        match = re.search(pattern, self.response)
        if match:
            sql_query = match.group(1)
            return self.db.run("SELECT cd_dep_employed_count FROM customer_demographics WHERE cd_gender = 'M' AND cd_credit_rating = 'Good' ;")
        else:
            print("SQL not found")
            return None

def execurteDB(db: SQLDatabase, response: str) -> Runnable:
    return ExecuteDB(db, response)

write_query = create_sql_query_chain(llm__ollama,snowflake_db)
execute_query = execurteDB(snowflake_db, response)
chain_2 = write_query | execute_query  
response = chain_2.invoke({
    "question" : "extract male count from CUSTOMER_DEMOGRAPHICS where CD_CREDIT_RATING are good",
    "table_names_to_use": ['customer_demographics']
})


insisde _strip


TypeError: ExecuteDB.invoke() takes 1 positional argument but 3 were given