# Module 3 — Embeddings & Vector Store (FAISS)

## Objectives
- Understand embeddings and **semantic search**. 💡
- Generate embeddings with:
  - **Local**: `sentence-transformers/all-MiniLM-L6-v2` (free, CPU-friendly).
  - **Cloud**: Gemini `text-embedding-004` (optional, quota-limited).
- Build and query a **FAISS vector index** (cosine similarity).
- **Persist** the index and metadata locally and reload it.
- Optional: **MMR re-ranking**; optional answer synthesis with Gemini. 🤖

## Prerequisites
- Run **Module 2** and produce JSONL chunks in `data/outputs`.
- For Gemini embeddings/answering: set `GOOGLE_API_KEY` in your environment.

### Install dependencies if needed:

In [None]:
import sys
!{sys.executable} -m pip install -q faiss-cpu sentence-transformers tqdm
!{sys.executable} -m install -q google-generativeai

In [None]:
# %pip install -q faiss-cpu sentence-transformers tqdm

# Optional for Gemini
# %pip install -q google-generativeai

import os
import sys
import json
import time
import math
import glob
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple

import numpy as np
from tqdm.notebook import tqdm

# Handles environments where module name differs
try:
    import faiss  # noqa: F401
except ImportError:
    import faiss_cpu as faiss  # type: ignore

print(f"Python: {sys.version.split()[0]}")
print(f"FAISS: {faiss.get_version_string() if hasattr(faiss, 'get_version_string') else 'ok'}")

## Configuration
- `INPUT_CHUNKS`: path to Module 2 output JSONL (combined or per-doc).
- `EMBEDDING_BACKEND`: `"minilm"` (local) or `"gemini"` (cloud).
- `INDEX_DIR`: folder to persist FAISS index and metadata.
- `BATCH_SIZE`: adjust to trade speed vs memory.
- `MMR`: optional re-ranking for diversity.

In [None]:
INPUT_CHUNKS = Path("./data/outputs/all_chunks.jsonl")   # change if needed
INDEX_DIR = Path("./data/indexes/my_corpus")
INDEX_DIR.mkdir(parents=True, exist_ok=True)

EMBEDDING_BACKEND = "minilm"  # "minilm" or "gemini"
BATCH_SIZE = 64
MAX_TEXT_LEN = 4000  # char limit per chunk for embedding API safety

GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")

## Load chunks from Module 2

Expects JSONL lines with:
`id`, `text`, `metadata: {doc_id, page, section_title, section_path, ...}`

In [None]:
def load_jsonl(path: Path) -> List[Dict[str, Any]]:
    items = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                items.append(json.loads(line))
    return items

assert INPUT_CHUNKS.exists(), f"Missing chunks file: {INPUT_CHUNKS}. Please run Module 2 first."
chunks = load_jsonl(INPUT_CHUNKS)
print(f"Loaded {len(chunks)} chunks from {INPUT_CHUNKS}")

print("\n--- Inspecting a sample ---")
for r in chunks[:3]:
    print(r["id"], "| page:", r["metadata"].get("page"), "| section:", r["metadata"].get("section_title"))
    print(r["text"][:140].replace("\n"," "), "...\n")

## Embedding backends
- **MiniLM (local)**: 384-dim, fast on CPU, no API cost.
- **Gemini (cloud)**: `text-embedding-004` (768-dim), requires `GOOGLE_API_KEY`.

We’ll normalize embeddings to unit length and use FAISS `IndexFlatIP` (inner product) to emulate cosine similarity.

In [None]:
class MiniLMBackend:
    def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
        from sentence_transformers import SentenceTransformer
        import torch
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = SentenceTransformer(model_name, device=device)
        self.model_name = model_name
        self.dim = self.model.get_sentence_embedding_dimension()

    def encode(self, texts: List[str], batch_size: int = 64) -> np.ndarray:
        # SentenceTransformer returns np.ndarray float32 by default
        return self.model.encode(texts, batch_size=batch_size, convert_to_numpy=True, normalize_embeddings=False)

class GeminiBackend:
    def __init__(self, model_name: str = "models/text-embedding-004", api_key: Optional[str] = None):
        assert api_key, "GOOGLE_API_KEY is required for Gemini embeddings."
        import google.generativeai as genai
        genai.configure(api_key=api_key)
        self.genai = genai
        self.model_name = model_name
        self.dim = 768  # text-embedding-004 output dimension

    def encode(self, texts: List[str], batch_size: int = 100) -> np.ndarray:
        vecs: List[List[float]] = []
        # Use the batch embedding API for efficiency
        for i in tqdm(range(0, len(texts), batch_size), desc="Gemini embedding"):
            batch = texts[i:i+batch_size]
            # Truncate texts in the batch if necessary
            batch_truncated = [t[:MAX_TEXT_LEN] if MAX_TEXT_LEN and len(t) > MAX_TEXT_LEN else t for t in batch]
            resp = self.genai.embed_content(model=self.model_name, content=batch_truncated, task_type="retrieval_document")
            vecs.extend(resp["embedding"])
        return np.asarray(vecs, dtype=np.float32)

def get_backend(name: str):
    if name.lower() == "minilm":
        be = MiniLMBackend()
        print(f"Using MiniLM backend: {be.model_name}, dim={be.dim}")
        return be
    elif name.lower() == "gemini":
        be = GeminiBackend(api_key=GOOGLE_API_KEY)
        print(f"Using Gemini backend: {be.model_name}, dim={be.dim}")
        return be
    else:
        raise ValueError("EMBEDDING_BACKEND must be 'minilm' or 'gemini'.")

backend = get_backend(EMBEDDING_BACKEND)

## Prepare texts and compute embeddings
- Keep a `docstore` mapping array aligned with vectors: each entry holds `id`, `text`, and `metadata` for retrieval/citations.
- Normalize embeddings to unit length for cosine search via inner product.

In [None]:
def prepare_texts(records: List[Dict[str, Any]]) -> List[str]:
    """Prepares text for embedding, applying truncation if necessary."""
    texts = []
    for r in records:
        t = r["text"].strip()
        texts.append(t)
    return texts

def l2_normalize(vecs: np.ndarray) -> np.ndarray:
    norms = np.linalg.norm(vecs, axis=1, keepdims=True)
    norms[norms == 0] = 1.0
    return vecs / norms

texts = prepare_texts(chunks)
print(f"Embedding {len(texts)} chunks...")

emb = backend.encode(texts, batch_size=BATCH_SIZE).astype(np.float32)
assert emb.shape[1] == backend.dim, f"Dimension mismatch: got {emb.shape[1]}, expected {backend.dim}"
emb = l2_normalize(emb)

print(f"Embeddings shape: {emb.shape}")

# Build docstore aligned with embeddings
docstore = [
    {
        "id": r["id"],
        "text": r["text"],
        "metadata": r.get("metadata", {})
    }
    for r in chunks
]

## Build FAISS index (cosine similarity via inner product)
- `IndexFlatIP` with L2-normalized vectors.
- Persist index, vectors (optional for MMR), and docstore to disk.
- Store a small manifest to ensure consistent reloads.

In [None]:
def build_faiss_index(vectors: np.ndarray) -> faiss.Index:
    dim = vectors.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(vectors)
    return index

def save_index(index: faiss.Index,
               vectors: np.ndarray,
               docstore: List[Dict[str, Any]],
               backend_name: str,
               dim: int,
               out_dir: Path):
    out_dir.mkdir(parents=True, exist_ok=True)
    faiss.write_index(index, str(out_dir / "index.faiss"))
    # Optional: save vectors for MMR re-ranking and reproducibility
    np.save(out_dir / "vectors.npy", vectors)
    
    # Save docstore and manifest
    with (out_dir / "docstore.jsonl").open("w", encoding="utf-8") as f:
        for d in docstore:
            f.write(json.dumps(d, ensure_ascii=False) + "\n")
            
    manifest = {
        "backend": backend_name,
        "dim": dim,
        "count": int(vectors.shape[0]),
        "created_at": time.time()
    }
    with (out_dir / "manifest.json").open("w", encoding="utf-8") as f:
        json.dump(manifest, f, ensure_ascii=False, indent=2)

def load_index(in_dir: Path) -> Tuple[faiss.Index, np.ndarray, List[Dict[str, Any]], Dict[str, Any]]:
    index = faiss.read_index(str(in_dir / "index.faiss"))
    vectors = np.load(in_dir / "vectors.npy")
    docstore = load_jsonl(in_dir / "docstore.jsonl")
    manifest = json.loads((in_dir / "manifest.json").read_text(encoding="utf-8"))
    return index, vectors, docstore, manifest

# Build and save
index = build_faiss_index(emb)
save_index(index, emb, docstore, EMBEDDING_BACKEND, backend.dim, INDEX_DIR)
print(f"Index saved to {INDEX_DIR}")

# Reload to verify
index2, vectors2, store2, manifest2 = load_index(INDEX_DIR)
print(f"Reloaded index: dim={manifest2['dim']}, count={manifest2['count']}, backend={manifest2['backend']}")

## Search utilities
- `search_faiss`: pure FAISS top-k search.
- `mmr_rerank`: optional **Maximal Marginal Relevance** reranking using stored vectors.
- `search`: convenience wrapper with optional MMR.

In [None]:
def embed_query(q: str, backend) -> np.ndarray:
    q = q.strip()
    qv = backend.encode([q], batch_size=1).astype(np.float32)
    qv = l2_normalize(qv)
    return qv

def search_faiss(index: faiss.Index, query_vec: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]:
    D, I = index.search(query_vec, top_k)
    return D[0], I[0]

def mmr_rerank(query_vec: np.ndarray,
               candidates_idx: np.ndarray,
               candidate_vecs: np.ndarray,
               top_k: int = 5,
               lambda_mult: float = 0.5) -> List[int]:
    # Simple MMR: select items maximizing lambda*sim(query,doc) - (1-lambda)*max_sim_to_selected
    selected = []
    candidate_list = candidates_idx.tolist()
    
    # Precompute similarity of all candidates to the query
    sims_to_query = (candidate_vecs @ query_vec[0])
    
    while len(selected) < min(top_k, len(candidate_list)):
        best_candidate = -1
        best_score = -np.inf
        
        for idx in candidate_list:
            doc_sim_to_query = sims_to_query[idx]
            if not selected:
                score = doc_sim_to_query
            else:
                max_sim_to_selected = max(candidate_vecs[idx] @ candidate_vecs[j] for j in selected)
                score = lambda_mult * doc_sim_to_query - (1 - lambda_mult) * max_sim_to_selected
                
            if score > best_score:
                best_score = score
                best_candidate = idx
                
        if best_candidate != -1:
            selected.append(best_candidate)
            candidate_list.remove(best_candidate)
            
    return selected

def search(query: str,
           index: faiss.Index,
           backend,
           store: List[Dict[str, Any]],
           vectors: Optional[np.ndarray] = None,
           top_k: int = 5,
           use_mmr: bool = False,
           mmr_lambda: float = 0.5,
           fetch_k: int = 25) -> List[Dict[str, Any]]:
    
    qv = embed_query(query, backend)
    D, I = search_faiss(index, qv, top_k=fetch_k if use_mmr and vectors is not None else top_k)
    
    if use_mmr and vectors is not None:
        selected_indices = mmr_rerank(qv, I, vectors, top_k=top_k, lambda_mult=mmr_lambda)
        # Recalculate scores and indices for the final top_k
        final_I = np.array(selected_indices, dtype=int)
        final_D = np.array([(vectors[i] @ qv[0]) for i in final_I], dtype=np.float32)
    else:
        final_I, final_D = I, D
        
    results = []
    for score, idx in zip(final_D[:top_k], final_I[:top_k]):
        rec = store[idx]
        results.append({
            "score": float(score),
            "id": rec["id"],
            "text": rec["text"],
            "metadata": rec["metadata"]
        })
    return results

## Try some queries

Use examples relevant to your PDFs. For the sample from Module 1/2: `"termination"`, `"confidentiality"`, `"indirect damages"`.

In [None]:
queries = [
    "termination notice",
    "confidentiality obligations",
    "indirect damages liability"
]

for q in queries:
    print(f"\n--- Query: '{q}' (with MMR) ---")
    res = search(q, index2, backend, store2, vectors=vectors2, top_k=3, use_mmr=True, mmr_lambda=0.5)
    for r in res:
        m = r["metadata"]
        cite = f"{m.get('doc_id','?')} p.{m.get('page','?')}"
        print(f"  score={r['score']:.3f} | {cite} | section: {m.get('section_title')}")
        print("    " + r["text"][:220].replace("\n", " ") + "...")

## Mini Task
1)  Point `INPUT_CHUNKS` to your Module 2 JSONL.
2)  Choose `EMBEDDING_BACKEND = "minilm"` (or `"gemini"` with `GOOGLE_API_KEY` set).
3)  Run embedding → indexing → search.
4)  Inspect top hits and verify they make sense with citations (`doc_id`/`page`/`section`).

## Optional: Answer synthesis with Gemini
- Retrieve top-k chunks, then ask Gemini to compose a concise answer citing pages.
- Requires `GOOGLE_API_KEY`.
- Keep prompts small; pass only the needed chunk texts to control cost.

In [None]:
USE_GEMINI_FOR_ANSWERS = bool(GOOGLE_API_KEY) and EMBEDDING_BACKEND == 'gemini'

def answer_with_gemini(query: str, retrieved: List[Dict[str, Any]]) -> str:
    assert GOOGLE_API_KEY, "GOOGLE_API_KEY not set."
    import google.generativeai as genai
    genai.configure(api_key=GOOGLE_API_KEY)
    
    ctx_lines = []
    for r in retrieved:
        m = r["metadata"]
        cite = f"{m.get('doc_id','?')} p.{m.get('page','?')}"
        ctx_lines.append(f"Source [{cite}]:\n{r['text']}")
    context = "\n\n".join(ctx_lines)
    
    prompt = (
        "You are a helpful assistant. Use only the provided context to answer the question. "
        "Your answer must be concise and grounded in the provided text. "
        "Cite your sources in brackets like [doc_id p.page]. If the information is not in the context, say you don't know.\n\n"
        f"-- CONTEXT --\n{context}\n\n-- QUESTION --\n{query}\n\n-- ANSWER --"
    )
    
    model = genai.GenerativeModel("gemini-1.5-flash")
    resp = model.generate_content(prompt)
    return resp.text

if USE_GEMINI_FOR_ANSWERS:
    q = "What are the termination conditions in this agreement?"
    print(f"\n--- Answering with Gemini: '{q}' ---")
    top = search(q, index2, backend, store2, vectors=vectors2, top_k=5, use_mmr=True)
    ans = answer_with_gemini(q, top)
    print(ans)
else:
    print("\nGemini answering disabled (GOOGLE_API_KEY not set or backend is not 'gemini').")
    print("Retrieval results above can guide manual inspection.")

## Updating the index with new chunks
- You can append new documents without rebuilding from scratch (`IndexFlatIP` supports `add`).
- Remember to:
  - Embed new chunks with the **same backend**.
  - L2-normalize embeddings.
  - Append to FAISS, `vectors.npy`, and `docstore.jsonl`.
  - Update manifest count.

In [None]:
def add_chunks_to_index(new_chunks_path: Path,
                        index_dir: Path,
                        backend) -> None:
    # Load existing
    index, vectors, store, manifest = load_index(index_dir)
    
    # Load new chunks
    new_records = load_jsonl(new_chunks_path)
    new_texts = prepare_texts(new_records)
    if not new_texts:
        print("No new chunks found.")
        return
        
    print(f"Adding {len(new_records)} new chunks to index...")
    new_emb = backend.encode(new_texts, batch_size=BATCH_SIZE).astype(np.float32)
    new_emb = l2_normalize(new_emb)
    
    # Add to index and buffers
    index.add(new_emb)
    vectors = np.concatenate([vectors, new_emb], axis=0)
    for r in new_records:
        store.append({"id": r["id"], "text": r["text"], "metadata": r.get("metadata", {})})
        
    # Save back
    save_index(index, vectors, store, manifest['backend'], manifest['dim'], index_dir)
    print(f"Index updated: total vectors={index.ntotal}")

# Example (disabled by default)
# new_file = Path("./data/outputs/new_doc_chunks.jsonl")
# if new_file.exists():
#     add_chunks_to_index(new_file, INDEX_DIR, backend)
# else:
#     print(f"Skipping index update: '{new_file}' not found.")

## Practical notes
- `all-MiniLM-L6-v2` (384-dim) is a great default: small, fast, robust for general semantic search.
- Normalize embeddings and use `IndexFlatIP` to emulate cosine similarity.
- Persist a manifest including backend name and dim; **avoid mixing different embedding models** in one index.
- For 100k+ chunks, consider IVF/HNSW indexes for faster search (e.g., `IndexIVFFlat`), but start with `Flat` for correctness.
- Keep `doc_id`/`page`/`section` in docstore for citations and traceability.
- MMR re-ranking can improve diversity when top results are near-duplicates.

### Next steps
- Integrate with a LangChain retriever and QA chain.
- Add caching for embeddings (hash chunk text → vector).
- Add quality evals: query sets, hit-rate, manual judgments with a small rubric.