In [1]:
from langchain_community.llms import Ollama

In [2]:
llm = Ollama(model = "llama3.1")

In [3]:
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain

In [4]:
db = SQLDatabase.from_uri("postgresql://user-name:@localhost:5432/postgres")

In [5]:
db.get_usable_table_names()

['departments_table', 'employees_table']

In [20]:
chain = create_sql_query_chain(llm = llm, db = db)
database_description = (
    "The database consists of two tables: `public.employees_table` and `public.departments_table`. This is a PostgreSQL database, so you need to use postgres-related queries.\n\n"
    "The `public.employees_table` table records details about the employees in a company. It includes the following columns:\n"
    "- `EmployeeID`: A unique identifier for each employee.\n"
    "- `FirstName`: The first name of the employee.\n"
    "- `LastName`: The last name of the employee.\n"
    "- `DepartmentID`: A foreign key that links the employee to a department in the `public.departments_table` table.\n"
    "- `Salary`: The salary of the employee.\n\n"
    "The `public.departments_table` table contains information about the various departments in the company. It includes:\n"
    "- `DepartmentID`: A unique identifier for each department.\n"
    "- `DepartmentName`: The name of the department.\n"
    "- `Location`: The location of the department.\n\n"
    "The `DepartmentID` column in the `public.employees_table` table establishes a relationship between the employees and their respective departments in the `public.departments_table` table. This foreign key relationship allows us to join these tables to retrieve detailed information about employees and their departments."
)

response = chain.invoke({"question": database_description + " How many employees have salary above 70k?"})

In [21]:
response

'Question: The database consists of two tables: `public.employees_table` and `public.departments_table`. This is a PostgreSQL database, so you need to use postgres-related queries.\n\nThe `public.employees_table` table records details about the employees in a company. It includes the following columns:\n- `EmployeeID`: A unique identifier for each employee.\n- `FirstName`: The first name of the employee.\n- `LastName`: The last name of the employee.\n- `DepartmentID`: A foreign key that links the employee to a department in the `public.departments_table` table.\n- `Salary`: The salary of the employee.\n\nThe `public.departments_table` table contains information about the various departments in the company. It includes:\n- `DepartmentID`: A unique identifier for each department.\n- `DepartmentName`: The name of the department.\n- `Location`: The location of the department.\n\nThe `DepartmentID` column in the `public.employees_table` table establishes a relationship between the employees

In [15]:
import re
def extract_sql_query(response):
    # Define the regular expression pattern to match the SQL query
    pattern = re.compile(r'SQLQuery:\s*(.*)')
    
    # Search for the pattern in the response
    match = pattern.search(response)
    
    if match:
        # Extract and return the matched SQL query
        return match.group(1).strip()
    else:
        return None

sql_query = extract_sql_query(response)

In [16]:
sql_query

'SELECT COUNT(*) FROM public.employees_table WHERE "Salary" > 70000 LIMIT 5'

In [17]:
result = db.run(sql_query)
print(result)

[(1,)]
