# RAG Demo (Stable with latest libraries, no FAISS)
This notebook uses HuggingFace transformers to perform document embedding and question answering using cosine similarity.

In [None]:
# !pip install transformers torch scikit-learn

In [1]:
from transformers import AutoTokenizer, AutoModel, pipeline
from sklearn.metrics.pairwise import cosine_similarity
import torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [2]:
# Mean pooling function
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return (token_embeddings * input_mask_expanded).sum(1) / input_mask_expanded.sum(1)

In [3]:
# Embed texts
def embed_texts(texts, tokenizer, model):
    encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        model_output = model(**encoded_input)
    return mean_pooling(model_output, encoded_input['attention_mask']).numpy()

In [36]:
# Load embedding and generation models
embedding_model = 'sentence-transformers/all-MiniLM-L6-v2'
# embedding_model = "microsoft/mpnet-base"

tokenizer = AutoTokenizer.from_pretrained(embedding_model)
model = AutoModel.from_pretrained(embedding_model)
gen_pipeline = pipeline('text2text-generation', model='google/flan-t5-base')



In [37]:
# Knowledge base
docs = [
    "ProductX is the latest widget released in 2024. It features improved battery life.",
    "To reset ProductX, hold the power button for 10 seconds until the LED blinks.",
    "Our support plans include Basic, Plus, and Enterprise tiers, offering 24/7 support in higher tiers."
]
doc_embeddings = embed_texts(docs, tokenizer, model)

In [38]:
# User question
query = "How do I reset ProductX?"
query = "tell me something on levels?"

query_embedding = embed_texts([query], tokenizer, model)[0]
similarities = cosine_similarity([query_embedding], doc_embeddings)[0]
best_idx = int(np.argmax(similarities))
retrieved = docs[best_idx]
print(f'Retrieved: {retrieved}')

Retrieved: Our support plans include Basic, Plus, and Enterprise tiers, offering 24/7 support in higher tiers.


In [39]:
# Generate answer
prompt = f"Context: {retrieved}\nQuestion: {query}\nAnswer:"
response = gen_pipeline(prompt, max_length=100)[0]['generated_text']
print(f'Answer: {response}')

Answer: Basic, Plus, and Enterprise
