<a href="https://colab.research.google.com/github/rain027/RAG_learning/blob/main/RAG_BASICS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install chromadb
!pip install rank-bm25

In [None]:
!pip install nltk
!pip install tiktoken   # for sentence tokenization and to include overlapping

In [None]:
!pip install rouge-score bert-score

In [None]:
!pip install PyPDF2


In [None]:
# pulling a a rag article from wikipedia and feeding the cleaned version to the pipeline for better retrieval

!pip install wikipedia-api
import wikipediaapi

# Add user_agent to follow Wikipedia's policy
wiki_wiki = wikipediaapi.Wikipedia(
    user_agent="RAG-Experiment/1.0 (https://colab.research.google.com/)",
    language="en"
)

page = wiki_wiki.page("Retrieval-augmented generation")

if page.exists():
    text = page.text
    print("Extracted characters:", len(text))
    print(text[:1000])  # preview first 1000 characters
else:
    print("Page not found")

# Clean text (optional, like before)
import re
def clean_text(text):
    return re.sub(r'\s+', ' ', text).strip()

text = clean_text(text)




In [None]:

from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.utils import embedding_functions



In [None]:
from sentence_transformers import CrossEncoder

In [None]:
from rank_bm25 import BM25Okapi

In [None]:
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')
from nltk.tokenize import word_tokenize

In [None]:
from rouge_score import rouge_scorer
from bert_score import score as bert_score


In [None]:
from nltk.tokenize import sent_tokenize

In [None]:
# Step 1: Load model
model = SentenceTransformer("all-MiniLM-L6-v2")



In [None]:
# Step 2: Setup ChromaDB
chroma_client = chromadb.Client()
collection = chroma_client.create_collection(name="docss")



In [None]:
# Step 3.1 : token based chunking function
def chunk_text_by_tokens(text, max_tokens=300, overlap=50):
    tokens = word_tokenize(text)
    chunks = []
    start = 0
    while start < len(tokens):
        chunk_tokens = tokens[start:start+max_tokens]
        chunk_text = " ".join(chunk_tokens)
        chunks.append(chunk_text)
        start += (max_tokens - overlap)
    return chunks

In [None]:
def chunk_text_for_evaluation(text, max_tokens=150, overlap=25):
    """
    Create smaller chunks specifically for evaluation.
    max_tokens: number of tokens per chunk
    overlap: tokens overlapping between consecutive chunks
    """
    tokens = word_tokenize(text)
    chunks = []
    start = 0
    while start < len(tokens):
        chunk_tokens = tokens[start:start+max_tokens]
        chunk_text = " ".join(chunk_tokens)
        chunks.append(chunk_text)
        start += (max_tokens - overlap)
    return chunks


In [None]:
# Step 3.3: chunking the loaded document stored in the variable text

chunks = chunk_text_by_tokens(text, max_tokens=300, overlap=50)
eval_chunks = chunk_text_for_evaluation(text, max_tokens=150, overlap=25)
print(f"Total chunks created: {len(chunks)}")
print(chunks[0])


In [None]:
# Step 4: Embed and Store chunks in DB
for i, chunk in enumerate(chunks):
    emb = model.encode(chunk).tolist()
    collection.add(documents=[chunk], embeddings=[emb], ids=[str(i)])





In [None]:
#prepare BM25
tokenized_chunks = [c.lower().split() for c in chunks]
bm25 = BM25Okapi(tokenized_chunks)


In [None]:
# Loading Cross Encoder
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

In [None]:
# Retrieval function

def retrieve_chunks(query, method="embedding", top_k=5, use_hybrid=False, use_rerank=True):
    """
    method: 'bm25' or 'embedding'
    top_k: number of chunks to retrieve before optional reranking
    use_hybrid: combine BM25 + embedding retrieval
    use_rerank: apply cross-encoder re-ranking
    """
    top_chunks = []

    # --- Hybrid retrieval ---
    if use_hybrid:
        # Step 1: BM25 top-k
        tokenized_query = query.lower().split()
        bm25_scores = bm25.get_scores(tokenized_query)
        top_bm25 = sorted(zip(chunks, bm25_scores), key=lambda x: x[1], reverse=True)[:top_k]
        top_bm25_chunks = [c for c, s in top_bm25]

        # Step 2: Embedding top-k
        query_emb = model.encode(query).tolist()
        embedding_results = collection.query(query_embeddings=[query_emb], n_results=top_k)
        top_embedding_chunks = embedding_results["documents"][0]

        # Step 3: Merge (union) and remove duplicates
        top_chunks = list(dict.fromkeys(top_bm25_chunks + top_embedding_chunks))

    else:
        # --- Single method retrieval ---
        if method == "embedding":
            query_emb = model.encode(query).tolist()
            results = collection.query(query_embeddings=[query_emb], n_results=top_k)
            top_chunks = results["documents"][0]
        elif method == "bm25":
            tokenized_query = query.lower().split()
            scores = bm25.get_scores(tokenized_query)
            top_chunks = sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)[:top_k]
            top_chunks = [c for c, s in top_chunks]
        else:
            print("Invalid method! Choose 'bm25' or 'embedding'.")
            return None

    # --- Optional cross-encoder re-ranking ---
    if use_rerank:
        rerank_scores = cross_encoder.predict([(query, c) for c in top_chunks])
        top_chunks = [c for _, c in sorted(zip(rerank_scores, top_chunks), reverse=True)]


    # --- Display top 3 chunks ---
    print("\n=== Top Chunks ===")
    for i, c in enumerate(top_chunks[:3]):
        print(f"{i+1}. {c}\n")

    return top_chunks[:3]


In [None]:
def evaluate_retrieval(retrieved_chunks, reference_text, max_tokens=50):
    """
    Evaluate retrieval using ROUGE and BERTScore, handling long chunks more gracefully.

    retrieved_chunks: list of retrieved text chunks
    reference_text: ground-truth text
    max_tokens: max tokens to consider per chunk for evaluation
    """
    # Step 1: Trim each retrieved chunk
    trimmed_chunks = [" ".join(c.split()[:max_tokens]) for c in retrieved_chunks]

    # Step 2: Split reference into sentences
    reference_sents = sent_tokenize(reference_text)

    # Step 3: Build evaluation text by selecting chunks that have overlap with reference
    eval_text = []
    for chunk in trimmed_chunks:
        for ref_sent in reference_sents:
            if any(word.lower() in chunk.lower() for word in ref_sent.split()):
                eval_text.append(chunk)
                break
    retrieved_text = " ".join(eval_text)

    # Step 4: Compute ROUGE
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    rouge_scores = scorer.score(reference_text, retrieved_text)
    print("\n--- ROUGE Scores ---")
    for key, value in rouge_scores.items():
        print(f"{key}: Precision={value.precision:.3f}, Recall={value.recall:.3f}, F1={value.fmeasure:.3f}")

    # Step 5: Compute BERTScore
    P, R, F1 = bert_score([retrieved_text], [reference_text], lang='en', model_type='roberta-large', rescale_with_baseline=True)
    print(f"\nBERTScore F1: {F1[0].item():.3f}")

In [None]:
# Step 5: Query
query = "What is retrieval augmented generation?"

reference_text = """Retrieval-augmented generation ( RAG ) is a technique that enables large language models ( LLMs ) to retrieve and incorporate new information . With RAG , LLMs do not respond to user queries until they refer to a specified set of documents . These documents supplement information from the LLM 's pre-existing training data . This allows LLMs to use domain-specific and/or updated information that is not available in the training data . For example , this helps LLM-based chatbots access internal company data or generate responses based on authoritative sources . RAG improves large language models ( LLMs ) by incorporating information retrieval before generating responses . Unlike traditional LLMs that rely on static training data , RAG pulls relevant text from databases , uploaded documents , or web sources . According to Ars Technica , `` RAG is a way of improving LLM performance , in essence by blending the LLM process with a web search or other document look-up process to help LLMs stick to the facts . '' This method helps reduce AI hallucinations , which have caused chatbots to describe policies that do not exist , or recommend nonexistent legal cases to lawyers that are looking for citations to support their arguments . RAG also reduces the need to retrain LLMs with new data , saving on computational and financial costs . Beyond efficiency gains , RAG also allows LLMs to include sources in their responses , so users can verify the cited sources . This provides greater transparency , as users can cross-check retrieved content to ensure accuracy and relevance . The term RAG was first introduced in a 2020 research paper from Meta ."""

# Ask user if they want to apply hybrid retrieval
use_hybrid = input("Do you want to use hybrid retrieval (BM25 + embeddings)? (y/n): ").strip().lower() == "y"

# Ask user if they want to apply cross-encoder reranking
use_rerank = input("Do you want to apply cross-encoder reranking? (y/n): ").strip().lower() == "y"

# Ask user which single method to use if not hybrid
method = "embedding"
if not use_hybrid:
    method_input = input("Choose retrieval method (bm25 / embedding): ").strip().lower()
    if method_input in ["bm25", "embedding"]:
        method = method_input


retrieved_chunks = retrieve_chunks(query, method="embedding", top_k=5, use_hybrid=True, use_rerank=True)
evaluate_retrieval(retrieved_chunks, reference_text, max_tokens=50)