In [None]:

EMBED_MODEL = 'nomic-embed-text'
LANG_MODEL = 'deepseek-r1:7b'
LIB_ROOT = 'library'
VECTOR_CACHE_ROOT = 'cache'


In [None]:
from langchain_ollama import OllamaEmbeddings
from langchain_community.document_loaders import PDFPlumberLoader
from langchain_community.vectorstores import FAISS
from langchain_experimental.text_splitter import SemanticChunker
from langchain_core.documents import Document

from langgraph.checkpoint.memory import MemorySaver

import os

# knowledge base
embedder = OllamaEmbeddings(model=EMBED_MODEL)
text_splitter = SemanticChunker(embedder)

if not os.path.exists(VECTOR_CACHE_ROOT):
    os.mkdir(VECTOR_CACHE_ROOT)

# Retrieval setup
def md5(filename):
    import hashlib
    import codecs

    return hashlib.md5(codecs.encode(filename)).hexdigest()

def to_doc(path):
    from typing import List
    from pydantic import TypeAdapter
    adapter = TypeAdapter(List[Document])
    ret: 'list[Document]'
    cache_file = os.path.join(VECTOR_CACHE_ROOT, md5(path))

    if os.path.exists(cache_file):
        print(f'cache for {path} exists, will use cache')
        with open(cache_file, 'rb') as c:
            content = c.read()
            return adapter.validate_json(content)

    loader = PDFPlumberLoader(path)
    print(f'loading doc {path}')
    docs = loader.load()
    print(f'spliting {path}')
    ret = text_splitter.split_documents(docs)

    with open(cache_file, 'wb') as f:
        f.write(adapter.dump_json(ret))
    return ret

def build_vector_store():

    docfiles = []
    for file in os.listdir(LIB_ROOT):
        docfiles.append(os.path.join(LIB_ROOT, file))
    
    docs = map(to_doc, docfiles)
    documents = [d for ds in docs if ds is not None for d in ds]
    print('building vector store')
    vector = FAISS.from_documents(documents, embedder)

    return vector

vector = build_vector_store()

In [None]:

from langchain_core.messages import AIMessageChunk
from typing import List
from typing_extensions import Annotated, TypedDict
from langgraph.graph import START, StateGraph
from langchain_core.prompts import PromptTemplate
from langchain_ollama import ChatOllama
from langchain_core.documents import Document


# Desired schema for response
class AnswerWithSources(TypedDict):
    """An answer to the question, with sources."""

    answer: str
    sources: Annotated[
        List[str],
        ...,
        "List of sources (author + year) used to answer the question",
    ]

class State(TypedDict):
    question: str
    context: List[Document]
    answer: str


model = ChatOllama(model=LANG_MODEL, num_ctx=8192)

template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use three sentences maximum and keep the answer as concise as possible.
Always say "thanks for asking!" at the end of the answer.

{context}

Question: {question}

Helpful Answer:"""
prompt = PromptTemplate.from_template(template)

# Define application steps
def retrieve(state: State):
    retrieved_docs = vector.similarity_search(state["question"])
    return {"context": retrieved_docs}


def generate(state: State):
    docs_content = "\n\n".join(doc.page_content for doc in state["context"])
    messages = prompt.invoke({"question": state["question"], "context": docs_content})
    response = model.invoke(messages)
    return {"answer": response.content}

# Build graph

# Compile application and test
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")

memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)


def do_chat():
    config = {
    "configurable": {
        "thread_id": "main"
    }
}

    input_message = input(">: ")
    # https://python.langchain.com/docs/concepts/streaming/
    # for step in graph.stream(
    #     {"messages": [{"role": "user", "content": input_message}]},
    #     stream_mode="messages",
    #     config=config,
    # ):
    #     for chunk in step:
    #         if isinstance(chunk, AIMessageChunk):
    #             print(chunk.content, end='', flush=False)


    for step in graph.stream(
        {'question': input_message},
        stream_mode="messages", 
        config=config
    ):
        for chunk in step:
            if isinstance(chunk, AIMessageChunk):
                print(chunk.content, end='')

while True:
    do_chat()

In [None]:

from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))