In [5]:
from typing import Generator

In [6]:
import torch
import faiss
import requests
import json
import pickle
import xmltodict
from bs4 import BeautifulSoup
from pathlib import Path
from typing import List, Tuple, Dict, Optional
from tqdm import tqdm
from fastembed import TextEmbedding, SparseTextEmbedding
from fastembed.rerank.cross_encoder import TextCrossEncoder
import nltk
from nltk.tokenize import word_tokenize
import numpy as np
from nltk.corpus import stopwords


# Download required NLTK data
nltk.download('punkt', quiet=True)

class RetrievalPipeline:
    def __init__(self):

        self.stop_words = set(stopwords.words('english'))
        nltk.download('stopwords', quiet=True)

        self.device = 'mps' if torch.backends.mps.is_available() else 'cpu'
        print(f"Using device: {self.device}")
        
        # Initialize models (removed CoreML provider for stability)
        self.dense_model = TextEmbedding(
            model_name="BAAI/bge-small-en-v1.5"
        )
        
        self.sparse_model = SparseTextEmbedding(
            model_name="prithivida/Splade_PP_en_v1"
        )
        
        self.reranker = TextCrossEncoder(
            model_name="Xenova/ms-marco-MiniLM-L-6-v2"
        )
        
        # Data storage
        self.chunks: List[str] = []
        self.chunk_to_url: Dict[str, str] = {}
        self.sparse_embeddings = []
        self.dense_embeddings: Optional[torch.Tensor] = None
        self.dense_index: Optional[faiss.Index] = None

    def fetch_sitemap(self, sitemap_url: str) -> List[str]:
        """Fetch URLs from sitemap with error handling."""
        try:
            response = requests.get(sitemap_url, timeout=15)
            response.raise_for_status()
            sitemap_dict = xmltodict.parse(response.content)
            return [url['loc'] for url in sitemap_dict['urlset']['url']]
        except Exception as e:
            print(f"Error fetching sitemap: {str(e)}")
            return []

    def fetch_content(self, url: str) -> str:
        """Fetch and clean content from URL with improved error handling."""
        try:
            response = requests.get(url, timeout=15)
            response.raise_for_status()
            soup = BeautifulSoup(response.content, 'html.parser')
            
            # Remove unwanted elements
            for element in soup(["script", "style", "nav", "footer", "header", "form"]):
                element.decompose()
            
            # Clean text
            text = ' '.join(
                line.strip() 
                for line in soup.get_text().splitlines() 
                if line.strip() and len(line.strip()) > 25
            )
            return text
        except Exception as e:
            print(f"Error fetching {url}: {str(e)}")
            return ""

    def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 128) -> List[str]:
        """Token-based chunking with overlap."""
        tokens = word_tokenize(text)
        chunks = []
        start = 0
        
        while start < len(tokens):
            end = start + chunk_size
            chunk_tokens = tokens[start:end]
            chunks.append(' '.join(chunk_tokens))
            start = end - overlap if end - overlap > start else start + 1
        
        return chunks

    def normalize_scores(self, scores: List[float]) -> List[float]:
        """Robust score normalization handling multiple input types."""
        if not scores:
            return []
        
        # Convert generators to numpy arrays
        if isinstance(scores, (map, filter, Generator)):
            scores = list(scores)
        
        scores = np.array(scores, dtype=np.float32)
        
        if scores.size == 0:
            return []
        
        if np.ptp(scores) == 0:
            return np.ones_like(scores).tolist()
        
        return ((scores - np.min(scores)) / np.ptp(scores)).tolist()

    def process_documents(self, sitemap_url: str):
        """Process documents with type-safe conversions and device handling"""
        data_dir = Path("data")
        data_dir.mkdir(exist_ok=True)
        
        urls = self.fetch_sitemap(sitemap_url)
        print(f"Found {len(urls)} URLs")
        
        all_chunks = []
        chunk_to_url = {}
        
        # Process URLs and chunk content
        for url in tqdm(urls, desc="Processing URLs"):
            content = self.fetch_content(url)
            if not content:
                continue
                
            chunks = self.chunk_text(content)
            for chunk in chunks:
                chunk_idx = str(len(all_chunks))
                all_chunks.append(chunk)
                chunk_to_url[chunk_idx] = url
        
        # Generate dense embeddings with explicit type handling
        print("Generating dense embeddings...")
        embeddings = list(self.dense_model.embed(all_chunks))
        
        # Create numpy array first with proper dtype
        dense_numpy = np.stack(embeddings).astype(np.float32)
        
        # Convert directly to tensor with device placement
        self.dense_embeddings = torch.tensor(
            dense_numpy,
            device=self.device,
            dtype=torch.float32
        )
        
        # Build FAISS index using CPU numpy array
        print("Building FAISS index...")
        self.dense_index = faiss.IndexFlatIP(dense_numpy.shape[1])
        self.dense_index.add(dense_numpy.astype('float32'))  # Extra type safety
        
        # Generate sparse embeddings
        print("Generating sparse embeddings...")
        self.sparse_embeddings = list(self.sparse_model.embed(all_chunks))
        
        # Save all components
        print("Saving data...")
        with open(data_dir/'chunks.pkl', 'wb') as f:
            pickle.dump(all_chunks, f)
        with open(data_dir/'chunk_to_url.json', 'w') as f:
            json.dump(chunk_to_url, f)
        with open(data_dir/'sparse_embeddings.pkl', 'wb') as f:
            pickle.dump(self.sparse_embeddings, f)
        faiss.write_index(self.dense_index, str(data_dir/'dense_index.faiss'))
        torch.save(self.dense_embeddings, data_dir/'dense_embeddings.pt')
        
        print(f"Processed {len(all_chunks)} chunks successfully")

    def load_data(self):
        """Load processed data with enhanced error handling."""
        data_dir = Path("data")
        required_files = [
            'chunks.pkl', 
            'chunk_to_url.json',
            'sparse_embeddings.pkl',
            'dense_index.faiss',
            'dense_embeddings.pt'
        ]
        
        for file in required_files:
            if not (data_dir/file).exists():
                raise FileNotFoundError(f"Missing required file: {file}")
        
        with open(data_dir/'chunks.pkl', 'rb') as f:
            self.chunks = pickle.load(f)
        with open(data_dir/'chunk_to_url.json', 'r') as f:
            self.chunk_to_url = json.load(f)
        with open(data_dir/'sparse_embeddings.pkl', 'rb') as f:
            self.sparse_embeddings = pickle.load(f)
        self.dense_index = faiss.read_index(str(data_dir/'dense_index.faiss'))
        self.dense_embeddings = torch.load(data_dir/'dense_embeddings.pt', map_location=self.device)

    def hybrid_search(self, query: str, k: int = 50) -> List[Tuple[int, str, float, float, float]]:
        """Optimized hybrid search with combined scoring."""
        # Dense search
        query_dense = np.array(list(self.dense_model.embed([query]))[0])
        dense_scores, dense_indices = self.dense_index.search(query_dense.reshape(1, -1), k)
        dense_scores = dense_scores[0].tolist()
        top_indices = dense_indices[0].tolist()
        
        # Sparse scoring only on top results
        query_sparse = list(self.sparse_model.embed([query]))[0]
        sparse_scores = []
        query_indices = set(query_sparse.indices)
        
        for idx in top_indices:
            doc_emb = self.sparse_embeddings[idx]
            score = sum(
                query_sparse.values[i] * doc_emb.values[doc_emb.indices == idx][0]
                for i, idx in enumerate(query_sparse.indices)
                if idx in doc_emb.indices
            )
            sparse_scores.append(score)
        
        # Normalize and combine scores
        norm_dense = self.normalize_scores(dense_scores)
        norm_sparse = self.normalize_scores(sparse_scores)
        
        combined = []
        for i, idx in enumerate(top_indices):
            combined_score = 0.7 * norm_dense[i] + 0.3 * norm_sparse[i]
            combined.append((idx, combined_score, norm_dense[i], norm_sparse[i]))
        
        # Sort by combined score
        combined.sort(key=lambda x: x[1], reverse=True)
        return [(idx, self.chunks[idx], *scores) for idx, *scores in combined[:k]]

    def check_relevance(self, query: str, results: List[Tuple[int, str, float, float, float]], 
                       threshold: float = 0.6) -> Tuple[bool, float]:
        """Enhanced relevance check with term normalization"""
        if not results:
            return False, 0.0
            
        # Get top result
        top_idx = results[0][0]
        top_text = results[0][1]
        
        # Semantic similarity
        query_embedding = torch.tensor(list(self.dense_model.embed([query]))[0], device=self.device)
        doc_embedding = self.dense_embeddings[top_idx]
        similarity = torch.nn.functional.cosine_similarity(
            query_embedding, doc_embedding, dim=0
        ).item()
        
        # Term normalization
        def normalize_text(text: str) -> set:
            # Remove punctuation and stopwords
            text = ''.join([c.lower() if c.isalnum() else ' ' for c in text])
            return set(word for word in text.split() if word not in self.stop_words)
        
        query_terms = normalize_text(query)
        doc_terms = normalize_text(top_text)
        
        # Flexible term matching
        term_overlap = len(query_terms & doc_terms) / len(query_terms) if query_terms else 0
        
        # Debug prints
        print(f"\nRelevance Check:")
        print(f"Top Document: {top_text[:200]}...")
        print(f"Similarity: {similarity:.2f}, Term Overlap: {term_overlap:.2f}")
        
        # Adjusted thresholds
        return (
            similarity >= threshold and term_overlap >= 0.25,
            similarity
        )

    def search(self, query: str) -> Dict:
        """Fixed search with proper score handling"""
        results = self.hybrid_search(query)
        
        # Early exit if no relevant results
        is_relevant, rel_score = self.check_relevance(query, results)
        if not is_relevant:
            return {
                "status": "no_results",
                "message": "No relevant documents found",
                "confidence": f"{rel_score:.2f}"
            }
        
        # Rerank top candidates with type conversion
        candidates = [res[1] for res in results[:20]]
        rerank_scores = list(self.reranker.rerank(query, candidates))  # Convert generator to list
        rerank_scores = self.normalize_scores(rerank_scores)
        
        # Combine scores
        final_results = []
        for (idx, text, comb, dense, sparse), rerank in zip(results[:10], rerank_scores[:10]):
            final_score = 0.5 * comb + 0.5 * rerank
            final_results.append({
                "index": idx,
                "text": text[:350] + "..." if len(text) > 350 else text,
                "url": self.chunk_to_url.get(str(idx), ""),
                "scores": {
                    "final": round(final_score, 3),
                    "dense": round(dense, 3),
                    "sparse": round(sparse, 3),
                    "rerank": round(rerank, 3)
                }
            })
        
        return {
            "status": "success",
            "confidence": f"{rel_score:.2f}",
            "results": sorted(final_results, key=lambda x: x['scores']['final'], reverse=True)
        }

    def print_results(self, search_results: Dict):
        """Enhanced result formatting."""
        if search_results["status"] != "success":
            print(f"\n❌ {search_results['message']} (Confidence: {search_results['confidence']})")
            return
            
        print(f"\n🔍 Found {len(search_results['results'])} results (Confidence: {search_results['confidence']})")
        for result in search_results["results"]:
            print(f"\n📄 {result['url']}")
            print(f"📝 {result['text']}")
            print("📊 Scores:")
            print(f"  Final: {result['scores']['final']}")
            print(f"  Dense: {result['scores']['dense']}")
            print(f"  Sparse: {result['scores']['sparse']}")
            print(f"  Rerank: {result['scores']['rerank']}")
            print("-" * 80)



In [7]:
pipeline = RetrievalPipeline()

Using device: mps


In [11]:
pipeline.process_documents("https://nextjs.org/sitemap.xml")

Found 570 URLs


Processing URLs: 100%|██████████| 570/570 [02:15<00:00,  4.20it/s]


Generating dense embeddings...
Building FAISS index...
Generating sparse embeddings...
Saving data...
Processed 1574 chunks successfully


In [8]:
pipeline.load_data()

In [9]:
results = pipeline.search("SSR")


Relevance Check:
Top Document: Rendering : Server-side Rendering ( SSR ) | Next.js MenuUsing App RouterFeatures available in /appUsing Latest Version15.1.6Building Your ApplicationRenderingServer-side Rendering ( SSR ) Server-side ...
Similarity: 0.70, Term Overlap: 1.00


In [10]:
pipeline.print_results(results)



🔍 Found 10 results (Confidence: 0.70)

📄 https://nextjs.org/docs/pages/building-your-application/rendering/server-side-rendering
📝 Rendering : Server-side Rendering ( SSR ) | Next.js MenuUsing App RouterFeatures available in /appUsing Latest Version15.1.6Building Your ApplicationRenderingServer-side Rendering ( SSR ) Server-side Rendering ( SSR ) Also referred to as `` SSR '' or `` Dynamic Rendering '' . If a page uses Server-side Rendering , the page HTML is generated on each...
📊 Scores:
  Final: 1.0
  Dense: 1.0
  Sparse: 1.0
  Rerank: 1.0
--------------------------------------------------------------------------------

📄 https://nextjs.org/learn/seo/rendering-strategies
📝 SEO : Rendering Strategies | Next.jsSign inSign in to save progress11Chapter 11Rendering StrategiesStatic Site Generation ( SSG ) Static site generation is where your HTML is generated at build time . This HTML is then used for each request . Static site generation is probably the best type of rendering strategy 

In [104]:
!conda list -e > requirements.txt

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [103]:
!pip freeze > requirements.txt

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
