In [None]:
!python -m pip install --upgrade pip
!python -m pip uninstall -y torch torchvision torchaudio
!python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
!pip install faiss-gpu-cu12 datasets sentence-transformers

In [None]:
from huggingface_hub import login
login(token="HF_token")

In [None]:
import torch
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("google/embeddinggemma-300m")

In [None]:
from typing import Iterator, Dict, Sequence, Optional
from pathlib import Path

def read_text_file(path: Path, encoding: str = "utf-8", max_bytes: Optional[int] = 5_000_000) -> Optional[str]:
    try:
        if max_bytes is not None and path.stat().st_size > max_bytes:
            return None
        return path.read_text(encoding=encoding, errors="ignore")
    except Exception:
        try:
            return path.read_bytes().decode("utf-8", errors="ignore")
        except Exception:
            return None

def chunk_text(text: str, size: int = 1000, overlap: int = 200):
    if size <= 0:
        raise ValueError("size must be > 0")
    if overlap < 0:
        raise ValueError("overlap must be >= 0")
    if overlap >= size:
        overlap = size - 1
    step = max(1, size - overlap)
    for i in range(0, len(text), step):
        chunk = text[i:i + size]
        if not chunk:
            break
        yield chunk

def iter_file_chunks(
    base_dir: str | Path,
    patterns: Sequence[str] = ("*.txt", "*.md"),
    recursive: bool = False,
    size: int = 1000,
    overlap: int = 200,
    max_bytes: Optional[int] = 5_000_000,
    include_titles: bool = True,
) -> Iterator[Dict[str, str | int]]:
    base = Path(base_dir)
    paths = []
    for pat in patterns:
        paths.extend(base.rglob(pat) if recursive else base.glob(pat))
    paths = [p for p in paths if p.is_file()]
    for p in paths:
        text = read_text_file(p, max_bytes=max_bytes)
        if not text:
            continue
        title = p.stem if include_titles else "none"
        for idx, c in enumerate(chunk_text(text, size=size, overlap=overlap)):
            yield {"source": str(p), "title": title, "chunk": c, "chunk_index": idx}

In [None]:
base_dir = "/workspace/corpus"
records = list(iter_file_chunks(base_dir=base_dir))

print(len(records), "chunks")

for r in records[:3]:
    print(r["source"], r["chunk_index"], len(r["chunk"]))

def as_doc(text, title=None):
    return f"title: {title if title else 'none'} | text: {text}"

doc_inputs = [as_doc(rec["chunk"], rec["title"]) for rec in records]
doc_embs = model.encode_document(doc_inputs, batch_size=8, device="cuda", convert_to_numpy=True)

In [None]:
import numpy as np, faiss

def l2_normalize(x):
    n = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
    return x / n

emb = l2_normalize(doc_embs.astype(np.float32))
index = faiss.IndexFlatIP(emb.shape[1])
index.add(emb)

def retrieve(q, k=5):
    qv = model.encode_query(q).astype(np.float32)[None, :]
    qv = l2_normalize(qv)
    k = min(k, index.ntotal)
    D, I = index.search(qv, k)
    hits = [(int(i), float(d)) for i, d in zip(I[0], D[0]) if i != -1]
    return hits

q = "Type your query here"
hits = retrieve(q, k=5)
for i, s in hits:
    print(round(s, 4), records[i]["title"], records[i]["chunk"][:160].replace("\n"," "))

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "google/gemma-2-2b-it"
tok = AutoTokenizer.from_pretrained(model_id)
gen = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
gen = gen.to("cuda" if torch.cuda.is_available() else "cpu")
gen.eval()

In [None]:
def build_prompt(question, context_chunks):
    context = "\n\n".join(context_chunks)
    return f"Use the context to answer the question.\n\nContext:\n{context}\n\nQuestion: {question}\nAnswer:"

topk = 3
hits = retrieve(q, k=topk)
ctx_chunks = [records[i]["chunk"] for i, _ in hits]
prompt = build_prompt(q, ctx_chunks)

inputs = tok(prompt, return_tensors="pt").to(gen.device)

with torch.no_grad():
    out = gen.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.7,
        top_p=0.9,
        do_sample=True
    )

print(f"Answer: {tok.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()}")