In [None]:
import torch
from transformers import RagTokenizer, RagSequenceForGeneration
from utils.vectorstore import VectorStore

class RAG:
    def __init__(self, config, llm):
        self.config = config
        self.llm = llm
        self.tokenizer = RagTokenizer.from_pretrained(config["model_name"])
        self.model = RagSequenceForGeneration.from_pretrained(config["model_name"])
        self.vector_store = VectorStore(config["vector_store"])

    def retrieve(self, query):
        retrieved_docs = self.vector_store.search(query, k=self.config["num_retrieved_docs"])
        return retrieved_docs

    def generate(self, query, retrieved_docs):
        input_ids = self.tokenizer(query, return_tensors="pt").input_ids
        context_input_ids, context_attention_mask = self.tokenizer.encode_contexts(retrieved_docs)

        outputs = self.model.generate(
            input_ids=input_ids,
            context_input_ids=context_input_ids,
            context_attention_mask=context_attention_mask,
            max_length=self.config["max_length"],
            num_beams=self.config["num_beams"],
            num_return_sequences=1,
        )

        return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)