In [3]:
# (รันครั้งแรกใน Colab ถ้ายังไม่ได้ติดตั้ง)
!pip install -q sentence-transformers faiss-cpu transformers PyPDF2 nltk

# Python imports + device
import os, re, time, json
import numpy as np
import faiss
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from PyPDF2 import PdfReader
import nltk
from nltk.tokenize import sent_tokenize

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cpu


In [4]:
def read_pdf(path):
    reader = PdfReader(path)
    pages = []
    for p in reader.pages:
        try:
            pages.append(p.extract_text() or "")
        except:
            pages.append("")
    return "\n".join(pages)

def read_text(path):
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        return f.read()


In [5]:
# NLTK ensure + sentence chunker (อังกฤษ)
def ensure_nltk():
    try:
        nltk.data.find('tokenizers/punkt')
    except LookupError:
        nltk.download('punkt')
    try:
        nltk.data.find('tokenizers/punkt_tab')
    except LookupError:
        try:
            nltk.download('punkt_tab')
        except:
            pass

def text_to_chunks_by_sentence(text, max_words=300, overlap_words=60):
    ensure_nltk()
    sents = sent_tokenize(text)
    chunks = []
    cur, cur_len = [], 0
    for s in sents:
        w = s.split()
        if cur_len + len(w) > max_words and cur:
            chunks.append(" ".join(cur).strip())
            if overlap_words > 0:
                tail = " ".join(" ".join(cur).split()[-overlap_words:])
                cur = tail.split(); cur_len = len(cur)
            else:
                cur, cur_len = [], 0
        cur.extend(w); cur_len += len(w)
    if cur:
        chunks.append(" ".join(cur).strip())
    return chunks

# Lightweight splitter (fallback / ภาษาไทย)
def split_sentences_light(text):
    return [s.strip() for s in re.split(r'[.!?]\s+|\n+', text) if s.strip()]

def text_to_chunks_light(text, max_words=250, overlap_words=50):
    sents = split_sentences_light(text)
    chunks = []
    cur, cur_len = [], 0
    for s in sents:
        w = s.split()
        if cur_len + len(w) > max_words and cur:
            chunks.append(" ".join(cur).strip())
            if overlap_words > 0:
                tail = " ".join(" ".join(cur).split()[-overlap_words:])
                cur = tail.split(); cur_len = len(cur)
            else:
                cur, cur_len = [], 0
        cur.extend(w); cur_len += len(w)
    if cur:
        chunks.append(" ".join(cur).strip())
    return chunks


In [7]:
# โหลด embedder (เบาและเร็ว)
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
print("Embedder loaded:", embed_model.__class__.__name__)

def build_faiss_index(chunks, embedder, batch_size=32):
    vecs = []
    for i in range(0, len(chunks), batch_size):
        batch = chunks[i:i+batch_size]
        em = embedder.encode(batch, convert_to_numpy=True, show_progress_bar=False)
        vecs.append(em)
    vecs = np.vstack(vecs).astype('float32')
    faiss.normalize_L2(vecs)
    d = vecs.shape[1]
    index = faiss.IndexFlatIP(d)   # inner product on normalized vectors => cosine similarity
    index.add(vecs)
    return index, vecs


Embedder loaded: SentenceTransformer


In [8]:
def retrieve(index, embedder, chunks, query, top_k=10):
    if index.ntotal == 0: return []
    qvec = embedder.encode([query], convert_to_numpy=True).astype('float32')
    faiss.normalize_L2(qvec)
    k = min(top_k, index.ntotal)
    D, I = index.search(qvec, k)
    results = []
    for score, idx in zip(D[0], I[0]):
        if idx < 0 or idx >= len(chunks): continue
        results.append((chunks[idx], float(score), int(idx)))
    return results

def auto_top_k(num_chunks):
    k = int(np.log2(max(2, num_chunks)) * 2)
    return max(5, min(k, 50))

def filter_by_similarity(results, threshold=0.2):
    return [r for r in results if r[1] >= threshold]

def expand_queries(query):
    q = query.strip()
    return list(dict.fromkeys([
        q,
        q.lower(),
        q.replace("what", "").strip(),
        "Explain " + q,
        q + " details"
    ]))

def multi_query_retrieve(query, index, embedder, chunks, top_k):
    qs = expand_queries(query)
    all_results = []
    for q in qs:
        all_results.extend(retrieve(index, embedder, chunks, q, top_k))
    # merge unique idx keeping highest score
    unique = {}
    for txt, score, idx in all_results:
        if idx not in unique or score > unique[idx][1]:
            unique[idx] = (txt, score)
    merged = [(t, s, i) for i,(t,s) in unique.items()]
    merged.sort(key=lambda x: x[1], reverse=True)
    return merged


In [9]:
def mmr_rerank(query, results, embed_model, top_k=5, lambda_param=0.6):
    """
    results: [(text, score, idx), ...] ordered by score
    returns: list of original chunk indices selected
    """
    if not results: return []
    query_vec = embed_model.encode([query], convert_to_numpy=True).astype('float32')
    cand_texts = [r[0] for r in results]
    cand_vecs = embed_model.encode(cand_texts, convert_to_numpy=True).astype('float32')
    selected = []
    not_selected = list(range(len(cand_texts)))
    first = int(np.argmax(np.dot(cand_vecs, query_vec.T)))
    selected.append(first); not_selected.remove(first)
    while len(selected) < min(top_k, len(cand_texts)):
        mmr_scores = []
        for i in not_selected:
            relevance = float(np.dot(cand_vecs[i], query_vec.T))
            diversity = max(float(np.dot(cand_vecs[i], cand_vecs[j])) for j in selected)
            score = lambda_param * relevance - (1 - lambda_param) * diversity
            mmr_scores.append((score, i))
        mmr_scores.sort(reverse=True)
        chosen = mmr_scores[0][1]
        selected.append(chosen); not_selected.remove(chosen)
    return [results[i][2] for i in selected]


In [10]:
FALLBACK_SEQ2SEQ = "google/flan-t5-small"

def load_llm_safe(model_id=None, hf_token=None):
    if model_id:
        try:
            tok = AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_token)
            m = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=hf_token)
            if getattr(m.config, "pad_token_id", None) is None:
                m.config.pad_token_id = m.config.eos_token_id
                tok.pad_token = tok.eos_token
            m.to(device)
            return tok, m, "causal"
        except Exception as e:
            print("Causal load failed:", e)
    # fallback seq2seq
    tok = AutoTokenizer.from_pretrained(FALLBACK_SEQ2SEQ)
    m = AutoModelForSeq2SeqLM.from_pretrained(FALLBACK_SEQ2SEQ)
    if getattr(m.config, "pad_token_id", None) is None:
        m.config.pad_token_id = m.config.eos_token_id
        tok.pad_token = tok.eos_token
    m.to(device)
    return tok, m, "seq2seq"

def rag_generate(prompt, tokenizer, model, model_type='causal', max_new_tokens=200, temperature=0.2):
    inputs = tokenizer(prompt, return_tensors='pt', truncation=True).to(device)
    if model_type == 'causal':
        out = model.generate(
            **inputs, max_new_tokens=max_new_tokens, do_sample=True,
            temperature=temperature, top_p=0.95,
            pad_token_id=getattr(model.config, "pad_token_id", None),
            eos_token_id=getattr(model.config, "eos_token_id", None),
            no_repeat_ngram_size=3, early_stopping=True
        )
        text = tokenizer.decode(out[0], skip_special_tokens=True)
        return text[len(prompt):].strip() if text.startswith(prompt) else text
    else:
        out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
        text = tokenizer.decode(out[0], skip_special_tokens=True)
        return text


In [11]:
def build_rag_prompt(query, chunks, indices, role="expert", max_context_chars=2000):
    context = ""
    for idx in indices:
        context += f"[chunk {idx}]\n" + chunks[idx].strip() + "\n\n---\n\n"
        if len(context) > max_context_chars:
            break
    prompt = f"""You are an {role}. Answer ONLY using the context below.
If the answer cannot be found in the context, reply exactly: "I don't know".

Include citations by stating [chunk number] when referring to context.
Keep your answer concise.

Context:
{context}
Question: {query}

Answer:"""
    return prompt


In [12]:
def rag_answer(query, doc_path, lang="en", model_id=None, hf_token=None,
               chunk_max_words=None, chunk_overlap=60, top_k=None,
               sim_threshold=0.18, mmr_k=4, mmr_lambda=0.6, debug=False):
    t0 = time.time()
    # load doc
    if doc_path.lower().endswith(".pdf"):
        text = read_pdf(doc_path)
    else:
        text = read_text(doc_path)
    if debug: print("doc chars:", len(text))
    # chunk
    if chunk_max_words is None:
        chunk_max_words = 400 if lang=="en" else 200
    if lang=="en":
        chunks = text_to_chunks_by_sentence(text, max_words=chunk_max_words, overlap_words=chunk_overlap)
    else:
        chunks = text_to_chunks_light(text, max_words=chunk_max_words, overlap_words=chunk_overlap)
    if debug: print("num chunks:", len(chunks))
    # embed & index
    embed_model = SentenceTransformer("all-MiniLM-L6-v2")
    if debug: print("embedder loaded")
    index, vecs = build_faiss_index(chunks, embed_model, batch_size=32)
    if debug: print("index size:", index.ntotal)
    # top_k
    if top_k is None: top_k = auto_top_k(len(chunks))
    if debug: print("top_k:", top_k)
    # retrieval
    raw = multi_query_retrieve(query, index, embed_model, chunks, top_k)
    if debug: print("raw retrieved:", len(raw))
    # filter
    filtered = filter_by_similarity(raw, threshold=sim_threshold)
    if debug: print("filtered:", len(filtered))
    if not filtered:
        return {"answer":"I don't know","selected":[], "snippets":[], "prompt":None, "time": time.time()-t0}
    # mmr
    sel_indices = mmr_rerank(query, filtered, embed_model, top_k=mmr_k, lambda_param=mmr_lambda)
    if debug: print("sel indices:", sel_indices)
    # prompt
    prompt = build_rag_prompt(query, chunks, sel_indices, role="expert")
    if debug: print("prompt len:", len(prompt))
    # load llm & generate
    tokenizer, model, model_type = load_llm_safe(model_id, hf_token)
    if debug: print("loaded llm:", model_type)
    answer = rag_generate(prompt, tokenizer, model, model_type=model_type, max_new_tokens=200, temperature=0.2)
    snippets = [{"idx": idx, "text": chunks[idx]} for idx in sel_indices]
    return {"answer":answer, "selected":sel_indices, "snippets":snippets, "prompt":prompt, "time":time.time()-t0}


In [14]:
# ตัวอย่าง (อังกฤษ .txt)
res = rag_answer("What are the key components of AI?", "/content/20241125-Generative-AI-Guideline_V2-0.pdf", lang="en", debug=True)
print("Answer:\n", res["answer"])
print("Selected indices:", res["selected"])
for s in res["snippets"]:
    print("---- chunk", s["idx"])
    print(s["text"][:400].replace("\n"," "), "\n")

# ตัวอย่าง (ไทย .pdf)
res_th = rag_answer("องค์ประกอบสำคัญของ AI มีอะไรบ้าง?", "/content/20241125-Generative-AI-Guideline_V2-0.pdf", lang="th", debug=True)
print("Answer (th):\n", res_th["answer"])


doc chars: 59518


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


num chunks: 11
embedder loaded
index size: 11
top_k: 6
raw retrieved: 7
filtered: 7


  relevance = float(np.dot(cand_vecs[i], query_vec.T))


sel indices: [10, 3, 9, 2]
prompt len: 2913


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

loaded llm: seq2seq
Answer:
 [chunk number]
Selected indices: [10, 3, 9, 2]
---- chunk 10
2024). Generative AI examples . Google. https://cloud.google.com/use -cases/generative -ai?hl=en Google Cloud. ( 2024). Grounding . Google. https://cloud.google.com/vertex - ai/generative -ai/docs/model -reference/grounding HM Government. ( 2024). Generative AI framework for HM Government . HM Government . https://assets.publishing.service.gov.uk/media/65c3b5d628a4a00012d2b a5c/6.8558_CO_Generativ 

---- chunk 3
บนพื้นฐำน ข้อมูลที่ได้รับ จำกกำรฝึกฝน หรือจำกสภำพแวดล้อม มำกกว่ำกำร ท ำงำนตำม โปรแกรมที่มนุษย์ก ำหนด Deep Learning (DL) AI ประเภท Machine Learning ที่ประมวลผลข้อมูล ขนำดใหญ่ ผ่ำนโครงข่ำยประสำทเทียม (Artificial Neural Network : ANN) ซึ่งเลียนแบบกำรท ำงำน จำกสมองของมนุษย์ Generative AI AI ประเภท Deep Leaning ที่มีควำมสำมำรถในกำรสร้ำงสรรค์ เนื้อหำใหม่ใ นหลำกหลำยรูปแบบ ทั้งข้อควำม ภำพ วิดีโอ หรือ รูปแ 

---- chunk 9
ไปใช้งำนร่วมกับแอปพลิเคชันหรือบริกำร Generative AI 3) ในกรณีที่ต้องใช้แอปพลิเค

In [16]:
import json, faiss

def save_index_and_chunks(index, chunks, index_path="faiss_index.bin", chunks_path="chunks.json"):
    faiss.write_index(index, index_path)
    with open(chunks_path, "w", encoding="utf-8") as f:
        json.dump(chunks, f, ensure_ascii=False)

def load_index_and_chunks(index_path="faiss_index.bin", chunks_path="chunks.json"):
    index = faiss.read_index(index_path)
    with open(chunks_path, "r", encoding="utf-8") as f:
        chunks = json.load(f)
    return index, chunks
