In [4]:
import json
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from pathlib import Path

OUT_DIR = Path("../data")
INDEX_PATH = OUT_DIR / "pdf_index.faiss"
EMBEDDINGS_PATH = OUT_DIR / "pdf_embeddings.npy"
CHUNKS_PATH = OUT_DIR / "pdf_chunks.json"

def l2_normalize(a, axis=1, eps=1e-10):
    norm = np.linalg.norm(a, axis=axis, keepdims=True)
    return a / (norm + eps)

class RAGService:
    def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2",
                 generator_model=None):  # generator_model=None => dùng extractive fallback
        # embed model
        self.embed_model = SentenceTransformer(model_name)

        # load chunks (ensure order consistent with embeddings file)
        with open(CHUNKS_PATH, 'r', encoding='utf-8') as f:
            self.chunks = json.load(f)  # list of dicts with id,page,text

        # load embeddings
        self.embeddings = np.load(EMBEDDINGS_PATH)  # should be shape (N, D)

        # sanity checks
        if len(self.embeddings) != len(self.chunks):
            raise ValueError(f"Embeddings length ({len(self.embeddings)}) != chunks ({len(self.chunks)}) — check you saved them in the same order")

        # ensure dtype float32
        if self.embeddings.dtype != np.float32:
            print("[WARN] embeddings dtype != float32, converting.")
            self.embeddings = self.embeddings.astype("float32")

        # normalize embeddings for cosine similarity and build index (IndexFlatIP)
        self.embeddings = l2_normalize(self.embeddings)
        dim = self.embeddings.shape[1]

        # Build a new FAISS index in-memory (safe) from loaded embeddings to avoid mismatch issues
        self.index = faiss.IndexFlatIP(dim)  # inner product on normalized vectors == cosine
        self.index.add(self.embeddings)
        # (optional) write index back if you want
        faiss.write_index(self.index, str(INDEX_PATH))

        # generator: optional. If provided, it MUST be a causal/generation-capable model.
        if generator_model:
            # NOTE: pass device_map or device if needed. Use a known causal Vietnamese model.
            self.generator = pipeline('text-generation', model=generator_model)
        else:
            self.generator = None

        print("✅ RAGService ready. #chunks:", len(self.chunks), "dim:", dim)

    def retrieve(self, query, top_k=5):
        q_emb = self.embed_model.encode([query], convert_to_numpy=True)
        if q_emb.dtype != np.float32:
            q_emb = q_emb.astype("float32")
        q_emb = l2_normalize(q_emb)
        D, I = self.index.search(q_emb, top_k)  # returns inner product scores (higher = closer)
        results = []
        for score, idx in zip(D[0], I[0]):
            if idx == -1: continue
            # safety bounds
            if idx < 0 or idx >= len(self.chunks): continue
            results.append({
                "id": self.chunks[idx].get("id"),
                "page": self.chunks[idx].get("page"),
                "text": self.chunks[idx].get("text"),
                "score": float(score)  # cosine-ish similarity in [-1,1]
            })
        return results

    def generate_or_extract(self, query, top_k=5, use_generator=False):
        retrieved = self.retrieve(query, top_k=top_k)
        # debug: always print top retrieved
        print("DEBUG: Retrieved top-k:")
        for r in retrieved:
            print(f" - id={r['id']} page={r['page']} score={r['score']:.4f} text_preview={r['text'][:120]!r}")

        # simple extractive fallback:
        # if top score is high enough, return that chunk (or the sentence inside it)
        if len(retrieved) == 0:
            return {"answer": "", "sources": retrieved, "method": "none"}

        top = retrieved[0]
        if top['score'] >= 0.5 or not use_generator:
            # try to extract a direct answer sentence that contains keywords from query
            q_words = set([w.lower() for w in query.split() if len(w) > 1])
            sentences = [s.strip() for s in top['text'].split('.') if s.strip()]
            # prefer sentence that contains most query words
            best_sent, best_count = None, -1
            for s in sentences:
                cnt = sum(1 for w in q_words if w in s.lower())
                if cnt > best_count:
                    best_count = cnt
                    best_sent = s
            answer = best_sent if best_sent else top['text'][:400]
            return {"answer": answer.strip(), "sources": retrieved, "method": "extractive", "score": top['score']}

        # else call generator (if available)
        if self.generator:
            context = " ".join([r["text"] for r in retrieved])
            prompt = f"Ngữ cảnh: {context}\nCâu hỏi: {query}\nTrả lời ngắn gọn, chính xác bằng tiếng Việt:"
            # deterministic generation
            gen = self.generator(prompt, max_new_tokens=128, do_sample=False, num_return_sequences=1)[0]['generated_text']
            # remove prompt echo if present
            if gen.startswith(prompt):
                gen = gen[len(prompt):].strip()
            return {"answer": gen.strip(), "sources": retrieved, "method": "generator"}
        else:
            return {"answer": top['text'][:400].strip(), "sources": retrieved, "method": "extractive (no generator)"}


In [6]:
svc = RAGService(generator_model=None)   # tạm không dùng generator
res = svc.generate_or_extract("Nhà thờ Đức Bà ở đâu?", top_k=5)
print(res['method'])
print(res['answer'])
for s in res['sources'][:3]:
    print(s['page'], s['score'])


✅ RAGService ready. #chunks: 3090 dim: 384
DEBUG: Retrieved top-k:
 - id=2152 page=698 score=0.7710 text_preview='nam định'
 - id=454 page=153 score=0.7684 text_preview='gùi đan hoa văn bằng nan nhu ộm đen.'
 - id=806 page=259 score=0.7667 text_preview='đã có'
 - id=2164 page=702 score=0.7641 text_preview='cầu nam định'
 - id=1762 page=572 score=0.7599 text_preview='ngày nay đang được đầu tư khôi thành phố hồ chí minh'
extractive
nam định
698 0.7709580063819885
153 0.7683683633804321
259 0.7667286396026611
