In [1]:
from langgraph.graph import StateGraph, START, END
from typing import TypedDict
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
import os

# -------------- 🔐 API Key Setup --------------
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv(), override=True)

# -------------- 🧠 Define State Schema --------------
class State(TypedDict):
    messages: list
    department: str
    retrieved_docs: list
    final_answer: str

# -------------- 🔧 LLM + Embeddings --------------
llm = ChatOpenAI(model_name='gpt-4o-mini', temperature=0.5)
embedding_model = OpenAIEmbeddings()

# Load vectorstore (Ensure FAISS index exists)
vectorstore = FAISS.load_local("policy_faiss_index", embedding_model)

# Optional memory (can persist between runs if needed)
memory = ConversationBufferMemory(return_messages=True)

# -------------- 🧩 Node 1: Understand Query --------------
def understand_query(state: State) -> State:
    return {**state, "messages": state.get("messages", [])}

# -------------- 🧩 Node 2: Classify Department --------------
def classify_department(state: State) -> State:
    query = state["messages"][-1]["content"]
    prompt = PromptTemplate.from_template(
        "Classify this query into one department: HR, Finance, or IT:\n\n{query}"
    )
    chain = LLMChain(llm=llm, prompt=prompt)
    department = chain.run(query=query).strip().lower()
    return {**state, "department": department}

# -------------- 🧩 Node 3: Retrieve Relevant Docs --------------
def retrieve_docs(state: State) -> State:
    query = state["messages"][-1]["content"]
    docs = vectorstore.similarity_search(query, k=5)
    return {**state, "retrieved_docs": docs}

# -------------- 🧩 Node 4: Filter Docs with LLM --------------
def filter_docs(state: State) -> State:
    docs = state["retrieved_docs"]
    combined = "\n\n".join([doc.page_content for doc in docs])
    prompt = PromptTemplate.from_template(
        "Here is some policy content:\n{docs}\n\nFilter out irrelevant or redundant parts and retain the most helpful content."
    )
    chain = LLMChain(llm=llm, prompt=prompt)
    filtered = chain.run(docs=combined)
    return {**state, "retrieved_docs": [filtered]}

# -------------- 🧩 Node 5: Summarize Final Answer --------------
def summarize_answer(state: State) -> State:
    content = state["retrieved_docs"][0] if state["retrieved_docs"] else "No relevant policy found."
    prompt = PromptTemplate.from_template(
        "Summarize the following for an employee in simple, helpful language:\n\n{content}"
    )
    chain = LLMChain(llm=llm, prompt=prompt)
    summary = chain.run(content=content)
    return {**state, "final_answer": summary}

# -------------- ⚙️ Build LangGraph --------------
builder = StateGraph(State)
builder.add_node("UnderstandQuery", understand_query)
builder.add_node("ClassifyDepartment", classify_department)
builder.add_node("RetrieveDocs", retrieve_docs)
builder.add_node("FilterDocs", filter_docs)
builder.add_node("SummarizeAnswer", summarize_answer)

builder.set_entry_point("UnderstandQuery")
builder.add_edge("UnderstandQuery", "ClassifyDepartment")
builder.add_edge("ClassifyDepartment", "RetrieveDocs")
builder.add_edge("RetrieveDocs", "FilterDocs")
builder.add_edge("FilterDocs", "SummarizeAnswer")
builder.add_edge("SummarizeAnswer", END)

graph = builder.compile()

# -------------- 🚀 Run the System --------------
if __name__ == "__main__":
    query = "How many casual leaves can I take in a year?"
    input_state = {
        "messages": [HumanMessage(content=query)]
    }

    result = graph.invoke(input_state)
    print("✅ Final Answer:\n", result["final_answer"])


ValueError: The de-serialization relies loading a pickle file. Pickle files can be modified to deliver a malicious payload that results in execution of arbitrary code on your machine.You will need to set `allow_dangerous_deserialization` to `True` to enable deserialization. If you do this, make sure that you trust the source of the data. For example, if you are loading a file that you created, and know that no one else has modified the file, then this is safe to do. Do not set this to `True` if you are loading a file from an untrusted source (e.g., some random site on the internet.).