In [None]:
import os, json
from typing_extensions import TypedDict, List
import ast
import re, json, ast
from langgraph.graph import START, StateGraph, END
from langchain_core.documents import Document
from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.chat_models import ChatOllama
from langchain_core.prompts import PromptTemplate

In [None]:
from dataset_prompt import (
    query_transformation_prompt,
    dataset_prompt 
)

from configs import (
    QDRANT_URL,
    QDRANT_API_KEY,
    OLLAMA_URL
)

In [None]:
class chat(TypedDict):
    user_query: str
    transformed_query: str
    metadata: List[dict]
    summaries: List[str]
    similarity_scores: List[float]
    datasets: str

In [None]:
dense_embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25")

In [None]:
vector_store = QdrantVectorStore.from_existing_collection(
    embedding=dense_embeddings,
    sparse_embedding=sparse_embeddings,
    retrieval_mode=RetrievalMode.HYBRID,
    url=QDRANT_URL,
    api_key=QDRANT_API_KEY,
    prefer_grpc=True,
    collection_name="The Next Decade in AI Four Steps Towards Robust Artificial Intelligence",
)

In [None]:
llm = ChatOllama(
    model="mistral",
    base_url=OLLAMA_URL,
    temperature=0.0,
    num_predict=512,
)

In [None]:
def query_transformation(state: chat) -> dict:
    prompt = PromptTemplate.from_template(query_transformation_prompt)
    formatted_prompt = prompt.format(query=state["user_query"])
    response = llm.invoke(formatted_prompt)
    return {"transformed_query": response.content.strip()}

In [None]:
def retrieve_documents(state: chat) -> dict:
    results = vector_store.similarity_search_with_score(
        state["transformed_query"], k=5
    )

    metadata_list, summaries, scores = [], [], []
    for doc, score in results:
        metadata_list.append(doc.metadata)
        summaries.append(doc.page_content)
        scores.append(score)

    return {
        "metadata": metadata_list,
        "summaries": summaries,
        "similarity_scores": scores,
    }

In [None]:
def extract_datasets(state: chat) -> dict:
    paper_content = "\n\n".join(state["summaries"])

    prompt = PromptTemplate.from_template(dataset_prompt)
    formatted_prompt = prompt.format(context=paper_content)
    response = llm.invoke(formatted_prompt)

    raw_output = response.content.strip()

    print("\nRaw LLM Output:\n", raw_output)

In [None]:
chat_builder = StateGraph(chat)
chat_builder.add_node("query_transformation", query_transformation)
chat_builder.add_node("retrieve_documents", retrieve_documents)
chat_builder.add_node("extract_datasets", extract_datasets)

chat_builder.add_edge(START, "query_transformation")
chat_builder.add_edge("query_transformation", "retrieve_documents")
chat_builder.add_edge("retrieve_documents", "extract_datasets")
chat_builder.add_edge("extract_datasets", END)

chat_llm = chat_builder.compile()

In [None]:
if __name__ == "__main__":
    init_state: chat = {
        "user_query": "List the datasets used in this paper, with exact names. Give output in the form of a python List and do not include sections or any other explanations. ",
        "transformed_query": "",
        "metadata": [],
        "summaries": [],
        "similarity_scores": [],
        "datasets": ""
    }

    result = chat_llm.invoke(init_state)