In [None]:
from langchain import hub
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import START, StateGraph
from typing_extensions import List, TypedDict
import sys
sys.path.append("..")
from langchain_ollama import ChatOllama
from scripts.chromaDB_handler import ChromaDataManager
from scripts.config import DATA_PATH
import os
import torch

In [None]:
llm = ChatOllama(
    model="gemma3:4b-it-qat",
    temperature=0,
)

In [None]:
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, PromptTemplate

# Define the input variables and template
input_variables = ['context', 'question']
template = (
    "You are an assistant for question-answering tasks. "
    "Use the following pieces of retrieved context to answer the question. "
    "If you don't know the answer, just say that you don't know. "
    "Use three sentences maximum and keep the answer concise.\n"
    "Question: {question} \nContext: {context} \nAnswer:"
)

# Create the prompt template
prompt_template = PromptTemplate(input_variables=input_variables, template=template)

# Define the HumanMessagePromptTemplate with the prompt template
human_message_prompt = HumanMessagePromptTemplate(prompt=prompt_template)

# Create the ChatPromptTemplate with the defined prompt and metadata
prompt = ChatPromptTemplate(
    input_variables=input_variables,
    messages=[human_message_prompt],
)



In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_path = "nomic-ai/nomic-embed-text-v1"
data_manager = ChromaDataManager(model_path=model_path, collection_name='textCollection', data_path=DATA_PATH, device=device)

In [None]:
from langgraph.checkpoint.memory import MemorySaver

class State(TypedDict):
    question: str
    context: str
    answer: str
    tokens: int
    tokens_per_second: float
    history: List[dict]  # Add history to keep previous Q&A


def retrieve(state: State):
    retrieved_docs = data_manager.search_vector_store(state["question"], n_results=1)
    retrieved_docs.sort_values(by=['ID', 'chunk_number'], ascending=True, inplace=True)
    retrieved_docs.reset_index(drop=True, inplace=True)
    docs = "\n".join(retrieved_docs['Text'])

    # Format past history
    history_context = "\n".join([f"User: {m['question']}\nBot: {m['answer']}" for m in state.get("history", [])])
    full_context = history_context + "\n\n" + docs if history_context else docs

    return {"context": full_context}

def generate(state: State):
    messages = prompt.invoke({
        "question": state["question"],
        "context": state["context"]
    })
    response = llm.invoke(messages)

    usage = response.usage_metadata
    meta = response.response_metadata

    total_tokens = usage["total_tokens"]
    total_duration_sec = meta["total_duration"] / 1e9
    tokens_per_sec = round(total_tokens / total_duration_sec, 2) if total_duration_sec else 0.0

    # Update chat history
    updated_history = state.get("history", []) + [{
        "question": state["question"],
        "answer": response.content
    }]

    return {
        "answer": response.content,
        "tokens": total_tokens,
        "tokens_per_second": tokens_per_sec,
        "history": updated_history
    }


graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile(checkpointer=MemorySaver())

In [None]:
def create_state(question: str) -> State:
    return {
        "question": question,
        "context": "",
        "answer": "",
        "tokens": 0,
        "tokens_per_second": 0.0,
    }

# --- Utility: run with memory ---
def qa_chat(question: str, thread_id: str = "default_thread") -> State:
    state = create_state(question)
    config = {"configurable": {"thread_id": thread_id}}
    return graph.invoke(state, config=config)

In [None]:
response = qa_chat("What are the two methods of calculating the total cost installed?", thread_id="cost_calc")

print("Answer:", response["answer"])
print("Tokens Used:", response["tokens"])
print("Speed (tokens/sec):", response["tokens_per_second"])

In [None]:
# ask follow-up questions
response2 = qa_chat("Explain the second method in more detail.", thread_id="cost_calc")
print("Follow-up:", response2["answer"])
print("Tokens Used:", response2["tokens"])
print("Speed (tokens/sec):", response2["tokens_per_second"])

In [None]:
#clear memory
graph.checkpointer = MemorySaver()