In [1]:
"""
RAG Pipeline (Production-Ready Version)

Improvements:
- Persistent ChromaDB storage
- Clean OpenAI integration
- Safer architecture
- Better repo-ready structure
"""

from typing import List, Dict, Optional
import os
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from openai import OpenAI


class RAGPipeline:
    """
    Production-ready RAG pipeline using:
    - sentence-transformers
    - ChromaDB
    - OpenAI API (optional generation)
    """

    def __init__(
        self,
        embedding_model_name: str = "all-MiniLM-L6-v2",
        collection_name: str = "documents",
        persist_directory: str = "./chroma_store",
        chunk_size: int = 500,
        chunk_overlap: int = 50,
    ):
        print("üöÄ Initializing Advanced RAG Pipeline...")

        # Embedding model
        print(f"üì¶ Loading embedding model: {embedding_model_name}")
        self.embedding_model = SentenceTransformer(embedding_model_name)

        # Persistent ChromaDB client (IMPORTANT for real projects)
        self.chroma_client = chromadb.Client(
            Settings(
                persist_directory=persist_directory,
                anonymized_telemetry=False,
                allow_reset=True,
            )
        )

        self.collection = self.chroma_client.get_or_create_collection(
            name=collection_name,
            metadata={"hnsw:space": "cosine"},
        )

        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap

        print("‚úÖ Pipeline initialized!")

    # ------------------------------------------------------------------
    # TEXT CHUNKING
    # ------------------------------------------------------------------
    def chunk_text(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
        words = text.split()
        chunks = []

        step = self.chunk_size - self.chunk_overlap
        if step <= 0:
            raise ValueError("chunk_overlap must be smaller than chunk_size")

        for i in range(0, len(words), step):
            chunk_words = words[i : i + self.chunk_size]
            if not chunk_words:
                continue

            chunks.append(
                {
                    "text": " ".join(chunk_words),
                    "metadata": metadata or {},
                }
            )

            if i + self.chunk_size >= len(words):
                break

        return chunks

    # ------------------------------------------------------------------
    # ADD DOCUMENTS
    # ------------------------------------------------------------------
    def add_documents(self, documents: List[Dict[str, str]]):

        print(f"\nüì• Processing {len(documents)} documents...")

        all_texts, all_metadatas, all_ids = [], [], []
        total_chunks = 0

        for doc_idx, doc in enumerate(documents):
            chunks = self.chunk_text(doc["text"], doc.get("metadata", {}))

            for chunk_idx, chunk in enumerate(chunks):
                all_texts.append(chunk["text"])
                all_metadatas.append(chunk["metadata"])
                all_ids.append(f"doc_{doc_idx}_chunk_{chunk_idx}")
                total_chunks += 1

        print(f"üìÑ Created {total_chunks} chunks")

        print("üî¢ Generating embeddings...")
        embeddings = self.embedding_model.encode(
            all_texts,
            batch_size=32,
            show_progress_bar=True,
            convert_to_numpy=True,
            normalize_embeddings=True,  # IMPORTANT for cosine similarity
        )

        print("üíæ Storing in vector database...")
        self.collection.add(
            embeddings=embeddings.tolist(),
            documents=all_texts,
            metadatas=all_metadatas,
            ids=all_ids,
        )

        print("‚úÖ Documents added successfully!")

    # ------------------------------------------------------------------
    # RETRIEVAL
    # ------------------------------------------------------------------
    def retrieve(
        self,
        query: str,
        k: int = 3,
        filter_metadata: Optional[Dict] = None,
    ) -> List[Dict]:

        query_embedding = self.embedding_model.encode(
            query, normalize_embeddings=True
        )

        results = self.collection.query(
            query_embeddings=[query_embedding.tolist()],
            n_results=k,
            where=filter_metadata,
        )

        retrieved = []
        for idx in range(len(results["ids"][0])):
            retrieved.append(
                {
                    "id": results["ids"][0][idx],
                    "text": results["documents"][0][idx],
                    "score": 1 - results["distances"][0][idx],
                    "metadata": results["metadatas"][0][idx],
                }
            )

        return retrieved

    # ------------------------------------------------------------------
    # PROMPT BUILDER
    # ------------------------------------------------------------------
    def generate_prompt(self, query: str, retrieved_docs: List[Dict]) -> str:

        context_parts = []

        for idx, doc in enumerate(retrieved_docs, 1):
            context_parts.append(f"Document {idx}:")
            context_parts.append(doc["text"])
            context_parts.append("")

        context = "\n".join(context_parts)

        prompt = f"""Based on the following context, answer the question.

Context:
{context}

Question: {query}

Answer:"""

        return prompt

    # ------------------------------------------------------------------
    # FULL RAG QUERY
    # ------------------------------------------------------------------
    def query(
        self,
        question: str,
        k: int = 3,
        filter_metadata: Optional[Dict] = None,
    ) -> Dict:

        print(f"\nüîç Query: {question}")
        print("=" * 80)

        retrieved_docs = self.retrieve(question, k, filter_metadata)

        for idx, doc in enumerate(retrieved_docs, 1):
            print(f"Document {idx} | Score: {doc['score']:.3f}")
            print(doc["text"][:200], "...\n")

        prompt = self.generate_prompt(question, retrieved_docs)

        return {
            "query": question,
            "retrieved_documents": retrieved_docs,
            "prompt": prompt,
        }

    # ------------------------------------------------------------------
    # LLM GENERATION 
    # ------------------------------------------------------------------
    def generate_answer(
        self,
        prompt: str,
        model: str = "gpt-4.1-mini",
        api_key: Optional[str] = None,
    ) -> str:

        api_key = api_key or os.getenv("OPENAI_API_KEY")
        if not api_key:
            raise ValueError("OPENAI_API_KEY not found")

        client = OpenAI(api_key=api_key)

        response = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": prompt}],
        )

        return response.choices[0].message.content

    # ------------------------------------------------------------------
    # RESET COLLECTION
    # ------------------------------------------------------------------
    def reset(self):
        self.chroma_client.delete_collection(self.collection.name)
        self.collection = self.chroma_client.create_collection(
            name=self.collection.name,
            metadata={"hnsw:space": "cosine"},
        )
        print("üóëÔ∏è Collection reset!")


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# Sample documents
documents = [
    {
        'text': """
        Python is a high-level, interpreted programming language known for its 
        simplicity and readability. Created by Guido van Rossum and first released 
        in 1991, Python emphasizes code readability with significant whitespace. 
        It supports multiple programming paradigms including procedural, object-oriented, 
        and functional programming. Python is widely used in web development, data 
        science, artificial intelligence, scientific computing, and automation.
        """,
        'metadata': {'category': 'programming', 'language': 'python'}
    },
    {
        'text': """
        Machine learning is a method of data analysis that automates analytical model 
        building. It is a branch of artificial intelligence based on the idea that 
        systems can learn from data, identify patterns and make decisions with minimal 
        human intervention. Applications include recommendation systems, image recognition, 
        natural language processing, and predictive analytics. Popular ML frameworks 
        include TensorFlow, PyTorch, and scikit-learn.
        """,
        'metadata': {'category': 'ai', 'subtopic': 'machine-learning'}
    },
    {
        'text': """
        Vector databases are specialized databases designed to store and query high-dimensional 
        vectors efficiently. They use approximate nearest neighbor (ANN) algorithms like 
        HNSW or IVF to enable fast similarity search. Vector databases are crucial for 
        modern AI applications including semantic search, recommendation engines, and 
        retrieval-augmented generation (RAG) systems. Popular options include Pinecone, 
        Weaviate, Milvus, and ChromaDB.
        """,
        'metadata': {'category': 'databases', 'subtopic': 'vector-db'}
    },
    {
        'text': """
        Transformers are a type of neural network architecture introduced in the paper 
        'Attention Is All You Need'. They use self-attention mechanisms to process 
        sequential data in parallel, making them much faster than recurrent neural networks. 
        Transformers have revolutionized NLP and are the foundation of models like BERT, 
        GPT, and T5. They excel at tasks like translation, summarization, and text generation.
        """,
        'metadata': {'category': 'ai', 'subtopic': 'deep-learning'}
    }
]
    
# Initialize pipeline

rag = RAGPipeline(
    chunk_size=150,
    chunk_overlap=30,
)

rag.add_documents(documents)

result = rag.query(
    "Explain transformers",
    k=2,
    filter_metadata={"category": "ai"},
)

answer = rag.generate_answer(result["prompt"])
print("\nü§ñ LLM Answer:\n", answer)

üöÄ Initializing Advanced RAG Pipeline...
üì¶ Loading embedding model: all-MiniLM-L6-v2


Loading weights: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 103/103 [00:00<00:00, 2053.76it/s, Materializing param=pooler.dense.weight]                             
[1mBertModel LOAD REPORT[0m from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


‚úÖ Pipeline initialized!

üì• Processing 4 documents...
üìÑ Created 4 chunks
üî¢ Generating embeddings...


Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00, 56.61it/s]

üíæ Storing in vector database...
‚úÖ Documents added successfully!

üîç Query: Explain transformers
Document 1 | Score: 0.547
Transformers are a type of neural network architecture introduced in the paper 'Attention Is All You Need'. They use self-attention mechanisms to process sequential data in parallel, making them much  ...

Document 2 | Score: 0.104
Machine learning is a method of data analysis that automates analytical model building. It is a branch of artificial intelligence based on the idea that systems can learn from data, identify patterns  ...







ü§ñ LLM Answer:
 Transformers are a type of neural network architecture designed for processing sequential data. Introduced in the paper "Attention Is All You Need," they utilize self-attention mechanisms that allow them to handle data in parallel rather than sequentially, making them much faster than traditional recurrent neural networks. This architecture has significantly advanced the field of natural language processing (NLP) and serves as the foundation for influential models such as BERT, GPT, and T5. Transformers are particularly effective for tasks including translation, summarization, and text generation.
