# RAG Demo (No sentence-transformers)
This notebook uses HuggingFace transformers only for embedding and generation.

In [None]:
# Install dependencies (for Colab)
# !pip install transformers faiss-cpu torch

In [4]:
from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForSeq2SeqLM
import torch
import faiss
import numpy as np

In [5]:
# Mean pooling function
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return (token_embeddings * input_mask_expanded).sum(1) / input_mask_expanded.sum(1)

In [6]:
# Embedding function
def embed_texts(texts, tokenizer, model):
    encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        model_output = model(**encoded_input)
    return mean_pooling(model_output, encoded_input['attention_mask']).numpy()

In [7]:
# Load models
# embedding_model_name = 'sentence-transformers/all-MiniLM-L6-v2' 

embedding_model_name = "microsoft/mpnet-base"
tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
embed_model = AutoModel.from_pretrained(embedding_model_name)

gen_model_name = 'google/flan-t5-base'
gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_name)
qa_pipeline = pipeline('text2text-generation', model=gen_model, tokenizer=gen_tokenizer)

Downloading tokenizer_config.json: 100%|██████████| 350/350 [00:00<?, ?B/s] 
Downloading vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 904kB/s]
Downloading tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 1.66MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 112/112 [00:00<?, ?B/s] 
Downloading config.json: 100%|██████████| 612/612 [00:00<?, ?B/s] 


RuntimeError: Failed to import transformers.models.deta.configuration_deta because of the following error (look up to see its traceback):
No module named 'transformers.models.deta.configuration_deta'

In [None]:
# Documents
docs = [
    "ProductX is the latest widget released in 2024. It features improved battery life.",
    "To reset ProductX, hold the power button for 10 seconds until the LED blinks.",
    "Our support plans include Basic, Plus, and Enterprise tiers, offering 24/7 support in higher tiers."
]

In [None]:
# Embed documents
doc_embeddings = embed_texts(docs, tokenizer, embed_model)

In [None]:
# FAISS index
index = faiss.IndexFlatL2(doc_embeddings.shape[1])
index.add(doc_embeddings)

In [None]:
# Query
query = "How can I reset ProductX?"
query_vec = embed_texts([query], tokenizer, embed_model)[0]
k = 1
distances, indices = index.search(np.array([query_vec]), k)
retrieved_text = docs[indices[0][0]]
print(f'Retrieved: {retrieved_text}')

In [None]:
# Prompt and generate
prompt = f"Context: {retrieved_text}\nQuestion: {query}\nAnswer:"
result = qa_pipeline(prompt, max_length=100)[0]['generated_text']
print(f'Answer: {result}')