Version 2 of reproduction of the following paper : Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks

What is added from V1 : 
- Evaluation of the model
- Other specific things from the paper?

IMPORTS

In [3]:
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import Dataset, load_dataset

import faiss
from sentence_transformers import SentenceTransformer
import numpy as np
import torch

  from .autonotebook import tqdm as notebook_tqdm


DATASET

In [4]:
# Load TriviaQA (unfiltered for simplicity)
dataset = load_dataset("trivia_qa", "unfiltered")

# Let's look at the first example
print(dataset["train"][0])

# Take first 50 examples for quick testing
test_set = dataset["train"].select(range(50))

{'question': 'Who was President when the first Peanuts cartoon was published?', 'question_id': 'tc_0', 'question_source': 'http://www.triviacountry.com/', 'entity_pages': {'doc_source': ['TagMe'], 'filename': ['Peanuts.txt'], 'title': ['Peanuts'], 'wiki_context': ['Peanuts is a syndicated daily and Sunday American comic strip written and illustrated by Charles M. Schulz, which ran from October 2, 1950, to February 13, 2000, continuing in reruns afterward. The strip is the most popular and influential in the history of comic strips, with 17,897 strips published in all, making it "arguably the longest story ever told by one human being".  At its peak, Peanuts ran in over 2,600 newspapers, with a readership of 355 million in 75 countries, and was translated into 21 languages.  It helped to cement the four-panel gag strip as the standard in the United States,  and together with its merchandise earned Schulz more than $1 billion. Reprints of the strip are still syndicated and run in almost 

RETRIEVER

In [6]:
# Collect all passages (here, using search_results if available)
passages = []

for item in test_set:
    # Make sure search_results exist
    if "search_results" in item and item["search_results"]:
        for r in item["search_results"]:
            # Sometimes r is a string, sometimes a dict
            if isinstance(r, dict) and "text" in r:
                passages.append({"text": r["text"]})
            elif isinstance(r, str):
                passages.append({"text": r})

# Build embeddings
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embeddings = embed_model.encode([p["text"] for p in passages])
embeddings = np.array(embeddings, dtype=np.float32)

# Build FAISS index
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)

# Retrieval function
def retrieve(query, k=5):
    query_vec = embed_model.encode([query]).astype(np.float32)
    D, I = index.search(query_vec, k)
    return [passages[i]["text"] for i in I[0]]


GENERATOR

In [7]:
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large")

def generate_answer(query):
    retrieved_text = " ".join(retrieve(query, k=3))  # combine top 3 passages
    input_text = query + " " + retrieved_text
    inputs = tokenizer(input_text, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=50)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

EVALUATE

In [None]:
def exact_match(pred, answers):
    """
    pred: string generated by your model
    answers: list of acceptable answers
    """
    pred = pred.lower().strip()
    answers = [a.lower().strip() for a in answers]
    return int(pred in answers)

for item in test_set:
    # TriviaQA answers can be in "aliases" or "value"
    answer_list = []

    if "answer" in item:
        ans = item["answer"]
        if isinstance(ans, dict):
            if "aliases" in ans:
                answer_list.extend(ans["aliases"])
            if "value" in ans:
                answer_list.append(ans["value"])
        elif isinstance(ans, str):
            answer_list.append(ans)
    
    pred = generate_answer(item["question"])
    em = exact_match(pred, answer_list)

    print("Q:", item["question"])
    print("Pred:", pred)
    print("EM:", em)
    print("---")

    


Q: Who was President when the first Peanuts cartoon was published?
Pred: Who was President when the first Peanuts cartoon was published? title title title
EM: 0
---
Q: Which American-born Sinclair won the Nobel Prize for Literature in 1930?
Pred: Which American-born Sinclair won the Nobel Prize for Literature in 1930?Search_context search_contextSearch_Context search_Context
EM: 0
---
Q: Where in England was Dame Judi Dench born?
Pred: Where in England was Dame Judi Dench born? url url url
EM: 0
---
Q: William Christensen of Madison, New Jersey, has claimed to have the world's biggest collection of what?
Pred: William Christensen of Madison, New Jersey, has claimed to have the world's biggest collection of what? url url url
EM: 0
---
Q: In which decade did Billboard magazine first publish and American hit chart?
Pred: In which decade did Billboard magazine first publish and American hit chart? rank rank rank
EM: 0
---
Q: Where was horse racing's Breeders' Cup held in 1988?
Pred: Where 

KeyboardInterrupt: 