In [None]:
!pip install pymilvus pymilvus[milvus_lite] datasets transformers sentence-transformers ragas evaluate

In [None]:
import json
import os

import pandas as pd
import numpy as np
import transformers, torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM

from sentence_transformers import SentenceTransformer, CrossEncoder

from datasets import Dataset

from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType

from evaluate import load
from ragas import evaluate
from ragas.metrics import (
    faithfulness,
    answer_relevancy,
    context_recall,
    context_precision,
)
from ragas.run_config import RunConfig
from ragas.llms import LangchainLLMWrapper
from langchain_openai import ChatOpenAI

from tqdm.auto import tqdm

# Read Dataset

In [None]:
passages = pd.read_parquet(
    "hf://datasets/rag-datasets/rag-mini-wikipedia/data/passages.parquet/part.0.parquet"
)

print(passages.shape)
passages.head()

In [None]:
queries = pd.read_parquet(
    "hf://datasets/rag-datasets/rag-mini-wikipedia/data/test.parquet/part.0.parquet"
)

print(queries.shape)
queries.head()

# EDA

In [None]:
# Analyze passage lengths
passages["length"] = passages["passage"].str.len()
print(f"Min length: {passages['length'].min()}")
print(f"Max length: {passages['length'].max()}")
print(f"Mean length: {passages['length'].mean():.2f}")
print(f"Median length: {passages['length'].median()}")

# Check for missing values
print(f"\nMissing values: {passages['passage'].isna().sum()}")

# Setup Dependencies

## Prompts
Using Basic Prompt as it performed the best


In [None]:
def generate_basic_prompt(query, context):
    return f"Context: {context}: \n Question: {query} "

## Embedding Models
Using MiniLM-L6-v2 as it had better performance than mpnet-base-v2

In [None]:
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

## Top k

*   Retrieving 10: We retrieve 10 candidates to maximize recall, since relevant passages often appear beyond the top-5 (e.g., recall is higher at k=10 vs k=3/5)
*   Using top-5: We then feed only the top-5 (after reranking) to the model to boost precision and faithfulness, avoiding the noise and dilution that shows up when all 10 are passed in.

In [None]:
retrieval_top_k = 10
context_top_k = 5

## Model

In [None]:
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Helper Functions

In [None]:
def create_embeddings_and_rag_data(model):
    embeddings = model.encode(
        passages["passage"].tolist(),
        convert_to_tensor=True,
        show_progress_bar=False,
        batch_size=64,
    )

    rag_data = [
        {
            "id": idx,
            "passage": passages.iloc[idx]["passage"],
            "embedding": embeddings[idx].tolist(),
        }
        for idx in range(len(passages))
    ]

    return embeddings, rag_data

In [None]:
def create_schema(embed_dim):
    id_ = FieldSchema(
        name="id",
        dtype=DataType.INT64,
        is_primary=True,
        auto_id=False,
    )

    passage = FieldSchema(
        name="passage",
        dtype=DataType.VARCHAR,
        max_length=2600,
    )
    embedding = FieldSchema(
        name="embedding",
        dtype=DataType.FLOAT_VECTOR,
        dim=embed_dim,
    )

    schema = CollectionSchema(
        fields=[id_, passage, embedding],
        description="RAG Wikipedia passages",
        auto_id=False,
    )

    return schema

In [None]:
def setup_collection(embed_dim, collection_name, rag_data):
    if collection_name in client.list_collections():
        client.drop_collection(collection_name)

    schema = create_schema(embed_dim)

    client.create_collection(
        collection_name=collection_name,
        schema=schema,
    )

    client.insert(collection_name=collection_name, data=rag_data)

    index_params = MilvusClient.prepare_index_params()
    index_params.add_index(
        field_name="embedding",
        index_type="FLAT",
        metric_type="COSINE",
    )
    client.create_index(
        collection_name=collection_name,
        index_params=index_params,
    )

    client.load_collection(collection_name=collection_name)

    print(f"{collection_name} created and loaded into memory")

In [None]:
def generate_answer(prompt, model, tokenizer):
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs)
    answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    return answer[0]

In [None]:
get_output_file_name = lambda exp_name: f"out_{exp_name}.json"
get_results_file_name = lambda exp_name: f"exp_{exp_name}.json"

In [None]:
def save_outputs(results, exp_name):
    filename = get_output_file_name(exp_name)
    with open(filename, "w") as f:
        json.dump(results, f)

    print(f"Saved {filename}")

In [None]:
def save_results(config, metrics, exp_name):
    output = {
        "config": config,
        "metrics": metrics,
    }

    filename = get_results_file_name(exp_name)
    with open(filename, "w") as f:
        json.dump(output, f)

    print(f"Saved {filename}")

In [None]:
np.random.seed(42)


def select_random_subset(results, size=25):
    random_indices = np.random.choice(len(results), size=size, replace=False)
    return [results[i] for i in random_indices]

In [None]:
squad_metric = load("squad")


def perform_basic_evaluation(results):
    predictions = [
        {
            "id": str(i),
            "prediction_text": r["predicted_answer"],
        }
        for i, r in enumerate(results)
    ]

    references = [
        {
            "id": str(i),
            "answers": {
                "text": [r["ground_truth"]],
                "answer_start": [0],
            },
        }
        for i, r in enumerate(results)
    ]

    metrics = squad_metric.compute(predictions=predictions, references=references)

    return {
        "f1_score": metrics["f1"],
        "exact_match": metrics["exact_match"],
    }

In [29]:
def perform_ragas_evaluation(results):
    results = select_random_subset(results, size=100)

    data = {
        "question": [r["question"] for r in results],
        "answer": [r["predicted_answer"] for r in results],
        "contexts": [r["contexts"] for r in results],
        "ground_truth": [r["ground_truth"] for r in results],
    }

    dataset = Dataset.from_dict(data)

    config = RunConfig(max_workers=8, timeout=60)
    eval_result_sequential = evaluate(
        dataset, metrics=[answer_relevancy], run_config=config
    )

    eval_result_parallel = evaluate(
        dataset,
        metrics=[faithfulness, context_recall, context_precision],
        run_config=config,
    )

    # Combine results
    agg_scores = {}

    for metric, values in eval_result_parallel._scores_dict.items():
        agg_scores[metric] = float(np.nanmean(values))

    for metric, values in eval_result_sequential._scores_dict.items():
        agg_scores[metric] = float(np.nanmean(values))

    return agg_scores

In [None]:
def evaluate_results(results):
    return {**perform_basic_evaluation(results), **perform_ragas_evaluation(results)}

# Helper Functions for Advanced RAG

## Query Rewriting

In [None]:
llm = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0,
)


def rewrite_query(query):

    system_prompt = """
        You are a query rewriting assistant.Rewrite the given question into a concise keyword-style query optimized for Wikipedia retrieval in a RAG setup.
        Rules:
        - Keep it as concise as possible.
        - Remove filler words like 'what', 'is', 'was'.
        - Do not explain.
        - Output ONLY the rewritten query, nothing else.
    """

    response = llm.invoke(
        [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": query},
        ]
    )

    return response.content.strip()

## Reranking

In [None]:
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")


def rerank(query, retrieved_docs, top_k=5):
    pairs = [(query, doc["text"]) for doc in retrieved_docs]
    scores = cross_encoder.predict(pairs)

    for i, score in enumerate(scores):
        retrieved_docs[i]["rerank_score"] = float(score)

    reranked = sorted(retrieved_docs, key=lambda x: x["rerank_score"], reverse=True)
    return reranked[:top_k]

# Run Experiment

In [None]:
client = MilvusClient("rag_wikipedia_mini.db")
collection_name = "all_MiniLM_L6_v2"

embeddings, rag_data = create_embeddings_and_rag_data(embedding_model)
setup_collection(384, collection_name, rag_data)
experiment_name = f"advanced_all_MiniLM_L6_v2_basic_{retrieval_top_k}_{context_top_k}"

In [None]:
results = []

if os.path.exists(get_output_file_name(experiment_name)):
    print(f"Output for {experiment_name} exists! Skipping...")
else:
    for index, row in tqdm(queries.iterrows(), total=len(queries)):
        query = row["question"]

        # Advanced RAG Optimization - Query Rewrite
        query = rewrite_query(query)

        search_results = client.search(
            collection_name=collection_name,
            data=[embedding_model.encode(query).tolist()],
            output_fields=["passage"],
            limit=retrieval_top_k,
        )

        # Advanced RAG Optimization - Reranking
        docs = [
            {"text": hit.entity.get("passage"), "score": hit.score}
            for hit in search_results[0]
        ]
        top_results = rerank(query, docs, context_top_k)

        context = [doc["text"] for doc in top_results]
        context_str = "\n".join(context)

        prompt = generate_basic_prompt(query, context_str)
        answer = generate_answer(prompt, model, tokenizer)
        results.append(
            {
                "question": row["question"],
                "rewritten_query": query,
                "predicted_answer": answer,
                "ground_truth": row["answer"],
                "contexts": context,
            }
        )

    save_outputs(results, experiment_name)

In [None]:
if os.path.exists(get_results_file_name(experiment_name)):
    print(f"Results for {experiment_name} exists! Skipping...")
else:
    if not results:
        with open(get_output_file_name(experiment_name), "r") as f:
            results = json.load(f)
    metrics = evaluate_results(results)
    save_results(
        {
            "embedding_model": "all_MiniLM_L6_v2",
            "prompt": "basic",
            "top_k": context_top_k,
        },
        metrics,
        experiment_name,
    )