# Retrieval-Augmented Generation (RAG) Inference Pipeline

This notebook implements a **full Retrieval-Augmented Generation (RAG) inference system** on a **fine-tuned Mistral 7B model** using a custom-built scientific corpus.

The goal:  
- To **retrieve the most relevant scientific text passages** for a user query  
- To **construct a clean, context-rich prompt**  
- To **generate grounded, intelligent answers** using the fine-tuned LLM

---

## What This Notebook Does:

1. **Load the Fine-Tuned Model**:
   - Load the fine-tuned Mistral-7B-Instruct model (with LoRA adaptation) for high-quality semantic reasoning.

2. **Load the FAISS Index + Chunk Metadata**:
   - Load the pre-computed FAISS index containing dense semantic embeddings of scientific paper chunks.
   - Load the associated `chunk_metadata.json` for human-readable titles and texts.

3. **Define a Semantic Retriever**:
   - Given a user query, embed it using a SentenceTransformer (BGE model).
   - Perform a fast similarity search over the FAISS index to retrieve the top-k relevant text chunks.

4. **Construct the RAG Prompt**:
   - Assemble the retrieved excerpts cleanly into a structured prompt.
   - Insert a clear role ("You are an expert scientific assistant...") and the user question at the end.

5. **Run Model Inference**:
   - Tokenize and feed the constructed prompt into the LLM.
   - Generate an answer using greedy decoding (reproducible, deterministic outputs).

6. **Inspect Retrieved Context (Optional)**:
   - Print the top-k retrieved chunks to validate retrieval relevance.

7. **Tested with Multiple Queries**:
   - Verified the RAG pipeline with two real scientific questions to ensure retrieval grounding and model performance.

---

## Key Techniques Used:

- **Retrieval-Augmented Generation (RAG) Architecture**
- **Semantic Search with FAISS + Sentence-Transformers**
- **Prompt Engineering for Scientific QA**
- **LoRA Fine-Tuning Usage**
- **Efficient FAISS Indexing and Retrieval**
- **Professional Inference and Decoding Settings (greedy, max token budgeting)**


## Step 1: Mounting Google Drive

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Navigate to the repo folder
%cd /content/drive/MyDrive/llm-finetuning-project/llm-finetuning-summarizer

# List repo contents
!ls

## Step 2: Installing Dependencies and Importing Libraries

In [None]:
!pip install -q sentence-transformers faiss-cpu transformers

In [None]:
import os
import json
import numpy as np

import faiss
from sentence_transformers import SentenceTransformer

from transformers import AutoModelForCausalLM, AutoTokenizer

import torch
from huggingface_hub import login

In [None]:
login()

## Step 3: Verifying GPU and Environment

In [None]:
if torch.cuda.is_available():
  device = torch.device("cuda")
  print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
  device = torch.device("cpu")
  print("Using CPU")

## Step 4: Loading the FAISS index and Chunk Metadata

### Load FAISS Index and Chunk Metadata

In this section, we prepare the two essential components for semantic retrieval:

1. **FAISS Index** (`faiss_index.bin`)
   - A dense vector index containing semantic embeddings of all text chunks.
   - Used to perform fast top-k nearest neighbor search given a user query embedding.
   - Loaded with `faiss.read_index`, providing instant retrieval functionality.

2. **Chunk Metadata** (`chunk_metadata.json`)
   - A JSON mapping between each FAISS vector and its corresponding original text chunk.
   - Contains important fields such as:
     - `arxiv_id` (paper source)
     - `chunk_id` (local identifier)
     - `title` (paper title)
     - `text` (actual content of the chunk)

Both the FAISS index and the metadata are loaded into memory at this step to enable full semantic retrieval and reconstruction of meaningful context for the user’s question.

**Outcome**:  
- FAISS index with ~8724 vectors ready for search.  
- Metadata dictionary with ~8724 entries for text reconstruction.

In [None]:
# Load FAISS index
faiss_index_path = "./data/rag_corpus/faiss_index.bin"
index = faiss.read_index(faiss_index_path)
print(f"Loaded FAISS index with {index.ntotal} vectors.")

In [None]:
# Load chunk metadata
metadata_path = "./data/rag_corpus/chunk_metadata.json"
with open(metadata_path, "r") as f:
    chunk_metadata = json.load(f)
print(f"Loaded metadata for {len(chunk_metadata)} chunks.")

## Step 5: Loading the Fine-Tuned Model

In [None]:
# Load merged model and tokenizer
model_path = "./models/merged-finetuned-mistral"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [None]:
tokenizer.pad_token = tokenizer.eos_token

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")

In [None]:
model.eval()
print("Model and tokenizer successfully loaded.")

## Step 6: Define the Retriever Function

### Defining the Semantic Retriever Function

In this section, we define a function that enables **semantic retrieval** of the most relevant scientific text chunks given a user query.

---

#### What This Function Does:

1. **Embeds the User Query**:
   - The query is transformed into a dense 768-dimensional vector using the same `bge-base-en-v1.5` model used during chunking.
   - Embeddings are normalized to ensure correct cosine similarity behavior during retrieval.

2. **Performs FAISS Semantic Search**:
   - The query embedding is searched against the FAISS index to find the top-k most similar chunk embeddings.
   - FAISS search returns the indices of the top-k matches based on inner-product similarity.

3. **Maps FAISS Results Back to Full Text**:
   - The retrieved indices are used to gather the corresponding metadata (e.g., title, chunk text) for each matching chunk.

---

#### Outputs:

- A **list of dictionaries**, each containing:
  - `title` of the original paper
  - `text` of the retrieved chunk
  - `chunk_id` to uniquely identify the chunk

---

#### Why This Matters:

- Retrieval-augmented generation (**RAG**) relies on fetching real-world grounding knowledge from a semantic database.
- A well-designed retriever ensures that the LLM model has high-quality, contextually relevant information when answering questions.
- This improves answer accuracy, reduces hallucination, and enables real scientific reasoning over the corpus.

---

**Outcome**:  
We can now input any user question and retrieve the top-k most semantically relevant scientific excerpts, ready for answer generation.

In [None]:
# Load Query Encoder (same model as used during chunk embedding)
embedder = SentenceTransformer("BAAI/bge-base-en-v1.5", device=model.device)
print("Query embedder loaded.")

In [None]:
# Define Semantic Retriever Function
def retrieve_relevant_chunks(query, top_k=5):
    """
    Embed the query, search FAISS, and retrieve top-k relevant chunks.

    Args:
        query (str): User question.
        top_k (int): Number of top chunks to retrieve.

    Returns:
        List of dicts containing 'title', 'text', and 'chunk_id'.
    """
    # Step 1: Embed the query
    query_embedding = embedder.encode(query, normalize_embeddings=True)
    query_embedding = np.expand_dims(query_embedding, axis=0)  # FAISS expects (batch_size, dim)

    # Step 2: Search FAISS
    distances, indices = index.search(query_embedding, top_k)
    retrieved_chunks = []

    # Step 3: Map FAISS results back to chunk texts
    for idx in indices[0]:
        if idx < len(chunk_metadata):
            retrieved_chunks.append(chunk_metadata[idx])

    return retrieved_chunks

print("Retriever function defined successfully.")

## Step 7: Prompt Construction and Inference Function

In this section, we define a complete **Retrieval-Augmented Generation (RAG)** pipeline, where:

- A **user query** is first used to retrieve the most relevant scientific text passages.
- These passages are then assembled into a coherent **context prompt**.
- The **fine-tuned LLM** uses this context to generate an intelligent, grounded answer.

---

#### What This Function Does:

1. **Semantic Retrieval**:
   - Uses the previously defined retriever function to fetch the top-k most semantically relevant chunks from the scientific corpus.

2. **Context Assembly**:
   - Gathers the retrieved chunks while respecting a **maximum context token limit** (default 2048 tokens).
   - Ensures the prompt stays within the model's maximum sequence length (4096 tokens).

3. **Prompt Construction**:
   - Builds a structured prompt that:
     - Introduces the model's role ("You are an expert scientific assistant...")
     - Clearly separates the provided excerpts.
     - Presents the user’s question cleanly.

4. **Model Inference**:
   - Tokenizes the prompt carefully with padding and truncation.
   - Feeds the input into the fine-tuned LLM for generation.
   - Uses **greedy decoding** (no sampling) with `temperature=0.0` for maximum stability and reproducibility.

5. **Postprocessing**:
   - Cleans the output by extracting only the portion after "Answer:", ensuring focused and professional answers.

---

#### Key Parameters:

| Parameter | Default | Meaning |
|-----------|---------|---------|
| `top_k` | 5 | Number of passages to retrieve |
| `max_context_tokens` | 2048 | Max tokens allowed for retrieved context |
| `max_new_tokens` | 512 | Max tokens allowed for the generated answer |

---

#### Outputs:

- A final, human-readable **answer** generated by the model, grounded in real scientific text excerpts.

---

In [None]:
def generate_answer_with_retrieval(query, top_k=5, max_context_tokens=2048, max_new_tokens=512):
    """
    Full RAG pipeline: retrieve relevant chunks → build prompt → generate answer.

    Args:
        query (str): User's question.
        top_k (int): Number of chunks to retrieve.
        max_context_tokens (int): Max tokens to allocate for context passages.
        max_new_tokens (int): Max tokens the model can generate for the answer.

    Returns:
        str: The model's generated answer.
    """

    # Step 1: Retrieve relevant context passages
    retrieved_chunks = retrieve_relevant_chunks(query, top_k=top_k)

    # Step 2: Assemble the context
    context_blocks = []
    total_tokens = 0

    for chunk in retrieved_chunks:
        chunk_text = f"[Title: {chunk['title']}]\n{chunk['text']}\n"
        chunk_tokens = len(tokenizer.tokenize(chunk_text))

        if total_tokens + chunk_tokens <= max_context_tokens:
            context_blocks.append(chunk_text)
            total_tokens += chunk_tokens
        else:
            break  # stop adding more chunks if token budget exceeded

    assembled_context = "\n\n".join(context_blocks)

    # Step 3: Build the final prompt
    prompt = (
        f"You are an expert scientific assistant. Use the provided excerpts to answer the question.\n\n"
        f"Excerpts:\n{assembled_context}\n\n"
        f"Question: {query}\n"
        f"Answer:"
    )

    # Step 4: Tokenize the full prompt
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=4096,  # absolute model context window
        padding=True
    ).to(model.device)

    # Step 5: Generate the answer
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=max_new_tokens,
            do_sample=False,  # greedy decoding
            top_p=1.0
        )

    # Step 6: Decode and return the generated answer
    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract the portion after "Answer:" (to be clean)
    final_answer = decoded_output.split("Answer:")[-1].strip()

    return final_answer

## Step 8: Testing the Model with a Sample Question

In [None]:
sample_question = "What are the benefits of parameter-efficient fine-tuning methods?"

# Generate the answer
rag_answer = generate_answer_with_retrieval(sample_question, top_k=5)

# Display
print("Question:")
print(sample_question)
print("\nModel Answer:")
print(rag_answer)

In [None]:
# Inspect Retrieved Chunks for a Query

retrieved_chunks = retrieve_relevant_chunks(sample_question, top_k=5)

for i, chunk in enumerate(retrieved_chunks):
    print(f"--- Chunk {i+1} ---")
    print(f"Title: {chunk['title']}")
    print(f"Text excerpt:\n{chunk['text'][:500]}...")  # Print only the first 500 characters
    print("\n")

In [None]:
sample_question_2 = "How does LoRA improve the efficiency of fine-tuning large language models?"

rag_answer_2 = generate_answer_with_retrieval(sample_question_2, top_k=5)

print("Question:")
print(sample_question_2)
print("\nModel Answer:")
print(rag_answer_2)

In [None]:
# Inspect Retrieved Chunks for a Query

retrieved_chunks_2 = retrieve_relevant_chunks(sample_question_2, top_k=5)

for i, chunk in enumerate(retrieved_chunks_2):
    print(f"--- Chunk {i+1} ---")
    print(f"Title: {chunk['title']}")
    print(f"Text excerpt:\n{chunk['text'][:500]}...")  # Print only the first 500 characters
    print("\n")

## Step 9: Fixing Metadata

In [None]:
pip install nbformat --quiet

In [None]:
from google.colab import drive, files
drive.mount('/content/drive', force_remount=True)

In [None]:
import nbformat
import os, json, pathlib

In [None]:
nb_path = pathlib.Path("/content/drive/MyDrive/llm-finetuning-project/llm-finetuning-summarizer/notebooks/14_rag_retrieval_and_inference.ipynb")   # adjust if filename differs
nb = json.loads(nb_path.read_text())

# Delete the troublesome metadata
nb.get("metadata", {}).pop("widgets", None)

# (optional but helpful) strip cell outputs
for cell in nb["cells"]:
    cell["outputs"] = []
    cell["execution_count"] = None

nb_path.write_text(json.dumps(nb, indent=1, ensure_ascii=False))
print("Notebook cleaned.")