In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
import faiss
import torch
import warnings
warnings.filterwarnings("ignore")
from sklearn.feature_extraction.text import TfidfVectorizer


# RAG Examples

documents = [
    "The Brazil national football team is the national team of Brazil and is governed by the Brazilian Football Confederation (CBF).",
    "The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower.",
    "The Great Pyramid of Giza is a pyramid located in the Giza pyramid complex in Egypt. It is the oldest of the Seven Wonders of the Ancient World, and the only one still in existence.",
]

# Doc vectorizer
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

def embed(texts):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        embeddings = model(**inputs).last_hidden_state.mean(dim=1)
    return embeddings.numpy()

document_embeddings = embed(documents)

# FAISS Index
index = faiss.IndexFlatL2(document_embeddings.shape[1])
index.add(document_embeddings)

# Generative Model
model_name = "t5-small"
generator_tokenizer = AutoTokenizer.from_pretrained(model_name)
generator_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# RAG method
def rag_query(query, top_k=3):
    # Embedding the query
    query_embedding = embed([query])
    
    # Search for the most similar documents
    distances, indices = index.search(query_embedding, top_k)

    # Get the most similar documents
    most_similar_documents = [documents[i] for i in indices[0]]

    # Concat most similar documents with the query
    context = "\n".join(most_similar_documents) + "\n\n" + query

    # Generate the response
    inputs = generator_tokenizer.encode(context, return_tensors="pt", max_length=512, truncation=True)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
    outputs = generator_model.generate(inputs, max_length=50, num_beams=2)
    answer = generator_tokenizer.decode(outputs[0], skip_special_tokens=True)

    return answer

# Example of usage

question = "What is the name of the federation that governs the Brazil national football team?"
print(rag_query(question))

question = "What is the name of the engineer who designed the Eiffel Tower?"
print(rag_query(question))

question = "What is the name of the oldest of the Seven Wonders of the Ancient World?"
print(rag_query(question))



  from .autonotebook import tqdm as notebook_tqdm


is governed by the Brazilian Football Confederation (CBF). The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave
is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower. The Eiffel Tower
is the oldest of the Seven Wonders of the Ancient World, and the only one still in existence. The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It
