In [None]:
import os, json
from typing_extensions import TypedDict, List

from langgraph.graph import START, StateGraph, END
from langchain_core.documents import Document
from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.chat_models import ChatOllama
from langchain_core.prompts import ChatPromptTemplate

In [None]:
from prompts import (
    query_transformation_prompt,
    llm_prompt,
)

from configs import (
    QDRANT_URL,
    QDRANT_API_KEY,
    OLLAMA_URL
)

In [None]:
class chat(TypedDict):
    user_query: str
    transformed_query: str
    metadata: List[dict]
    summaries: List[str]
    similarity_scores: List[float]
    answer: str
    token_count: int

In [None]:
dense_embeddings  = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25")

In [None]:
vector_store = QdrantVectorStore.from_existing_collection(
    embedding=dense_embeddings,
    sparse_embedding=sparse_embeddings,
    retrieval_mode=RetrievalMode.HYBRID,
    url=QDRANT_URL,
    api_key=QDRANT_API_KEY,
    prefer_grpc=True,
    collection_name="axRiv_research_papers",
)

In [None]:
llm = ChatOllama(
    model="mistral",
    base_url=OLLAMA_URL,
    temperature=0.0,
    num_predict=1024,
)

In [None]:
def query_transformation(state: chat) -> dict:
    prompt = ChatPromptTemplate.from_messages([
        ("system", query_transformation_prompt),
        ("human", "{query}")
    ])
    chain = prompt | llm
    response = chain.invoke({"query": state["user_query"]})
    return {"transformed_query": response.content}

In [None]:
def retrieve_documents(state: chat) -> dict:
    results = vector_store.similarity_search_with_score(
        state["transformed_query"], k=5
    )

    metadata_list, summaries, scores = [], [], []
    for doc, score in results:
        metadata_list.append(doc.metadata)
        summaries.append(doc.page_content)
        scores.append(score)

    return {
        "metadata": metadata_list,
        "summaries": summaries,
        "similarity_scores": scores,
    }

In [None]:
def generate_answer(state: chat) -> dict:
    papers_block = "\n\n".join(
        f"• **{md.get('Title', 'No title')}** ({md.get('url', 'No URL')})\n  {summ}"
        for md, summ in zip(state["metadata"], state["summaries"])
    )

    prompt = ChatPromptTemplate.from_messages([
        ("system", llm_prompt),
        ("human", "{query}\n{context}")
    ])
    chain = prompt | llm
    response = chain.invoke({
        "query":   state["user_query"],
        "context": papers_block
    })

    if response is None:
        raise RuntimeError("LLM invocation returned None—check your LLM client or network settings.")

    usage_meta = getattr(response, "usage_metadata", {}) or {}
    token_count = usage_meta.get("total_tokens", 0)

    return {"answer": response.content, "token_count": token_count}

In [None]:
chat_builder = StateGraph(chat)
chat_builder.add_node("query_transformation", query_transformation)
chat_builder.add_node("retrieve_documents", retrieve_documents)
chat_builder.add_node("generate_answer", generate_answer)

chat_builder.add_edge(START, "query_transformation")
chat_builder.add_edge("query_transformation", "retrieve_documents")
chat_builder.add_edge("retrieve_documents", "generate_answer")
chat_builder.add_edge("generate_answer", END)

chat_llm = chat_builder.compile()

In [None]:
if __name__ == "__main__":
    init_state: chat = {
        "user_query": "What is Reinforcement Learning?",
        "transformed_query": "",
        "metadata": [], "summaries": [], "similarity_scores": [],
        "answer": "", "token_count": 0
    }
    result = chat_llm.invoke(init_state)
    print("=== ANSWER ===")
    print(result["answer"])
    print("\n=== Similarity Scores ===")
    print(result["similarity_scores"])

    print("\n=== Top 5 Paper Metadata ===")
    for i, md in enumerate(result["metadata"], start=1):
        print(f"\nPaper {i}:")
        print(json.dumps(md, indent=2))