In [1]:
from transformers import BertTokenizer, BertModel, GPT2Tokenizer, GPT2LMHeadModel
import torch
import faiss
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load pre-trained model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

# Example documents
documents = ["Document 1 text", "Document 2 text", "Document 3 text"]


In [3]:
# Encode documents
def encode_documents(documents):
    embeddings = []
    for doc in documents:
        inputs = tokenizer(doc, return_tensors='pt', truncation=True, padding=True, max_length=512)
        outputs = model(**inputs)
        embeddings.append(outputs.last_hidden_state.mean(dim=1).detach().numpy())
    return np.vstack(embeddings)

document_embeddings = encode_documents(documents)

# Create FAISS index
index = faiss.IndexFlatL2(document_embeddings.shape[1])
index.add(document_embeddings)

In [4]:
# Load pre-trained GPT-2 model and tokenizer for text generation
gpt_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt_model = GPT2LMHeadModel.from_pretrained('gpt2')

# Function to generate text
def generate_text(prompt):
    inputs = gpt_tokenizer.encode(prompt, return_tensors='pt')
    outputs = gpt_model.generate(inputs, max_length=150, num_return_sequences=1)
    return gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)

In [5]:
# Combine Retrieval and Generation
def rag_model(query, top_k=3):
    # Step 1: Retrieve relevant documents
    query_embedding = encode_documents([query])[0]
    distances, indices = index.search(query_embedding.reshape(1, -1), top_k)
    retrieved_docs = [documents[idx] for idx in indices[0]]

    # Step 2: Concatenate retrieved documents to form context
    context = " ".join(retrieved_docs)

    # Step 3: Generate response using the generative model
    prompt = f"Context: {context}\n\nQuestion: {query}\nAnswer:"
    response = generate_text(prompt)

    return response

: 

In [6]:
# Example usage
query = "How can I enable dark mode in the latest version of your software?"
response = rag_model(query)
print(response)