In [None]:
"""
Query Engine for 10-K Filings Analysis
Handles query processing and retrieval from vector store
"""

import logging
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import numpy as np
from sentence_transformers import SentenceTransformer

logger = logging.getLogger(__name__)

@dataclass
class QueryResult:
    """Data class for query results"""
    content: str
    score: float
    metadata: Dict[str, Any]
    source: str

class QueryEngine:
    """
    Handles query processing and retrieval from vector store
    """
    
    def __init__(self, vector_store, embedding_model_name: str = "all-MiniLM-L6-v2"):
        """
        Initialize query engine
        
        Args:
            vector_store: Vector store instance (FAISS/Chroma)
            embedding_model_name: Name of the embedding model
        """
        self.vector_store = vector_store
        self.embedding_model = SentenceTransformer(embedding_model_name)
        
    def process_query(self, query: str, top_k: int = 5, 
                     score_threshold: float = 0.5) -> List[QueryResult]:
        """
        Process a single query and return relevant results
        
        Args:
            query: The search query
            top_k: Number of top results to return
            score_threshold: Minimum score threshold for results
            
        Returns:
            List of QueryResult objects
        """
        try:
            # Generate embedding for the query
            query_embedding = self.embedding_model.encode([query])
            
            # Search in vector store
            results = self.vector_store.search(
                query_embedding[0], 
                top_k=top_k
            )
            
            # Filter and format results
            formatted_results = []
            for result in results:
                if result.get('score', 0) >= score_threshold:
                    formatted_results.append(QueryResult(
                        content=result.get('content', ''),
                        score=result.get('score', 0),
                        metadata=result.get('metadata', {}),
                        source=result.get('source', 'unknown')
                    ))
            
            logger.info(f"Retrieved {len(formatted_results)} results for query: {query[:50]}...")
            return formatted_results
            
        except Exception as e:
            logger.error(f"Error processing query: {str(e)}")
            return []
    
    def multi_query_search(self, queries: List[str], 
                          top_k: int = 5) -> Dict[str, List[QueryResult]]:
        """
        Process multiple queries and return results for each
        
        Args:
            queries: List of search queries
            top_k: Number of top results per query
            
        Returns:
            Dictionary mapping queries to their results
        """
        results = {}
        for query in queries:
            results[query] = self.process_query(query, top_k)
        return results
    
    def semantic_search(self, query: str, company_filter: Optional[str] = None,
                       year_filter: Optional[str] = None, 
                       top_k: int = 10) -> List[QueryResult]:
        """
        Perform semantic search with optional filters
        
        Args:
            query: Search query
            company_filter: Filter by company name/ticker
            year_filter: Filter by filing year
            top_k: Number of results to return
            
        Returns:
            List of filtered QueryResult objects
        """
        # Get initial results
        results = self.process_query(query, top_k * 2)  # Get more to allow for filtering
        
        # Apply filters
        filtered_results = []
        for result in results:
            metadata = result.metadata
            
            # Apply company filter
            if company_filter:
                company_match = (
                    company_filter.lower() in metadata.get('company', '').lower() or
                    company_filter.lower() in metadata.get('ticker', '').lower()
                )
                if not company_match:
                    continue
            
            # Apply year filter
            if year_filter:
                if str(year_filter) not in str(metadata.get('year', '')):
                    continue
            
            filtered_results.append(result)
            
            # Stop when we have enough results
            if len(filtered_results) >= top_k:
                break
        
        return filtered_results
    
    def get_context_window(self, query_result: QueryResult, 
                          window_size: int = 2) -> str:
        """
        Get expanded context around a query result
        
        Args:
            query_result: The original query result
            window_size: Number of chunks before/after to include
            
        Returns:
            Expanded context string
        """
        try:
            chunk_id = query_result.metadata.get('chunk_id')
            if not chunk_id:
                return query_result.content
            
            # Get surrounding chunks
            context_chunks = self.vector_store.get_surrounding_chunks(
                chunk_id, window_size
            )
            
            return " ".join(context_chunks)
            
        except Exception as e:
            logger.warning(f"Could not get context window: {str(e)}")
            return query_result.content
    
    def rank_results_by_relevance(self, results: List[QueryResult], 
                                 query: str) -> List[QueryResult]:
        """
        Re-rank results by relevance to the original query
        
        Args:
            results: List of query results
            query: Original query
            
        Returns:
            Re-ranked list of results
        """
        if not results:
            return results
        
        try:
            # Generate query embedding
            query_embedding = self.embedding_model.encode([query])[0]
            
            # Re-calculate scores
            for result in results:
                content_embedding = self.embedding_model.encode([result.content])[0]
                similarity = np.dot(query_embedding, content_embedding) / (
                    np.linalg.norm(query_embedding) * np.linalg.norm(content_embedding)
                )
                result.score = float(similarity)
            
            # Sort by score
            return sorted(results, key=lambda x: x.score, reverse=True)
            
        except Exception as e:
            logger.warning(f"Could not re-rank results: {str(e)}")
            return results
    
    def extract_key_phrases(self, text: str, max_phrases: int = 5) -> List[str]:
        """
        Extract key phrases from text for query expansion
        
        Args:
            text: Input text
            max_phrases: Maximum number of phrases to extract
            
        Returns:
            List of key phrases
        """
        # Simple implementation - can be enhanced with NLP libraries
        import re
        
        # Remove common stop words and extract meaningful phrases
        stop_words = {'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'a', 'an'}
        
        # Split into sentences and extract noun phrases
        sentences = re.split(r'[.!?]+', text)
        phrases = []
        
        for sentence in sentences:
            words = re.findall(r'\b[a-zA-Z]+\b', sentence.lower())
            # Create 2-3 word phrases
            for i in range(len(words) - 1):
                if words[i] not in stop_words and words[i+1] not in stop_words:
                    phrase = f"{words[i]} {words[i+1]}"
                    if len(phrase) > 4:  # Minimum phrase length
                        phrases.append(phrase)
        
        # Return most frequent phrases
        phrase_counts = {}
        for phrase in phrases:
            phrase_counts[phrase] = phrase_counts.get(phrase, 0) + 1
        
        return sorted(phrase_counts.keys(), key=phrase_counts.get, reverse=True)[:max_phrases]