# Langgraph RAG

## Lib import

In [None]:
import os
import logging
import operator
from typing import List, Literal, Annotated, Optional, Union, Any
from typing_extensions import TypedDict
import chromadb
from chromadb.config import Settings
from pydantic import BaseModel, Field, validator

from langchain import PromptTemplate, LLMChain
from langchain import hub
from langchain_openai import ChatOpenAI
from langchain_deepseek import ChatDeepSeek
from langchain_core.tools import tool
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough, RunnableSerializable
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_chroma import Chroma
from langchain_text_splitters import TokenTextSplitter
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader
from langchain_community.document_loaders import UnstructuredHTMLLoader
from langchain.schema import Document
from langchain_community.retrievers import TavilySearchAPIRetriever
from langchain_huggingface import HuggingFaceEmbeddings

from langgraph.graph import START, END, MessagesState, StateGraph
from langgraph.checkpoint.memory import MemorySaver, InMemorySaver
from langgraph.graph import START, END, MessagesState, StateGraph
from langgraph.checkpoint.memory import InMemorySaver

## Init config and API key

In [None]:
#
# key
#
from dotenv import load_dotenv
load_dotenv()

deepseek_api_key = os.getenv("DEEPSEEK_API_KEY")
silicon_api_key = os.getenv("SILICON_API_KEY")
tavily_api_key = os.getenv("TAVILY_API_KEY")
linkup_api_key = os.getenv("LINKUP_API_KEY")

In [None]:
#
# config
#
import tomllib
def load_config(config_file):
    try:
        with open(config_file, 'rb') as f:
            config = tomllib.load(f)
            return config
    except Exception as e:
        print(f"Load config file error: {e}")
        return None

# load config file
deepseek_llm_model = None
silicon_base_url = None
silicon_llm_model = None
huggingface_embed_model = None

config_data = load_config("../config/config.toml")
if config_data:
    log_level = config_data.get('log_level')
    if log_level:
        logging.basicConfig(level=log_level)
    
    # deepseek
    deepseek_llm_model = config_data.get('deepseek', {}).get('model')
    deepseek_llm_temperature = config_data.get('deepseek', {}).get('temperature')
    deepseek_llm_max_tokens = config_data.get('deepseek', {}).get('max_tokens')  

    # silicon
    silicon_base_url =  config_data.get('silicon', {}).get('base_url')
    silicon_llm_model = config_data.get('silicon', {}).get('model')

    # huggingface
    huggingface_embed_model = config_data.get('huggingface', {}).get('embed_model')


# deepseek
deepseek_llm_model = deepseek_llm_model or "deepseek-chat"

# silicon
silicon_base_url =  silicon_base_url or "https://api.siliconflow.cn/v1"
silicon_llm_model = silicon_llm_model or "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"

# huggingface
huggingface_embed_model = huggingface_embed_model or "sentence-transformers/all-MiniLM-L6-v2"

## LLM Init

In [None]:
# init LLM mod
llm_deepseek = ChatDeepSeek(
    model=deepseek_llm_model,
    temperature=deepseek_llm_temperature or 0.3,
    max_tokens=deepseek_llm_max_tokens,
    timeout=None,
    top_p=0.9,
    frequency_penalty=0.7,
    presence_penalty=0.5,
    max_retries=3
)

## RAG Graph

In [None]:
################################################################################
### RAG state
################################################################################
class RagState(TypedDict):
    question: str
    response: str
    documents: List[str]
    #documents: Annotated[list, operator.add]

In [None]:
################################################################################
### RAG nodes
################################################################################
def node_retrieve(state: RagState):
    question = state["question"]
    documents = []

    # Retrieval
    #documents = retriever.invoke(question)
    return {"documents": documents}


def node_retrieve_rewrite(state: RagState):
    question = state["question"]
    documents = state["documents"]

    rewrite_docs = []
    for d in documents:
        #doc = document_rewriter.invoke(
        #    {"question": question, "document": d.page_content}
        #)
        #rewrite_docs.append(doc)
        continue

    return {"documents": rewrite_docs}


def node_grade_documents(state: RagState):
    question = state["question"]
    documents = state["documents"]

    # Score each doc
    filtered_docs = []
    for d in documents:
        #grade = retrieval_grader.invoke(
        #    {"question": question, "document": d.page_content}
        #)

        if grade is None:
            #print(f"---WARNING: Retrieval grader returned None for document: {d.page_content[:50]}...")
            grade = {"score": 7}

        if grade.score >= 7:
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            continue

    return {"documents": filtered_docs}

def node_web_search(state: RagState):
    question = state["question"]

    # Web search
    docs = []
    #docs = web_search_tool.invoke({"query": question})
    web_results = "\n".join([d["content"] for d in docs])
    web_results = Document(page_content=web_results)

    return {"documents": web_results}

def node_generate(state: RagState):
    question = state["question"]
    documents = state["documents"]

    # generation
    generation = ""
    #generation = rag_chain.invoke({"context": documents, "question": question})

    return {"response": generation}


def node_transform_query(state: RagState):
    question = state["question"]
    documents = state["documents"]

    # Re-write question
    better_question = ""
    better_question = question_rewriter.invoke({"question": question})

    return {"question": better_question, "documents": documents}


def node_generate_fail(state: RagState):
    question = state["question"]
    documents = state["documents"]

    generation = "Sorry, I was unable to reply to your earlier enquiry, please ask again, thank you!"

    return {"response": generation}

In [None]:
################################################################################
### Edges conditional functions
################################################################################
def condition_plan(state: RagState):
    question = state["question"]

    source = "vectorstore"
    # source = question_router.invoke({"question": question})
    if source.datasource == "web_search":
        print("---ROUTE QUESTION TO WEB SEARCH---")
        return "web_search"
    elif source.datasource == "vectorstore":
        print("---ROUTE QUESTION TO RAG---")
        return "vectorstore"


def condition_retrieve(state: RagState):
    question = state["question"]
    documents = state["documents"]

    if documents:
        return "success"
    else:
        return "failure"


def condition_grade_documents(state: RagState):
    question = state["question"]
    filtered_documents = state["documents"]

    if not filtered_documents:
        # We will re-generate a new query
        return "not relevant"
    else:
        # We have relevant documents, so generate answer
        return "relevant"


def condition_generation(state: RagState):
    question = state["question"]
    documents = state["documents"]
    response = state["response"]

    # Check question-answering
    # grade = answer_grader.invoke({"question": question, "generation": generation})
    if grade == None:
        grade = {"score": 7}

    score = grade.score
    if grade.score >= 7:
        return "useful"
    else:
        return "not useful"

In [None]:
################################################################################
### Create Graph
################################################################################
workflow = StateGraph(RagState)

################################################################################
### Add nodes
################################################################################
workflow.add_node("node_transform_query", node_transform_query)
workflow.add_node("node_retrieve", node_retrieve)
workflow.add_node("node_retrieve_rewrite", node_retrieve_rewrite)
workflow.add_node("node_web_search", node_web_search)
workflow.add_node("node_generate", node_generate)
workflow.add_node("node_generate_fail", node_generate_fail)

################################################################################
### Add edges
################################################################################
## retrieve
workflow.add_edge(START, "node_transform_query")
workflow.add_edge("node_transform_query", "node_retrieve")
workflow.add_edge("node_retrieve", "node_retrieve_rewrite")
workflow.add_conditional_edges(
  "node_retrieve_rewrite",
  condition_retrieve,
  {
    "success": "node_generate",
    "failure": "node_web_search",
  },
)

## web_search
workflow.add_edge("node_web_search", "node_generate")

## generate
workflow.add_conditional_edges(
  "node_generate",
  condition_generation,
  {
    "useful": END,
    "not useful": "node_generate_fail",
    "not supported": "node_generate_fail",
  },
)

workflow.add_edge("node_generate_fail", END)

In [None]:
################################################################################
### Compile Graph
################################################################################
rag_graph = workflow.compile()

In [None]:
from IPython.display import Image, display
display(Image(rag_graph.get_graph().draw_mermaid_png()))

## Test