Version 2 of reproduction of the following paper : Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks

What is added from V1 : 
- Evaluation of the model
- Improvement on the model

IMPORTS

In [1]:
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import Dataset, load_dataset

import faiss
from sentence_transformers import SentenceTransformer
import numpy as np
import torch

  from .autonotebook import tqdm as notebook_tqdm


DATASET

In [2]:
# Load TriviaQA (unfiltered for simplicity)
dataset = load_dataset("trivia_qa", "unfiltered")

# Take first 50 examples for quick testing
test_set = dataset["train"].select(range(50))

In [3]:
# List of the features of the dataset
features = dataset["train"].features; features

{'question': Value('string'),
 'question_id': Value('string'),
 'question_source': Value('string'),
 'entity_pages': {'doc_source': List(Value('string')),
  'filename': List(Value('string')),
  'title': List(Value('string')),
  'wiki_context': List(Value('string'))},
 'search_results': {'description': List(Value('string')),
  'filename': List(Value('string')),
  'rank': List(Value('int32')),
  'title': List(Value('string')),
  'url': List(Value('string')),
  'search_context': List(Value('string'))},
 'answer': {'aliases': List(Value('string')),
  'normalized_aliases': List(Value('string')),
  'matched_wiki_entity_name': Value('string'),
  'normalized_matched_wiki_entity_name': Value('string'),
  'normalized_value': Value('string'),
  'type': Value('string'),
  'value': Value('string')}}

In [4]:
digit = 77
example = dataset["train"][digit]

# Print the question and its answer
print("Question:", example["question"])
print("Answer:", example["answer"]["value"])

Question: 1998 was the Chinese year of which creature?
Answer: Tiger


RETRIEVER

In [20]:
passages = []

for item in test_set:
    if "search_results" in item and item["search_results"]:
        r = item["search_results"]
        # Get the number of results in this search
        n = len(r["description"])
        for i in range(n):
            text = r["description"][i] if "description" in r else ""
            context = r["search_context"][i] if "search_context" in r else ""
            combined_text = f"{r['title'][i]} - {text}\n{context}"
            passages.append({"text": combined_text})


# Build embeddings
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embeddings = embed_model.encode([p["text"] for p in passages])
embeddings = np.array(embeddings, dtype=np.float32)

# Build FAISS index
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)

# Retrieval function
def retrieve(query, k=5):
    query_vec = embed_model.encode([query]).astype(np.float32)
    D, I = index.search(query_vec, k)
    return [passages[i]["text"] for i in I[0]]


In [None]:
print(retrieve("Who was President when the first Peanuts cartoon was published?", k=3))

['A Brief History of Charles Schulz\'s \'Peanuts\' Comic ... - TIME - Charles Schulz\'s last Sunday Peanuts comic ... Oct. 2, marks 60 years since Schulz\'s first Peanuts ... On Jan. 3, 2000, the last Peanuts daily strip was published.\nA Brief History of Charles Schulz\'s \'Peanuts\' Comic Strip - TIME\nFollow @TIME\nWhen Alex Davis was 2 years old, he pointed to a drawing his father had done and exclaimed, "Snoopy!" The problem: his father was Jim Davis, the creator of Garfield, and the picture was of the cat he made famous. Charles Schulz\'s black-and-white dog is so beloved, though, that a lasagna-loving cat can\'t even compete. Saturday, Oct. 2, marks 60 years since Schulz\'s first Peanuts strip hit newspapers. Since then, Snoopy, Charlie Brown and the gang have become the most recognizable cartoon characters in America \x97 and have left an indelible mark on American culture.\nYet leave it to the man behind Charlie Brown to experience disappointment before success. In 1947 Schulz

GENERATOR

In [26]:
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large")

def generate_answer(query):
    retrieved_text = " ".join(retrieve(query, k=3))
    input_text = query + " " + retrieved_text
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        max_length=1024,
        truncation=True   # ✨ add this line
    )
    outputs = model.generate(**inputs, max_length=50)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


EVALUATE

In [27]:
def exact_match(pred, answers):
    """
    pred: string generated by your model
    answers: list of acceptable answers
    """
    pred = pred.lower().strip()
    answers = [a.lower().strip() for a in answers]
    return int(pred in answers)

for item in test_set:
    # TriviaQA answers can be in "aliases" or "value"
    answer_list = []

    if "answer" in item:
        ans = item["answer"]
        if isinstance(ans, dict):
            if "aliases" in ans:
                answer_list.extend(ans["aliases"])
            if "value" in ans:
                answer_list.append(ans["value"])
        elif isinstance(ans, str):
            answer_list.append(ans)
    
    pred = generate_answer(item["question"])
    em = exact_match(pred, answer_list)

    print("Q:", item["question"])
    print("Pred:", pred)
    print("EM:", em)
    print("---")

    


Q: Who was President when the first Peanuts cartoon was published?
Pred: A Brief History of Charles Schulz's 'Peanuts' Comic Strip - TIME civilian draw handsWhen Alex Davis was 2 years old, he pointed to a drawing his father had done and exclaimed, "Snoopy!" The problem:
EM: 0
---
Q: Which American-born Sinclair won the Nobel Prize for Literature in 1930?
Pred: Which American-born Sinclair won the Nobel Prize for Literature in 1930? Sinclair Lewis Becomes the First American to be Awarded this Nobel Prize in Literature, ... 1930, Sinclair Lewis became the first writer from the United States to receive the
EM: 0
---
Q: Where in England was Dame Judi Dench born?
Pred: Where in England was Dame Judi Dench born? (2007 TV Day) (special)TV Day andHer Series (2007)TV Conversation (2005) (2007–present) (Special) (2005–present
EM: 0
---
Q: William Christensen of Madison, New Jersey, has claimed to have the world's biggest collection of what?
Pred: Updated                02/21/2002\\\\\\\\\\\\\\\

KeyboardInterrupt: 