In [76]:
# Install packages
%pip install --upgrade google-genai sentence-transformers faiss-cpu langchain-text-splitters python-dotenv --quiet


Note: you may need to restart the kernel to use updated packages.


In [77]:
# Imports

from google import genai
from google.genai import types

import os,math, uuid
from dataclasses import dataclass
from typing import List, Dict, Tuple

import numpy as np
import faiss

from sentence_transformers import SentenceTransformer # for local embedding
from langchain_text_splitters import RecursiveCharacterTextSplitter #simple chunker 

from dotenv import load_dotenv
load_dotenv()

GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    raise RuntimeError("GEMINI_API_KEY is not set")

client = genai.Client(api_key=GEMINI_API_KEY)

GEMINI_GEN_MODEL="gemini-2.5-flash" #fast, good for RAG

LOCAL_EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"  # small, fast, good quality





In [78]:
# Document Example on which we'll perform retrieveal 
# Each doc has: id, title, text, and optional metadata.
DOCUMENTS = [
    {
        "id": "doc1",
        "title": "Baking Basics",
        "text": "Cakes are baked at moderate temperatures. Common ingredients are flour, sugar, eggs, and butter. Icing is added after the cake cools.",
        "meta": {"source": "kitchen-notes", "lang": "en"}
    },
    {
        "id": "doc2",
        "title": "Healthy Desserts",
        "text": "For a lighter dessert, substitute part of the sugar with fruit purees. Consider whole-grain flour. Yogurt frostings can reduce fat.",
        "meta": {"source": "health-blog", "lang": "en"}
    },
    {
        "id": "doc3",
        "title": "Birthday Traditions",
        "text": "Many cultures celebrate birthdays with a sweet cake, candles, and a wish. Popular flavors include chocolate and vanilla.",
        "meta": {"source": "culture-wiki", "lang": "en"}
    },
]

In [79]:
# Chunking 
# Goal: split long docs into chunks so retrieval can target the right part.
# We’ll use a character-based splitter with a small overlap.

def chunk_documents(docs:List[Dict],chunk_size,chunk_overlap):
    # Returns: list of chunk dicts with fields: chunk_id, doc_id, text, meta
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        separators=["\n\n", "\n", ".", " ", ""]
    )
    chunks=[]
    for doc in docs:
        parts = splitter.split_text(doc["text"])
        for i,part in enumerate(parts):
            chunks.append({
                "chunk_id":f"{doc['id']}::chunk{i}",
                "doc_id":doc["id"],
                "text":part,
                "meta":doc.get("meta",{}),
                "title":doc.get("title","")
            })
            
    return chunks
CHUNKS = chunk_documents(DOCUMENTS,chunk_size=100,chunk_overlap=40)
print(len(CHUNKS))
# print(CHUNKS[0],end="\n\n")


6


In [80]:
# Embeddings and storing it in FAISS (vector store)

local_embedder = SentenceTransformer(LOCAL_EMBED_MODEL)

def embed_texts(texts:List[str]) -> np.ndarray:
    """
    Turn a list of texts into dense vectors.
    Returns a 2D numpy array: shape (N, D)
    """
    vecs = local_embedder.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
    return vecs.astype("float32")

# Build a FAISS Index
@dataclass
class VectorIndex:
    """
    Holds:
      - faiss index (for fast vector similarity)
      - id_to_chunk (mapping from FAISS row -> our chunk dict)
      - dim (vector dimension)
    """
    index: faiss.IndexFlatIP
    id_to_chunk: Dict[int, Dict]
    dim: int

def build_faiss_index(chunks: List[Dict]) -> VectorIndex:
    """
    1) Embed each chunk
    2) Build a FAISS IndexFlatIP (cosine similarity if vectors are normalized)
    3) Add vectors to the index in the same order as id_to_chunk keys
    """
    texts = [c["text"] for c in chunks]
    vecs = embed_texts(texts)       # Step 1: embed
    dim = vecs.shape[1]             # Step 2: get vector dimension

    index = faiss.IndexFlatIP(dim)  # Step 3: build index (Inner Product == cosine if normalized)
    index.add(vecs)                 # Step 4: add vectors

    id_to_chunk = {i: chunks[i] for i in range(len(chunks))}  # mapping row → chunk
    return VectorIndex(index=index, id_to_chunk=id_to_chunk, dim=dim)


In [81]:
# Build the vector index
vindex = build_faiss_index(CHUNKS)
print(f"Built index with {len(CHUNKS)} chunks")


Built index with 6 chunks


In [82]:
# Retrieval (top-k search) + simple formatting for the LLM
def search(vindex: VectorIndex, query: str, k: int = 4) -> List[Dict]:
    """
    1) Embed the query the same way as documents
    2) Search top-k vectors in FAISS
    3) Return the matching chunk dicts with scores
    """
    qvec = embed_texts([query])

    scores, ids = vindex.index.search(qvec, k)
    results = []
    for rank, (sid, score) in enumerate(zip(ids[0], scores[0])):
        if sid == -1:
            continue
        ch = vindex.id_to_chunk[sid].copy()
        ch["score"] = float(score)
        ch["rank"] = rank
        results.append(ch)
    return results

def make_rag_prompt(question: str, contexts: List[Dict]) -> str:
    """
    Build a plain prompt for Gemini:
    - Include instructions
    - Include citations with titles or doc_ids
    - Keep it short and clear
    """
    header = (
        "You are a helpful assistant. Answer the question using ONLY the context.\n"
        "If the answer is not in the context, say you don't know.\n"
        "Cite sources as [title or doc_id].\n\n"
    )
    ctx_lines = []
    for i, c in enumerate(contexts):
        tag = c["title"] or c["doc_id"]
        ctx_lines.append(f"[{i+1}:{tag}] {c['text']}")
    ctx_block = "\n".join(ctx_lines)

    q_block = f"\n\nQuestion: {question}\nAnswer:"
    return header + ctx_block + q_block



In [84]:
# Generate the answer with Gemini

def ask_gemini(prompt: str) -> str:
    """
    Send a single prompt to Gemini and return the text output.
    We add light safety settings to be safe in public apps.
    """
    safety = [
        types.SafetySetting(
            category=types.HarmCategory.HARM_CATEGORY_HARASSMENT,
            threshold=types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
        ),
        types.SafetySetting(
            category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
            threshold=types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
        ),
    ]
    config = types.GenerateContentConfig(
        safety_settings=safety,
        temperature=0.3,   # lower temp for grounded answers
    )
    resp = client.models.generate_content(
        model=GEMINI_GEN_MODEL,
        contents=[types.Content(parts=[types.Part(text=prompt)])],
        config=config
    )
    # different SDK versions expose either .text or candidates
    if hasattr(resp, "text") and resp.text:
        return resp.text
    if resp.candidates and resp.candidates[0].content.parts:
        return resp.candidates[0].content.parts[0].text
    return "(no text returned)"

def basic_rag_answer(question: str, k: int = 4) -> Tuple[str, List[Dict]]:
    """
    1) Retrieve top-k chunks
    2) Make a prompt with those chunks
    3) Ask Gemini
    Returns (answer_text, contexts_used)
    """
    hits = search(vindex, question, k=k)
    prompt = make_rag_prompt(question, hits)
    answer = ask_gemini(prompt)
    return answer, hits

# Try it
q = "How do people usually celebrate birthdays with cake?"
ans, used = basic_rag_answer(q, k=3)
print("ANSWER:\n", ans, "\n")
print("CONTEXT USED:")
for c in used:
    print(c["rank"], round(c["score"], 3), c["title"] or c["doc_id"])


ANSWER:
 Many cultures celebrate birthdays with a sweet cake, candles, and a wish [1:Birthday Traditions]. 

CONTEXT USED:
0 0.715 Birthday Traditions
1 0.444 Baking Basics
2 0.399 Baking Basics
