In [75]:
import os
import glob
import numpy as np
from dotenv import load_dotenv
from langchain_community.embeddings import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer
from langchain_chroma import Chroma
import gradio as gr
import time
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import plotly.graph_objects as go

from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_ollama import ChatOllama
from langchain_core.documents import Document
from typing import List
from langchain_core.retrievers import BaseRetriever
from ensemble_retriever import EnsembleRetriever

In [None]:
# 1. Divide into chunk
# First load all our knowledge-base folder
folders = glob.glob("knowledge-base/**/*")
print(f"Found {len(folders)} files in the knowledge base") # 17 files: 

## How many characters in all the documents?
entire_knowledge_base = ""

for file_path in folders:
    with open(file_path, 'r', encoding='utf-8') as f:
        entire_knowledge_base += f.read()
        entire_knowledge_base += "\n\n"

print(f"Total characters in knowledge base: {len(entire_knowledge_base):,}") # 101,404 words

In [None]:
# Read in documents using LangChain's loaders
# Take everything in all the sub-folders of our knowledgebase

folders = glob.glob("knowledge-base/*")
text_loader_kwargs={'autodetect_encoding': True}


def extract_entity_from_filename(file_path):
    # Lấy tên file, bỏ đuôi .md
    filename = os.path.basename(file_path)
    return os.path.splitext(filename)[0]

documents = []
for folder in folders:
    doc_type = os.path.basename(folder) # # company, employees, schools, visas
    loader = DirectoryLoader(folder, glob="**/*.md", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)
    folder_docs = loader.load()
    for doc in folder_docs:
        # enrich metadata
        doc.metadata.update({
            "department": doc_type,  # group / folder
            "entity": extract_entity_from_filename(doc.metadata["source"]),  # tên thực thể
            "source_file": doc.metadata["source"],  # đường dẫn gốc
            "language": "vi",  # nếu toàn tiếng Việt, hoặc detect tự động
            # "tags": []  # placeholder nếu muốn gắn thêm tags sau này
        })
        documents.append(doc)

print("Total documents loaded:", len(documents))

In [None]:
documents[0].metadata

In [None]:
# Devide into CHUNKS
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=400,  
    chunk_overlap=80, 
    separators=["\n\n", "\n", ". ", " ", ""]  # Better separation
)

chunks = text_splitter.split_documents(documents)
print(f"Created {len(chunks)} chunks")
print(f"First chunk:\n\n{chunks[0]}")

In [None]:
# 2. Encode chunks and store in vector store
load_dotenv(override=True)

In [82]:
# choose an embedding model
# embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
embeddings = HuggingFaceEmbeddings(
    model_name="intfloat/multilingual-e5-base",
    encode_kwargs={"normalize_embeddings": True}
)
# IMPORTANT: E5 expects passage / query prefix
chunks = [
    Document(
        page_content="passage: " + c.page_content,
        metadata=c.metadata     
    )
    for c in chunks
]

In [83]:
# Đặt tên cho database vector (có thể tùy chọn)
db_name = "vector_db"

# Kiểm tra nếu database Chroma đã tồn tại, thì xóa collection để khởi động lại từ đầu or remove
if os.path.exists(db_name):
    Chroma(persist_directory=db_name, embedding_function=embeddings).delete_collection()

In [None]:
# Tạo vector store bằng Chroma
vectorstore = Chroma.from_documents(
    documents=chunks,              # Danh sách các đoạn văn bản đã chia nhỏ
    embedding=embeddings,          # Hàm embedding (HuggingFace)
    persist_directory=db_name      # Thư mục lưu trữ database
)
# Kiểm tra số lượng document đã được lưu vào vector store
print(f"Vectorstore created with {vectorstore._collection.count()} documents")

In [None]:
# #Lấy ra bộ sưu tập vector từ vectorstore
collection = vectorstore._collection

# ------investiage our vectors-----------------
# #Lấy 1 embedding từ database
sample_embedding = collection.get(limit=1, include=["embeddings"])["embeddings"][0]

# #Kiểm tra số chiều (số phần tử trong vector)
dimensions = len(sample_embedding)
print(len(sample_embedding))
# -------------------------------------------------

In [87]:
retriever = vectorstore.as_retriever()

MODEL = "llama3.2"
llm = ChatOllama(temperature=0.7, model=MODEL)

In [None]:
llm.invoke("Who is Lan ?")

In [None]:
retriever.invoke("Who is Lan ?")

In [88]:
SYSTEM_PROMPT_TEMPLATE = """
Bạn là chuyên gia tư vấn du học Hàn Quốc tại trung tâm Korea Study. 
Nhiệm vụ của bạn là trả lời các câu hỏi liên quan đến trung tâm, nhân viên, trường học và thông tin visa một cách ngắn gọn và chính xác. 
Nếu có thông tin liên quan trong ngữ cảnh được cung cấp, hãy sử dụng để trả lời câu hỏi.
Nếu bạn không biết câu trả lời, hãy nói rõ rằng bạn không biết. Tuyệt đối không bịa thông tin nếu không có ngữ cảnh phù hợp.
Ngữ cảnh:
{context}
"""

In [89]:
def answer_question(question: str, history):
    docs = retriever.invoke(question)
    
    # --- Debug: in ra các chunks được retrieve ---
    print(f"Found {len(docs)} chunks for question: {question}")
    for i, doc in enumerate(docs):
        print(f"\n--- Chunk {i+1} ---\n{doc.page_content}\n")
    # ---------------------------------------------
    context = "\n\n".join(doc.page_content for doc in docs)
    system_prompt = SYSTEM_PROMPT_TEMPLATE.format(context=context)
    response = llm.invoke([SystemMessage(content=system_prompt), HumanMessage(content=question)])
    return response.content

In [None]:
answer_question("Lan là ai?", [])

# Metadata filtering

In [94]:
# Multi-department routing
def route_retriever(question: str):
    q = question.lower()
    departments = []

    # People intent
    if any(x in q for x in ["ai", "who", "là ai", "nhân viên"]):
        departments.append("employees")

    # School intent
    if any(x in q for x in ["trường", "university", "đại học"]):
        departments.append("schools")

    # Visa intent
    if any(x in q for x in ["visa", "thị thực", "d2", "d4"]):
        departments.append("visas")

    # Company intent
    if any(x in q for x in ["korea study", "trung tâm", "công ty"]):
        departments.append("company")

    # No clear intent → search everything
    if not departments:
        return vectorstore.as_retriever(search_kwargs={"k": 6})

    # Multiple departments → broaden search
    return vectorstore.as_retriever(
        search_kwargs={
            "k": 8,
            "filter": {"department": {"$in": departments}}
        }
    )

# Hybrid search

In [95]:
# keyword retriever (content level)
class KeywordRetriever(BaseRetriever):
    documents: List[Document]
    
    class Config:
        arbitrary_types_allowed = True

    def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Document]:
        q = query.lower()
        return [
            doc for doc in self.documents
            if any(token in doc.page_content.lower() for token in q.split())
        ]

    async def _aget_relevant_documents(self, query: str, *, run_manager=None) -> List[Document]:
        return self._get_relevant_documents(query)


keyword_retriever = KeywordRetriever(documents=chunks)

In [96]:
# Hybrid retriever (Vector + keyword)
def hybrid_retriever_for_question(question: str):
    vector_retriever = route_retriever(question)
    return EnsembleRetriever(
        retrievers=[keyword_retriever, vector_retriever],
        weights=[0.3, 0.7]
    )

In [97]:
def answer_question(question: str, history):
    retriever = hybrid_retriever_for_question(question)
    docs = retriever.invoke(question)
    
    # --- Debug: in ra các chunks được retrieve ---
    print(f"Found {len(docs)} chunks for question: {question}")
    for i, doc in enumerate(docs):
        print(f"\n--- Chunk {i+1} ---\n{doc.page_content}\n")
    # ---------------------------------------------
    context = "\n\n".join(doc.page_content for doc in docs)
    system_prompt = SYSTEM_PROMPT_TEMPLATE.format(context=context)
    response = llm.invoke([SystemMessage(content=system_prompt), HumanMessage(content=question)])
    return response.content

# Evaluation

In [56]:
# Notebook setup
from test_questions import load_tests_from_jsonl
from evaluate_rag import (
    evaluate_retrieval,
    evaluate_answer_with_llm,
    RetrievalMetrics,
    AnswerMetrics
)

from collections import defaultdict
import random

In [None]:
# Test questions file 
from test_questions import save_tests_to_jsonl, print_test_summary

save_tests_to_jsonl()
print_test_summary()

In [None]:
# Load test set
tests = load_tests_from_jsonl("tests.jsonl")
print(f"Loaded {len(tests)} test questions")

# quick sanity check
tests[0]

In [102]:
retrieval_results = []

for test in tests:
    retriever = hybrid_retriever_for_question(test.question) # use hybrid_retriever_for_question
    metrics, _ = evaluate_retrieval(
        question=test.question,
        expected_keywords=test.keywords,
        retriever=retriever,
        k=10
    )
    
    retrieval_results.append({
        "category": test.category,
        "mrr": metrics.mrr,
        "ndcg": metrics.ndcg,
        "coverage": metrics.keyword_coverage
    })


In [None]:
# Summary retrieval
# overall
avg_mrr = sum(r["mrr"] for r in retrieval_results) / len(retrieval_results)
avg_ndcg = sum(r["ndcg"] for r in retrieval_results) / len(retrieval_results)
avg_cov = sum(r["coverage"] for r in retrieval_results) / len(retrieval_results)

print("=== RETRIEVAL SUMMARY ===")
print(f"Avg MRR: {avg_mrr:.3f}")
print(f"Avg nDCG: {avg_ndcg:.3f}")
print(f"Avg Coverage: {avg_cov:.1f}%")

# by category
by_cat = defaultdict(list)
for r in retrieval_results:
    by_cat[r["category"]].append(r)

print("\n=== BY CATEGORY ===")
for cat, items in by_cat.items():
    print(
        f"{cat:15s} | "
        f"MRR={sum(i['mrr'] for i in items)/len(items):.3f} | "
        f"Coverage={sum(i['coverage'] for i in items)/len(items):.1f}%"
    )

# Visualization

In [None]:
# ---------- Prepare data ----------
result = collection.get(include=["embeddings", "documents", "metadatas"])

vectors = np.array(result["embeddings"])
documents = result["documents"]
doc_types = [metadata["department"] for metadata in result["metadatas"]]

# Màu sắc theo loại tài liệu
color_map = {
    "company": "gray",
    "employees": "green",
    "visas": "red",
    "schools": "orange",
}
colors = [color_map.get(t, "blue") for t in doc_types]

# ---------- PCA 3D projection ----------
# Chuẩn hóa vector (rất quan trọng cho PCA)
scaler = StandardScaler()
vectors_scaled = scaler.fit_transform(vectors)

pca = PCA(n_components=3, random_state=42)
reduced_vectors = pca.fit_transform(vectors_scaled)

print("Explained variance ratio:", pca.explained_variance_ratio_)

# ---------- 3D Visualization ----------
fig = go.Figure(
    data=[
        go.Scatter3d(
            x=reduced_vectors[:, 0],
            y=reduced_vectors[:, 1],
            z=reduced_vectors[:, 2],
            mode="markers",
            marker=dict(
                size=5,
                color=colors,
                opacity=0.8,
            ),
            text=[
                f"Loại: {t}<br>Văn bản: {d[:100]}..."
                for t, d in zip(doc_types, documents)
            ],
            hoverinfo="text",
        )
    ]
)

fig.update_layout(
    title="Biểu đồ PCA 3D của Vector Store (Debug Retrieval Space)",
    scene=dict(
        xaxis_title="PC1",
        yaxis_title="PC2",
        zaxis_title="PC3",
    ),
    width=900,
    height=700,
    margin=dict(r=10, b=10, l=10, t=40),
)

fig.show()



# Bonus!!!

In [104]:
# 1. Query Rewrite
def rewrite_query_llm(question: str, history=[]):
    prompt = f"""
Bạn đang chuẩn bị tìm thông tin trong knowledge base.

Lịch sử hội thoại:
{history}

Câu hỏi hiện tại:
{question}

Viết lại thành MỘT câu truy vấn ngắn, rõ ràng, cụ thể,
phù hợp để search trong knowledge base.
Chỉ trả về câu truy vấn, KHÔNG giải thích.
"""
    response = llm.invoke(prompt)
    return response.content.strip()

In [105]:
# 2 LLM metadata routing (thay thế rule-based)
from pydantic import BaseModel

class Route(BaseModel):
    departments: list[str]

def llm_route(question: str):
    prompt = f"""
Bạn là hệ thống định tuyến truy vấn cho knowledge base.

Các department có thể có:
- employees
- schools
- visas
- company

Câu hỏi:
{question}

Trả về JSON hợp lệ, ví dụ:
{{"departments": ["schools", "employees"]}}

KHÔNG giải thích.
"""
    response = llm.invoke(prompt)
    return Route.model_validate_json(response.content).departments


In [106]:
# 3. Dung lai vectorstore hien tai
def retrieve_docs(question: str, k=8):
    rewritten = rewrite_query_llm(question)
    print("Rewrite:", rewritten)

    departments = llm_route(rewritten)
    print("Route:", departments)

    retriever = vectorstore.as_retriever(
        search_kwargs={
            "k": k,
            "filter": {"department": {"$in": departments}} if departments else None
        }
    )

    docs = retriever.invoke("query: " + rewritten)
    return docs

In [None]:
import json
import re

def rerank_llm(question, chunks, max_retries=3):
    if not chunks:
        return []

    prompt = "You are a document re-ranker.\n"
    prompt += "Given a question and document chunks, return a JSON array of chunk IDs (1-based) from most relevant to least relevant.\n"
    prompt += f"Question: {question}\nChunks:\n"

    for idx, chunk in enumerate(chunks):
        prompt += f"# CHUNK ID: {idx+1}\n{chunk.page_content}\n\n"

    prompt += "Reply only with a JSON array like [1,2,3,...]."

    for attempt in range(max_retries):
        try:
            response = llm.invoke(prompt)
            
            # Handle both AIMessage and string responses
            if hasattr(response, 'content'):
                reply = response.content.strip()
            else:
                reply = str(response).strip()

            # Parse JSON
            try:
                order = json.loads(reply)
            except:
                # fallback: extract numbers
                order = [int(n) for n in re.findall(r"\d+", reply)]

            if order and all(0 < i <= len(chunks) for i in order):
                # Return unique indices only
                seen = set()
                ranked_docs = []
                for i in order:
                    if i not in seen and 0 < i <= len(chunks):
                        ranked_docs.append(chunks[i-1])
                        seen.add(i)
                return ranked_docs

        except Exception as e:
            print(f"[rerank] Retry {attempt+1}: {e}")

    print("[rerank] Fallback: returning original docs")
    return chunks

In [108]:
# 5. Answer
def answer_question_advanced(question: str, history=[]):
    docs = retrieve_docs(question, k=12)
    docs = rerank_llm(question, docs)
    # --- Debug: in ra các chunks được retrieve ---
    print(f"Found {len(docs)} chunks for question: {question}")
    for i, doc in enumerate(docs):
        print(f"\n--- Chunk {i+1} ---\n{doc.page_content}\n")
    # ---------------------------------------------
    context = "\n\n".join(d.page_content for d in docs[:5])

    system_prompt = SYSTEM_PROMPT_TEMPLATE.format(context=context)

    response = llm.invoke([
        SystemMessage(content=system_prompt),
        HumanMessage(content=question)
    ])
    return response.content

In [None]:
answer_question_advanced("Visa D4 cần điều kiện gì?")

In [None]:
# --- Hàm chat wrapper sử dụng answer_question_advanced ---
def chat_advanced(user_message, history=None):
    """
    user_message: câu hỏi mới từ người dùng
    history: danh sách các message trước, dạng list of dicts [{"role": ..., "content": ...}, ...]
    """
    if history is None:
        history = []

    # Gọi RAG pipeline nâng cao
    result = answer_question_advanced(user_message, history)
    
    # Nếu answer_question_advanced trả về tuple (answer_text, docs), ta chỉ lấy answer_text
    if isinstance(result, tuple):
        answer_text = result[0]
    else:
        answer_text = result  # nếu trả về string

    # ChatInterface tự động quản lý history, chỉ cần return câu trả lời
    return answer_text


# --- Khởi tạo giao diện chat ---
interface = gr.ChatInterface(
    fn=chat_advanced,
    type="messages",
    title="Welcome to Korea Study chatbot",
    description="Hỏi bất cứ điều gì về du học Hàn Quốc"
)

interface.launch(inbrowser=True)