In [2]:
from langchain_community.utilities import SQLDatabase
from langchain.sql_database import SQLDatabase
from langchain_groq import ChatGroq
import psycopg2
import os
from dotenv import load_dotenv

In [3]:
load_dotenv()

api_key = os.getenv('API_KEY')
database_url = os.getenv('DATABASE_URL')

In [4]:
# Setting the PostgreSQL URI using the loaded database URL
POSTGRES_URI = database_url

In [5]:
# Initializing the SQLDatabase utility for LangChain with the PostgreSQL URI
sql_db = SQLDatabase.from_uri(POSTGRES_URI)

In [6]:
print(sql_db.dialect)

postgresql


In [7]:
print(sql_db.get_usable_table_names())

['analysis', 'entities', 'locations', 'mine_sites', 'stockpile_analysis', 'surveys', 'users']


In [8]:
# Setting up the ChatGroq LLM with parameters and the API key
llm = ChatGroq(
    model="mixtral-8x7b-32768",
    api_key=api_key,  
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
)

In [9]:
from typing_extensions import TypedDict

# Defining the structure of the application state
class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

In [10]:
from langchain import hub

# Pulling a pre-defined prompt template for SQL query generation
query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

assert len(query_prompt_template.messages) == 1
query_prompt_template.messages[0].pretty_print()




Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Only use the following tables:
[33;1m[1;3m{table_info}[0m

Question: [33;1m[1;3m{input}[0m


In [11]:
from typing_extensions import Annotated


class QueryOutput(TypedDict):

    query: Annotated[str, ..., "Syntactically valid SQL query."]


# Function to generate SQL queries using the LLM
def write_query(state: State):
    
    prompt = query_prompt_template.invoke(
        {
            "dialect": sql_db.dialect,
            "top_k": 10,
            "table_info": sql_db.get_table_info(),
            "input": state["question"],
        }
    )
    
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {"query": result["query"]}

In [41]:
sql_db.get_table_info()

'\nCREATE TABLE analysis (\n\tid INTEGER NOT NULL, \n\tsurvey_id VARCHAR(255), \n\tversion VARCHAR(50) NOT NULL, \n\tstatus VARCHAR(50), \n\tcreated_by VARCHAR(255), \n\tcreated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, \n\tupdated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP\n)\n\n/*\n3 rows from analysis table:\nid\tsurvey_id\tversion\tstatus\tcreated_by\tcreated_at\tupdated_at\n1\t1\t1\tCompleted\tAdmin\t2024-12-23 10:24:59.518234+05:30\t2024-12-23 10:24:59.518234+05:30\n2\t2\t1\tCompleted\tAdmin\t2024-12-23 10:24:59.518234+05:30\t2024-12-23 10:24:59.518234+05:30\n3\t3\t1\tCompleted\tAdmin\t2024-12-23 10:24:59.518234+05:30\t2024-12-23 10:24:59.518234+05:30\n*/\n\n\nCREATE TABLE entities (\n\tid INTEGER NOT NULL, \n\tname VARCHAR(255) NOT NULL, \n\tlocation_ids VARCHAR(255)[]\n)\n\n/*\n3 rows from entities table:\nid\tname\tlocation_ids\n1\tTata\tNone\n2\tAdani\tNone\n*/\n\n\nCREATE TABLE locations (\n\tid INTEGER NOT NULL, \n\tname VARCHAR(255) NOT NULL, \n\tad

In [12]:
# Connecting to PostgreSQL
def connect_to_db():
    conn = psycopg2.connect(POSTGRES_URI)  
    return conn

In [13]:
# Executing the SQL query
def execute_sql_query(query: str):
    conn = connect_to_db()
    cursor = conn.cursor()
    cursor.execute(query)
    result = cursor.fetchall()
    conn.close()
    return result

In [14]:
# Function to handle user questions, generate and execute SQL queries, and return answers
def query_agent(question: str):

    # Initialize state
    state = State(question=question, query="", result="", answer="")
    
    # Generating SQL query using the LLM
    query_info = write_query(state)
    query = query_info['query']
    
    # Executing the query
    result = execute_sql_query(query)
    
    # Update state with the query
    state["query"] = query

    # Store the result in the state
    state["result"] = str(result)
    
    state["answer"] = f"The result is: {result}"
    
    return state["answer"]

In [15]:
user_input = "what is the location id of Sagasahi Iron Ore Mine"  # Example question

In [16]:
print("User Input:", user_input)

User Input: what is the location id of Sagasahi Iron Ore Mine


In [17]:
write_query({"question": "what is the location id of Sagasahi Iron Ore Mine"})

{'query': "SELECT location_id FROM mine_sites WHERE name = 'Sagasahi Iron Ore Mine';"}

In [18]:
response = query_agent(user_input)

In [19]:
print("Agent Response:", response)

Agent Response: The result is: [(3,)]


In [20]:
user_input = "what are the total number of mines"  # Example question

In [21]:
print("User Input:", user_input)

User Input: what are the total number of mines


In [22]:
write_query({"question": "what are the total number of mines"})

{'query': 'SELECT COUNT(*) FROM mine_sites;'}

In [23]:
response = query_agent(user_input)

In [24]:
print("Agent Response:", response)

Agent Response: The result is: [(11,)]
