In [3]:
from langchain_ollama import ChatOllama, OllamaEmbeddings
from langchain_community.document_loaders import PDFPlumberLoader
from langchain_community.vectorstores import FAISS
from langchain_experimental.text_splitter import SemanticChunker
from langchain_core.documents import Document

from langchain_core.messages import SystemMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, MessagesState, StateGraph
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode, tools_condition

import os

EMBED_MODEL = 'nomic-embed-text'
LANG_MODEL = 'qwen2.5:7b'
LIB_ROOT = 'library'
VECTOR_CACHE_ROOT = 'cache'


In [4]:
# knowledge base
embedder = OllamaEmbeddings(model=EMBED_MODEL)
text_splitter = SemanticChunker(embedder)

if not os.path.exists(VECTOR_CACHE_ROOT):
    os.mkdir(VECTOR_CACHE_ROOT)

# Retrieval setup
def md5(filename):
    import hashlib
    import codecs

    return hashlib.md5(codecs.encode(filename)).hexdigest()

def to_doc(path):
    from typing import List
    from pydantic import TypeAdapter
    adapter = TypeAdapter(List[Document])
    ret: 'list[Document]'
    cache_file = os.path.join(VECTOR_CACHE_ROOT, md5(path))

    if os.path.exists(cache_file):
        print(f'cache for {path} exists, will use cache')
        with open(cache_file, 'rb') as c:
            content = c.read()
            return adapter.validate_json(content)

    loader = PDFPlumberLoader(path)
    print(f'loading doc {path}')
    docs = loader.load()
    print(f'spliting {path}')
    ret = text_splitter.split_documents(docs)

    with open(cache_file, 'wb') as f:
        f.write(adapter.dump_json(ret))
    return ret

def build_vector_store():

    docfiles = []
    for file in os.listdir(LIB_ROOT):
        docfiles.append(os.path.join(LIB_ROOT, file))
    
    docs = map(to_doc, docfiles)
    documents = [d for ds in docs if ds is not None for d in ds]
    print('building vector store')
    vector = FAISS.from_documents(documents, embedder)

    return vector

vector = build_vector_store()

cache for library\infra-agent.pdf exists, will use cache
cache for library\mixture-of-experts.pdf exists, will use cache
cache for library\optimization.pdf exists, will use cache
cache for library\original.pdf exists, will use cache
cache for library\survey.pdf exists, will use cache
building vector store


In [14]:

from langchain_core.messages import AIMessageChunk

# model and workflow
@tool(response_format="content_and_artifact", description="retrieve paper content from knowledgebase given keywords")
def search_for_paper(query: str):
    """Retrieve information related to a query."""
    retrieved_docs = vector.similarity_search(query, k=2)
    serialized = "\n\n".join(
        (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
        for doc in retrieved_docs
    )
    return serialized, retrieved_docs

model = ChatOllama(model=LANG_MODEL, num_ctx=8192)

# Step 1: Generate an AIMessage that may include a tool-call to be sent.
def query_or_respond(state: MessagesState):
    """Generate tool call for retrieval or respond."""
    model.bind_tools([search_for_paper])
    response = model.invoke(state["messages"])
    # MessagesState appends messages to state instead of overwriting
    return {"messages": [response]}

# Step 2: Execute the retrieval.
tools = ToolNode([search_for_paper])

# Step 3: Generate a response using the retrieved content.
def generate(state: MessagesState):
    """Generate answer."""
    # Get generated ToolMessages
    recent_tool_messages = []
    for message in reversed(state["messages"]):
        if message.type == "tool":
            recent_tool_messages.append(message)
        else:
            break
    tool_messages = recent_tool_messages[::-1]

    # Format into prompt
    docs_content = "\n\n".join(doc.content for doc in tool_messages)
    system_message_content = (
        "You are an assistant for question-answering tasks. "
        "Use the following pieces of retrieved context to answer "
        "the question. If you don't know the answer, say that you "
        "don't know. Use three sentences maximum and keep the "
        "answer concise."
        "\n\n"
        f"{docs_content}"
    )
    conversation_messages = [
        message
        for message in state["messages"]
        if message.type in ("human", "system")
        or (message.type == "ai" and not message.tool_calls)
    ]
    prompt = [SystemMessage(system_message_content)] + conversation_messages

    # Run
    response = model.invoke(prompt)
    return {"messages": [response]}

# Build graph
graph_builder = StateGraph(MessagesState)

graph_builder.add_node(query_or_respond)
graph_builder.add_node(tools)
graph_builder.add_node(generate)

graph_builder.set_entry_point("query_or_respond")
graph_builder.add_conditional_edges(
    "query_or_respond",
    tools_condition,
    {END: END, "tools": "tools"},
)
graph_builder.add_edge("tools", "generate")
graph_builder.add_edge("generate", END)

memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)

def do_chat():
    config = {
    "configurable": {
        "thread_id": "main"
    }
}

    input_message = input(">: ")
    print('===== User Message =====')
    print(input_message)
    print('========================')
    # https://python.langchain.com/docs/concepts/streaming/
    # for step in graph.stream(
    #     {"messages": [{"role": "user", "content": input_message}]},
    #     stream_mode="messages",
    #     config=config,
    # ):
    #     for chunk in step:
    #         if isinstance(chunk, AIMessageChunk):
    #             print(chunk.content, end='', flush=False)


    for step in graph.stream(
        {"messages": [{"role": "user", "content": input_message}]},
        stream_mode="values", config=config
    ):
        step["messages"][-1].pretty_print()

while True:
    do_chat()

===== User Message =====
explain paper for me OPTIMIZING MIXTURE OF EXPERTS USING DYNAMIC RECOMPILATIONS

explain paper for me OPTIMIZING MIXTURE OF EXPERTS USING DYNAMIC RECOMPILATIONS


ValueError: test