In [None]:
!pip install pymilvus pymilvus[milvus_lite] datasets transformers sentence-transformers ragas evaluate

In [None]:
import json
import os

import pandas as pd
import numpy as np
import transformers, torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM

from sentence_transformers import SentenceTransformer

from datasets import Dataset

from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType

from evaluate import load
from ragas import evaluate
from ragas.metrics import (
    faithfulness,
    answer_relevancy,
    context_recall,
    context_precision,
)
from ragas.run_config import RunConfig
from ragas.llms import LangchainLLMWrapper
from langchain_openai import ChatOpenAI

from tqdm.auto import tqdm

# Read Dataset

In [None]:
passages = pd.read_parquet(
    "hf://datasets/rag-datasets/rag-mini-wikipedia/data/passages.parquet/part.0.parquet"
)

print(passages.shape)
passages.head()

In [None]:
queries = pd.read_parquet(
    "hf://datasets/rag-datasets/rag-mini-wikipedia/data/test.parquet/part.0.parquet"
)

print(queries.shape)
queries.head()

# EDA

In [None]:
# Analyze passage lengths
passages["length"] = passages["passage"].str.len()
print(f"Min length: {passages['length'].min()}")
print(f"Max length: {passages['length'].max()}")
print(f"Mean length: {passages['length'].mean():.2f}")
print(f"Median length: {passages['length'].median()}")

# Check for missing values
print(f"\nMissing values: {passages['passage'].isna().sum()}")

# Setup Dependencies

## Prompts
1. Basic Prompt (Only Context and Question)
2. Persona Prompt
3. CoT Prompt


In [None]:
def generate_basic_prompt(query, context):
    return f"Context: {context}: \n Question: {query} "

In [None]:
def generate_persona_prompt(query, context):
    return f"""
    You are a knowledgeable and trustworthy Wikipedia-style guide.
    Your role is to explain answers clearly, objectively, and concisely, using only the retrieved passages.

    Role alignment:
    - Speak with the calm, factual tone of a reference editor.
    - Present information as if you are curating reliable knowledge.
    - Avoid speculation or personal opinions.

    Guidelines:
    - Keep answers short and direct, but add a brief explanation if it improves clarity.
    - Use clear, well-structured sentences that feel authoritative and easy to read.
    - If the passages do not contain the answer, say: "The passage does not provide enough information."


    Context: {context}: \n Question: {query}
    """

In [None]:
def generate_cot_prompt(query, context):
    return f"""
    You are an assistant that answers questions using only the retrieved passages.

    Process:
    1. Read the question carefully.
    2. Identify the most relevant information in the passages.
    3. Reason step by step to connect facts and resolve conflicts.
    4. Give a clear and concise final answer.

    Answer characteristics:
    - Short and precise, avoiding unnecessary words.
    - Faithful to the passages, with no outside knowledge.
    - Direct phrasing that can be matched exactly when possible.
    - If the passages do not provide enough information, reply: "The passage does not provide enough information."

    Context: {context}: \n Question: {query}
    """

## Embedding Models

In [None]:
embedding_model_384 = SentenceTransformer("all-MiniLM-L6-v2")
embedding_model_768 = SentenceTransformer("all-mpnet-base-v2")

## Model

In [None]:
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Define the Experiment Configurations

In [None]:
configs = {
    "prompts": [
        ("basic", generate_basic_prompt),
        ("persona", generate_persona_prompt),
        ("cot", generate_cot_prompt),
    ],
    "embedding_models": [
        ("all_MiniLM_L6_v2", 384, embedding_model_384),
        ("all_mpnet_base_v2", 768, embedding_model_768),
    ],
    "top_k": [3, 5, 10],
}

results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

# Helper Functions

In [None]:
def create_embeddings_and_rag_data(model):
    embeddings = model.encode(
        passages["passage"].tolist(),
        convert_to_tensor=True,
        show_progress_bar=False,
        batch_size=64,
    )

    rag_data = [
        {
            "id": idx,
            "passage": passages.iloc[idx]["passage"],
            "embedding": embeddings[idx].tolist(),
        }
        for idx in range(len(passages))
    ]

    return embeddings, rag_data

In [None]:
def create_schema(embed_dim):
    id_ = FieldSchema(
        name="id",
        dtype=DataType.INT64,
        is_primary=True,
        auto_id=False,
    )

    passage = FieldSchema(
        name="passage",
        dtype=DataType.VARCHAR,
        max_length=2600,
    )
    embedding = FieldSchema(
        name="embedding",
        dtype=DataType.FLOAT_VECTOR,
        dim=embed_dim,
    )

    schema = CollectionSchema(
        fields=[id_, passage, embedding],
        description="RAG Wikipedia passages",
        auto_id=False,
    )

    return schema

In [None]:
def setup_collection(embed_dim, collection_name, rag_data):
    if collection_name in client.list_collections():
        client.drop_collection(collection_name)

    schema = create_schema(embed_dim)

    client.create_collection(
        collection_name=collection_name,
        schema=schema,
    )

    client.insert(collection_name=collection_name, data=rag_data)

    index_params = MilvusClient.prepare_index_params()
    index_params.add_index(
        field_name="embedding",
        index_type="FLAT",
        metric_type="COSINE",
    )
    client.create_index(
        collection_name=collection_name,
        index_params=index_params,
    )

    client.load_collection(collection_name=collection_name)

    print(f"{collection_name} created and loaded into memory")

In [None]:
def generate_answer(prompt, model, tokenizer):
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs)
    answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    return answer[0]

In [None]:
get_output_file_name = lambda exp_name: f"{results_dir}/out_{exp_name}.json"
get_results_file_name = lambda exp_name: f"{results_dir}/exp_{exp_name}.json"

In [None]:
def save_outputs(results, exp_name):
    with open(get_output_file_name(exp_name), "w") as f:
        json.dump(results, f)

    print(f"Saved {results_dir}/out_{exp_name}.json")

In [None]:
def save_results(config, metrics, exp_name):
    output = {
        "config": config,
        "metrics": metrics,
    }

    filename = get_results_file_name(exp_name)
    with open(filename, "w") as f:
        json.dump(output, f)

    print(f"Saved {filename}")

In [None]:
np.random.seed(42)


def select_random_subset(results, size=25):
    random_indices = np.random.choice(len(results), size=size, replace=False)
    return [results[i] for i in random_indices]

In [None]:
squad_metric = load("squad")


def perform_basic_evaluation(results):
    predictions = [
        {
            "id": str(i),
            "prediction_text": r["predicted_answer"],
        }
        for i, r in enumerate(results)
    ]

    references = [
        {
            "id": str(i),
            "answers": {
                "text": [r["ground_truth"]],
                "answer_start": [0],
            },
        }
        for i, r in enumerate(results)
    ]

    metrics = squad_metric.compute(predictions=predictions, references=references)

    return {
        "f1_score": metrics["f1"],
        "exact_match": metrics["exact_match"],
    }

In [None]:
def perform_ragas_evaluation(results):
    results = select_random_subset(results, size=100)

    data = {
        "question": [r["question"] for r in results],
        "answer": [r["predicted_answer"] for r in results],
        "contexts": [r["contexts"] for r in results],
        "ground_truth": [r["ground_truth"] for r in results],
    }

    dataset = Dataset.from_dict(data)

    config = RunConfig(max_workers=8, timeout=60)
    eval_result_sequential = evaluate(
        dataset, metrics=[answer_relevancy], run_config=config
    )

    eval_result_parallel = evaluate(
        dataset,
        metrics=[faithfulness, context_recall, context_precision],
        run_config=config,
    )

    # Combine results
    agg_scores = {}

    for metric, values in eval_result_parallel._scores_dict.items():
        agg_scores[metric] = float(np.nanmean(values))

    for metric, values in eval_result_sequential._scores_dict.items():
        agg_scores[metric] = float(np.nanmean(values))

    return agg_scores

In [None]:
def evaluate_results(results):
    return {**perform_basic_evaluation(results), **perform_ragas_evaluation(results)}

# Run Experiments

In [None]:
prompts = configs["prompts"]
embedding_models = configs["embedding_models"]
top_ks = configs["top_k"]

In [None]:
client = MilvusClient("rag_wikipedia_mini.db")

for embed_model_name, embed_dim, embedding_model in tqdm(embedding_models):
    embeddings, rag_data = create_embeddings_and_rag_data(embedding_model)
    setup_collection(embed_dim, embed_model_name, rag_data)

    for prompt_name, prompt_generator in prompts:
        for top_k in top_ks:
            print("=" * 18)
            print(
                f"Embedding: {embed_model_name}, Prompt: {prompt_name}, Top K: {top_k}"
            )

            experiment_name = f"{embed_model_name}_{prompt_name}_{top_k}"
            results = []

            if os.path.exists(get_output_file_name(experiment_name)):
                print(f"Output for {experiment_name} exists! Skipping...")
            else:
                for index, row in tqdm(queries.iterrows(), total=len(queries)):
                    query = row["question"]

                    search_results = client.search(
                        collection_name=embed_model_name,
                        data=[embedding_model.encode(query).tolist()],
                        output_fields=["passage"],
                        limit=top_k,
                    )

                    context = [result["passage"] for result in search_results[0]]
                    context_str = "\n".join(context)

                    prompt = prompt_generator(query, context)
                    answer = generate_answer(prompt, model, tokenizer)
                    results.append(
                        {
                            "question": query,
                            "predicted_answer": answer,
                            "ground_truth": row["answer"],
                            "contexts": context,
                        }
                    )

                save_outputs(results, experiment_name)

            if os.path.exists(get_results_file_name(experiment_name)):
                print(f"Results for {experiment_name} exists! Skipping...")
            else:
                if not results:
                    with open(get_output_file_name(experiment_name), "r") as f:
                        results = json.load(f)
                metrics = evaluate_results(results)
                save_results(
                    {
                        "embedding_model": embed_model_name,
                        "prompt": prompt_name,
                        "top_k": top_k,
                    },
                    metrics,
                    experiment_name,
                )

In [None]:
!zip -r results.zip results/