In [5]:
! pip install --quiet langchain langchain_cohere langchain_community langchain-openai tiktoken langchainhub chromadb langgraph gpt4all

In [1]:
import os
from dotenv import load_dotenv

from langchain.schema import Document
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_cohere import CohereEmbeddings, ChatCohere
from langchain_openai import ChatOpenAI
from langchain_community.chat_models import ChatOllama
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field



In [2]:
run_local = "No"
generator_model = "cohere" # cohere | openai
load_dotenv()


True

In [3]:
# ! ollama pull gemma:2b
local_llm = "gemma:2b"

### Get embeddings from SQL metadata

Using document comprised of SQL CREATE queries and comments
https://arxiv.org/pdf/2204.00498.pdf


In [4]:
loader = TextLoader("./seed_db.sql")
documents = loader.load()

splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=500, chunk_overlap=200
)
all_splits = splitter.split_documents(documents)

# Embed and index
if run_local == "Yes":
    embedding = GPT4AllEmbeddings()
else:
    embedding = CohereEmbeddings()


# Index
vectorstore = Chroma.from_documents(
    documents=all_splits,
    collection_name="db-metadata-embeddings",
    embedding=embedding,
)

### Define graph state

In [5]:
from typing_extensions import TypedDict
from typing import List

class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        optimised_question: question optimised for vector search
        rewrite_question: str
        generation: LLM generation
        documents: list of documents 
    """
    question : str
    optimised_question: str
    rewrite_needed: str
    generation : str
    documents : List[str]

### Define graph nodes and edges

In [6]:
# Retrieval  (node)

def retrieve(state):
    print("---RETRIEVING---")
    retriever = vectorstore.as_retriever()
    if state["rewrite_needed"] == 'yes':
        question = state["optimised_question"]
    else: 
        question = state["question"]

    documents = retriever.invoke(question)
    return {"documents": documents}

In [7]:
# Retrieval Grader (node)

# Pydantic object for structured output
class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""

    binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")

preamble = """You are a grader assessing relevance of a retrieved document to a user question. \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""

if run_local == "Yes":
    llm = ChatOllama(model=local_llm, temperature=0)
else:
    llm = ChatCohere(model="command-r", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments, preamble=preamble)

grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
    ]
)

def grade_documents(state):
    
    retrieval_grader = grade_prompt | structured_llm_grader
    
    print("---CHECKING DOCUMENT RELEVANCE---")
    question = state["question"]
    documents = state["documents"]
    
    # Filter for relevant docs
    filtered_docs = []
    rewrite_needed = 'no'
    for i, d in enumerate(documents):
        grade = retrieval_grader.invoke({"question": question, "document": d.page_content}).binary_score
        if grade == "yes":
            print(f"------grader: document {i+1} relevant---")
            filtered_docs.append(d)
        else:
            print(f"------grader: document {i+1} not relevant---")
    
    if len(filtered_docs) == 0:
        rewrite_needed = "yes"
    return {"documents": filtered_docs, "rewrite_needed": rewrite_needed}

  warn_beta(


In [8]:
# Rewrite question to be more suitable for similarity search (node)

preamble = """You are an expert at re-writing questions so that they are optimized \n 
     for vectorstore retrieval. Take the original question and create an improved question. \n
"""

if run_local == "Yes":
    llm = ChatOllama(model=local_llm, temperature=0)
else:
    llm = ChatCohere(model="command-r", temperature=0).bind(preamble=preamble)

# Prompt 
rewrite_prompt = lambda x: ChatPromptTemplate.from_messages(
    [
        HumanMessage(
            f"Original question: {x['question']} \nNew question: ",
        )
    ]
)
question_rewriter = rewrite_prompt | llm | StrOutputParser()


def rewrite_question(state):
    print("---REWRITING QUESTION---")
    question = state["question"]

    optimised_question = question_rewriter.invoke({"question": question})
    return {"optimised_question": optimised_question}

In [9]:
# Decide whether to generate or rewrite question (conditional edge)

def decide_to_generate(state):
    """
    Determines whether to generate an answer, or re-generate a question.

    Args:
        state (dict): The current graph state

    Returns:
        str: Binary decision for next node to call
    """

    print("---ASSESS GRADED DOCUMENTS---")
    question = state["question"]
    rewrite_needed = state["rewrite_needed"]
    documents = state["documents"]

    if rewrite_needed == "yes":
        print("------decision: No documents relevant to question, rewriting---")
        return "rewrite"
    else:
        print("------decision: Generate---")
        return "generate"

In [10]:
# Generate! (node)

if generator_model == 'openai':
    llm = ChatOpenAI(model="gpt-4", temperature=0) 
else:
    llm = ChatCohere(model_name="command-r", temperature=0)


prompt = PromptTemplate(
    template="""You are a postgres expert. 
    Use the following pieces of retrieved context (relevant DB tables) to generate a syntactically correct SQL query that answers the user's question.
    When creating location-based queries, use tables containing GIS objects as this permits highly optimised geo-spatial querying.
    Only return the SQL query, with no explanation or preamble.
    If you can't formulate a valid query, just say that you don't know. \n 
    Here is the retrieved context: \n\n {documents} \n\n
    Here is the user question: {question} \n""",
    input_variables=["question", "documents"],
)

generation_chain = prompt | llm | StrOutputParser()

def generate(state):
    print("---GENERATING---")
    question = state["question"]
    documents = state["documents"]
    if not isinstance(documents, list):
      documents = [documents]

    generation = generation_chain.invoke({"documents": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}

In [11]:
# Eval (conditional edge)

if run_local == "Yes":
    llm = ChatOllama(model=local_llm, format="json", temperature=0)
else:
    llm = ChatCohere(model="command-r", temperature=0)

class ValidityEval(BaseModel):
    """Binary score for validity of generated SQL query."""

    binary_score: str = Field(description="Generated query is valid SQL, 'yes' or 'no'")

preamble = """You are a grader assessing the validity of a SQL query generated from text. If the query parses
as valid SQL, give a binary score of 'yes', otherwise give a binary score 'no'. 
Provide the binary score as JSON with a single key 'score' and no explanation."""
structured_evaluator = llm.with_structured_output(ValidityEval, preamble=preamble)
validity_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "Generated SQL query: \n\n {generation}"),
    ]
)

validity_evaluator = validity_prompt | structured_evaluator
class AnswerEval(BaseModel):
    """Binary score to assess whether generated SQL query answers the user's question."""

    binary_score: str = Field(description="Generated query addresses the question, 'yes' or 'no'")

preamble = """You are a SQL expert that can determine whether a SQL query addresses / resolves a question \n
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question. \n
Provide the binary score as JSON with a single key 'score' and no explanation."""
structured_evaluator = llm.with_structured_output(AnswerEval, preamble=preamble)
answer_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "User question: \n\n {question} \n\n Generated SQL query: \n\n {generation}"),
    ]
)
answer_evaluator = answer_prompt | structured_evaluator

class IdempEval(BaseModel):
    """Binary score for whether SQL query is idempotent."""

    binary_score: str = Field(description="Generated query is idempotent, 'yes' or 'no'")

preamble = """You are a grader assessing whether a generated SQL query permanently \n
alters the database. (i.e. whether it deletes, inserts, or updates entries in database tables, or deletes tables). \n
Permit JOINS and string functions that only temporarily alter the data returned by the query). \n
If the query does indeed leave the data unchanged, give a binary score of 'yes', otherwise give a binary score 'no'. 
Provide the binary score as JSON with a single key 'score' and no explanation."""

structured_evaluator = llm.with_structured_output(IdempEval, preamble=preamble)

idemp_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "Generated SQL query: \n\n {generation}"),
    ]
)
idemp_evaluator = idemp_prompt | structured_evaluator

def evaluate(state):
    """
    Determines whether generation answers question, is valid SQL, and is idempotent.

    Args:
        state (dict): The current graph state

    Returns:
        str: Decision for next node to call
    """

    print("---EVALUATING GENERATION---")
    question = state["question"]
    documents = state["documents"]
    generation = state["generation"]

    if validity_evaluator.invoke({"generation": generation}).binary_score == "yes":
        print("------decision: query is valid sql---")
    else: 
        print("------decision: query is not valid sql---")
        print(generation)
        return "fail"
        
    if answer_evaluator.invoke({"question": question, "generation": generation}).binary_score == "yes":
        print("------decision: query answers user's question---")
    else: 
        print("------decision: query does not answer user's question---")
        print(generation)
        return "fail"
        
    if idemp_evaluator.invoke({"generation": generation}).binary_score == "yes":
        print("------decision: query is idempotent---")
    else: 
        print("------decision: query is not idempotent---")
        print(generation)
        return "fail"

    return "success"

In [12]:
from langgraph.graph import END, StateGraph

workflow = StateGraph(GraphState)

# Define nodes
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("rewrite_question", rewrite_question)
workflow.add_node("generate", generate)

# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "rewrite": "rewrite_question",
        "generate": "generate",
    },
)
workflow.add_edge("rewrite_question", "retrieve")
workflow.add_conditional_edges(
    "generate",
    evaluate,
    {
        "fail": "generate", # Failed eval => re-generate 
        "success": END,
    },
)

# Compile
app = workflow.compile()

In [15]:
# Run
inputs = {"question": "Get tail numbers of all aircraft with model '787'"}
for output in app.stream(inputs):
    for key, value in output.items():
        print("")

print(output['generate']['generation'])

---RETRIEVING---

---CHECKING DOCUMENT RELEVANCE---
------grader: document 1 not relevant---
------grader: document 2 relevant---
------grader: document 3 not relevant---
------grader: document 4 not relevant---
---ASSESS GRADED DOCUMENTS---
------decision: Generate---

---GENERATING---
---EVALUATING GENERATION---
------decision: query is valid sql---
------decision: query answers user's question---
------decision: query is idempotent---

```sql
SELECT tail_number
FROM aircraft_assets
WHERE model = '787';
```


In [21]:
# Run
inputs = {"question": "Calculate how much time aircraft AB123 spent flying over the region of interest 'Slovakia'"}
for output in app.stream(inputs):
    for key, value in output.items():
        print(f"(Node: '{key}')")

print(value["generation"])

---RETRIEVING---
(Node: 'retrieve')
---CHECKING DOCUMENT RELEVANCE---
------grader: document 1 relevant---
------grader: document 2 relevant---
------grader: document 3 not relevant---
------grader: document 4 not relevant---
---ASSESS GRADED DOCUMENTS---
------decision: Generate---
(Node: 'grade_documents')
---GENERATING---
---EVALUATING GENERATION---
------decision: query is valid sql---
------decision: query answers user's question---
------decision: query is idempotent---
(Node: 'generate')
```sql
SELECT SUM(duration_hours) AS total_hours
FROM aircraft_historical_flights
WHERE
    tail_number = 'AB123'
    AND EXISTS (
        SELECT 1
        FROM aircraft_position_gis
        WHERE
            aircraft_position_gis.flight_id = aircraft_historical_flights.flight_id
            AND ST_Within(aircraft_position_gis.position, (
                SELECT region
                FROM regions_of_interest
                WHERE LOWER(name) = 'slovakia'
            ))
    );
```


### Very basic benchmarking with LangSmith
ProjectID: abe556aa-fe68-4b03-a365-7791b5c9f263

**Simple query:**

| Generation Model | Speed | Cost ($) | Quality of Query |
| ----------- | ----------- | --------- | --------- |
| Cohere cmdR | 1s (6.6s total) | 0.00083 | perfect |
| GPT4 | 1.3s (6.9s total) | 0.014 | perfect |


**Complex query:**

| Generation Model | Speed | Cost ($) | Quality of Query |
| ----------- | ----------- | --------- | --------- |
| Cohere cmdR | 3.5s (8.9s total) | 0.0013 | missed details in time calculation |
| GPT4 | 4.6s (10.2s total) | 0.031 | perfect |

