In [None]:
import os
import pickle
import numpy as np
import networkx as nx
import torch
import torch.nn.functional as F
from torch_geometric.nn import RGCNConv
from torch_geometric.data import Data
from pypdf import PdfReader
import spacy
from spacy.tokens import Doc
from spacy.language import Language
from spacy.matcher import Matcher
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
import re
from typing import List, Dict, Tuple, Set
from openai import OpenAI
from collections import defaultdict
import logging
import faiss
import threading
import concurrent.futures
from itertools import chain
import time
import traceback

# Configure logging
logging.basicConfig(
    level=logging.DEBUG,  # Set to DEBUG for detailed logs
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('rag_system.log'),
        logging.StreamHandler()
    ]
)

# Create logger
logger = logging.getLogger("ASUSimRAG")

# Define device for PyTorch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class PDFTextExtractor:
    def __init__(self, pdf_path: str):
        self.pdf_path = pdf_path

    def extract_text(self) -> str:
        """Extract and clean text from a PDF file."""
        reader = PdfReader(self.pdf_path)
        text = ""
        for page_num, page in enumerate(reader.pages, start=1):
            page_text = page.extract_text()
            if page_text:
                text += page_text + " "
            logger.debug(f"Extracted text from page {page_num}.")
        text = re.sub(r'\s+', ' ', text)
        text = re.sub(r'[^\x00-\x7F]+', ' ', text)  # Remove non-ASCII characters
        logger.debug("Extracted and cleaned text from PDF.")
        return text


class SentenceProcessor:
    def __init__(self, text: str, nlp: Language):
        self.text = text
        self.nlp = nlp
        self.matcher = Matcher(self.nlp.vocab)
        self.sentences = []
        self.processed_docs = []

    def split_into_sentences(self) -> List[str]:
        """Split text into sentences using spaCy with custom boundary detection."""
        # Create a new DocBin to store the modified document
        doc = self.nlp(self.text)
        
        # Instead of modifying tokens directly, collect sentence boundaries
        sentence_boundaries = set()
        for sent in doc.sents:
            sentence_boundaries.add(sent.start)
        
        # Add semicolons as additional sentence boundaries
        for token in doc:
            if token.text == ';':
                next_token_idx = token.i + 1
                if next_token_idx < len(doc):
                    sentence_boundaries.add(next_token_idx)
        
        # Create sentences based on collected boundaries
        sentence_starts = sorted(list(sentence_boundaries))
        sentence_starts.append(len(doc))  # Add document end as final boundary
        
        # Extract sentences using the boundaries
        sentences = []
        for i in range(len(sentence_starts) - 1):
            start = sentence_starts[i]
            end = sentence_starts[i + 1]
            sent_text = doc[start:end].text.strip()
            if sent_text:  # Only add non-empty sentences
                sentences.append(sent_text)
        
        self.sentences = sentences
        logger.debug(f"Split text into {len(self.sentences)} sentences.")
        return self.sentences

    def preprocess_sentences(self) -> List[Doc]:
        """Preprocess sentences with lemmatization, POS tagging, and named entity recognition."""
        def preprocess(sent: str) -> Doc:
            doc = self.nlp(sent)
            return doc

        with concurrent.futures.ThreadPoolExecutor() as executor:
            self.processed_docs = list(executor.map(preprocess, self.sentences))
        logger.debug("Preprocessed sentences with spaCy.")
        return self.processed_docs


class EnhancedKnowledgeGraphBuilder:
    def __init__(self, processed_docs: List[Doc], sentences: List[str], ontology_path: str = None):
        self.processed_docs = processed_docs
        self.sentences = sentences
        self.ontology_path = ontology_path
        self.knowledge_graph = nx.MultiDiGraph()
        self.entity_relations = defaultdict(list)
        self.ontology = self.load_ontology()
        self.build_knowledge_graph()

    def load_ontology(self):
        """Load an ontology to enrich the knowledge graph."""
        if self.ontology_path and os.path.exists(self.ontology_path):
            with open(self.ontology_path, 'rb') as f:
                ontology = pickle.load(f)
            logger.debug("Loaded ontology for knowledge graph enrichment.")
            return ontology
        logger.debug("No ontology path provided or file does not exist.")
        return None

    def extract_relations(self, doc: Doc) -> List[Tuple[str, str, str]]:
        """Extract relations using dependency parsing and pattern matching."""
        relations = []
        for sent in doc.sents:
            subj = ""
            verb = ""
            obj = ""
            for token in sent:
                if token.dep_ in ('nsubj', 'nsubjpass'):
                    subj = token.text
                    verb = token.head.lemma_
                    for child in token.head.children:
                        if child.dep_ == 'dobj':
                            obj = child.text
                            relations.append((subj, verb, obj))
                            logger.debug(f"Extracted relation: {subj} -[{verb}]-> {obj}")
        return relations

    def enrich_with_ontology(self, entity: str, label: str):
        """Enrich nodes with ontology-based relationships."""
        if self.ontology and entity in self.ontology:
            for related_entity, relation in self.ontology[entity]:
                self.add_node_if_not_exists(related_entity, label='Ontology')
                self.knowledge_graph.add_edge(entity, related_entity, key=relation, label=relation)
                logger.debug(f"Enriched graph with ontology relation: {entity} -[{relation}]-> {related_entity}")

    def add_node_if_not_exists(self, node_text: str, label: str = None, sentence: str = None):
        """Add a node to the knowledge graph if it doesn't exist, ensuring 'sentences' attribute is initialized."""
        if not self.knowledge_graph.has_node(node_text):
            self.knowledge_graph.add_node(node_text, sentences=[], label=label)
            logger.debug(f"Added node: {node_text} with label: {label}")
        if sentence:
            self.knowledge_graph.nodes[node_text]['sentences'].append(sentence)
        if label and 'label' not in self.knowledge_graph.nodes[node_text]:
            self.knowledge_graph.nodes[node_text]['label'] = label

    def build_knowledge_graph(self):
        """Construct a knowledge graph using advanced NLP techniques and ontology enrichment."""
        for idx, doc in enumerate(self.processed_docs):
            entities = [(ent.text, ent.label_) for ent in doc.ents]
            for ent_text, ent_label in entities:
                self.add_node_if_not_exists(ent_text, label=ent_label, sentence=self.sentences[idx])
                self.enrich_with_ontology(ent_text, ent_label)

            # Extract relations
            relations = self.extract_relations(doc)
            for subj, rel, obj in relations:
                self.add_node_if_not_exists(subj)
                self.add_node_if_not_exists(obj)
                self.knowledge_graph.add_edge(subj, obj, key=rel, label='relation')
                self.entity_relations[subj].append((obj, rel))
                logger.debug(f"Added edge: {subj} -[{rel}]-> {obj}")

        logger.info(f"Constructed knowledge graph with {self.knowledge_graph.number_of_nodes()} nodes and {self.knowledge_graph.number_of_edges()} edges.")

    def get_knowledge_graph(self) -> nx.MultiDiGraph:
        """Return the constructed knowledge graph."""
        return self.knowledge_graph

    def convert_to_pyg_data(self):
        """Convert the NetworkX graph to PyTorch Geometric data format for GNN processing."""
        node_to_idx = {node: idx for idx, node in enumerate(self.knowledge_graph.nodes())}
        idx_to_node = {idx: node for node, idx in node_to_idx.items()}
        edge_index = []
        edge_type = []

        for u, v, data in self.knowledge_graph.edges(data=True):
            u_idx = node_to_idx[u]
            v_idx = node_to_idx[v]
            edge_index.append([u_idx, v_idx])
            # Encode edge types using a unique integer identifier
            relation = data.get('key', 'relation')
            edge_type.append(self.get_relation_id(relation))

        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_type = torch.tensor(edge_type, dtype=torch.long)

        # Initialize node features using pre-trained embeddings (e.g., SentenceTransformer)
        embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
        node_embeddings = embedding_model.encode(list(node_to_idx.keys()), convert_to_tensor=True)
        node_embeddings = node_embeddings.to(device)

        data = Data(x=node_embeddings, edge_index=edge_index, edge_type=edge_type)
        logger.debug("Converted NetworkX graph to PyTorch Geometric Data object.")
        return data, idx_to_node, node_to_idx

    def get_relation_id(self, relation: str) -> int:
        """Assign a unique ID to each relation type."""
        if not hasattr(self, 'relation_to_id'):
            self.relation_to_id = {}
            self.next_relation_id = 0
        if relation not in self.relation_to_id:
            self.relation_to_id[relation] = self.next_relation_id
            self.next_relation_id += 1
        return self.relation_to_id[relation]


class AdvancedGNNReasoner(torch.nn.Module):
    def __init__(self, num_node_features: int, num_relations: int, hidden_channels: int = 512, out_channels: int = 256):
        super(AdvancedGNNReasoner, self).__init__()
        # Relational GCN to handle different edge types
        self.conv1 = RGCNConv(num_node_features, hidden_channels, num_relations)
        self.conv2 = RGCNConv(hidden_channels, hidden_channels, num_relations)
        self.conv3 = RGCNConv(hidden_channels, out_channels, num_relations)
        self.dropout = torch.nn.Dropout(p=0.3)

    def forward(self, data):
        x, edge_index, edge_type = data.x, data.edge_index, data.edge_type

        # First RGCN layer with activation and dropout
        x = self.conv1(x, edge_index, edge_type)
        x = F.relu(x)
        x = self.dropout(x)

        # Second RGCN layer
        x = self.conv2(x, edge_index, edge_type)
        x = F.relu(x)
        x = self.dropout(x)

        # Third RGCN layer
        x = self.conv3(x, edge_index, edge_type)
        x = F.relu(x)

        return x


class EmbeddingManager:
    def __init__(self, sentences: List[str], embeddings_path: str = None):
        self.sentences = sentences
        self.text_model = SentenceTransformer('all-mpnet-base-v2')
        self.sentence_embeddings = None
        self.index = None
        self.embeddings_path = embeddings_path
        self.compute_embeddings()
        self.build_faiss_index()

    def compute_embeddings(self):
        """Compute or load embeddings for all sentences."""
        if self.embeddings_path and os.path.exists(self.embeddings_path):
            try:
                with open(self.embeddings_path, 'rb') as f:
                    self.sentence_embeddings = pickle.load(f)
                logger.info("Loaded sentence embeddings from disk.")
            except Exception as e:
                logger.error(f"Error loading embeddings: {e}")
                self.sentence_embeddings = None
        
        if self.sentence_embeddings is None:
            # Compute embeddings in batches
            batch_size = 32
            all_embeddings = []
            for i in range(0, len(self.sentences), batch_size):
                batch = self.sentences[i:i + batch_size]
                batch_embeddings = self.text_model.encode(batch, convert_to_tensor=True)
                all_embeddings.append(batch_embeddings.cpu().numpy())
            
            self.sentence_embeddings = np.vstack(all_embeddings)
            
            if self.embeddings_path:
                try:
                    with open(self.embeddings_path, 'wb') as f:
                        pickle.dump(self.sentence_embeddings, f)
                    logger.info("Saved sentence embeddings to disk.")
                except Exception as e:
                    logger.error(f"Error saving embeddings: {e}")

        return self.sentence_embeddings

    def build_faiss_index(self):
        """Build a FAISS index for efficient similarity search."""
        try:
            dimension = self.sentence_embeddings.shape[1]
            
            # Use IndexFlatL2 for more stable results
            self.index = faiss.IndexFlatL2(dimension)
            
            # Normalize the embeddings
            embeddings = self.sentence_embeddings.copy()
            faiss.normalize_L2(embeddings)
            
            # Add to index
            self.index.add(embeddings)
            
            logger.debug(f"Built FAISS index with {self.index.ntotal} vectors of dimension {dimension}")
            
        except Exception as e:
            logger.error(f"Error building FAISS index: {e}")
            raise


class SemanticSearcher:
    def __init__(self, embedding_manager: EmbeddingManager):
        self.embedding_manager = embedding_manager
        self.index = self.embedding_manager.index
        self.logger = logging.getLogger("ASUSimRAG")

    def semantic_search(self, query: str, top_k: int = 5) -> List[Tuple[int, float]]:
        """Perform semantic search using FAISS index."""
        try:
            self.logger.debug(f"Starting semantic search for query: {query}")
            
            # Encode query
            query_embedding = self.embedding_manager.text_model.encode(
                [query], 
                convert_to_tensor=True,
                normalize_embeddings=True
            )
            query_embedding = query_embedding.cpu().numpy()
            
            # Perform search
            actual_k = min(top_k, self.index.ntotal)
            distances, indices = self.index.search(query_embedding, actual_k)
            
            # Convert distances to similarity scores (since we're using L2 distance)
            # Lower L2 distance means higher similarity
            max_dist = np.max(distances) + 1e-6  # Avoid division by zero
            similarities = 1 - distances / max_dist
            
            # Create results
            results = []
            for idx, sim in zip(indices[0], similarities[0]):
                if 0 <= idx < len(self.embedding_manager.sentences):
                    results.append((int(idx), float(sim)))
                    self.logger.debug(f"Match {idx}: {self.embedding_manager.sentences[idx][:100]}... (score: {sim:.3f})")
            
            return results

        except Exception as e:
            self.logger.error(f"Error in semantic search: {str(e)}")
            self.logger.debug(traceback.format_exc())
            return [(0, 1.0)]  # Return first sentence as fallback


class ContextBuilder:
    def __init__(
        self, 
        sentences: List[str], 
        knowledge_graph: nx.MultiDiGraph, 
        encoder: SentenceTransformer, 
        processed_docs: List[Doc],
        gnn_data: Data,
        idx_to_node: Dict[int, str],
        node_to_idx: Dict[str, int]
    ):
        self.sentences = sentences
        self.knowledge_graph = knowledge_graph
        self.encoder = encoder
        self.processed_docs = processed_docs
        self.gnn_data = gnn_data
        self.idx_to_node = idx_to_node
        self.node_to_idx = node_to_idx
        self.nlp = spacy.load('en_core_web_sm')
        self.sentence_indices = {sent: idx for idx, sent in enumerate(self.sentences)}
        
        try:
            self.sentence_graph = self.build_sentence_graph()
            logger.debug("Built sentence graph successfully")
        except Exception as e:
            logger.error(f"Error building sentence graph: {str(e)}")
            self.sentence_graph = nx.Graph()
            
        # Initialize projection layer
        try:
            self.projection = torch.nn.Linear(768, 256).to(device)
            logger.debug("Initialized projection layer")
        except Exception as e:
            logger.error(f"Error initializing projection layer: {str(e)}")
            # Create a simple fallback projection
            self.projection = lambda x: x[:, :256] if x.shape[1] > 256 else x
            
        try:
            # Load or compute entity embeddings with error handling
            self.entity_nodes, self.entity_embeddings = self.load_or_compute_entity_embeddings()
            logger.debug(f"Loaded entity embeddings: {self.entity_embeddings.shape}")
        except Exception as e:
            logger.error(f"Error in entity embeddings initialization: {str(e)}")
            # Initialize with empty embeddings that won't break downstream
            self.entity_nodes = []
            self.entity_embeddings = np.zeros((0, 256))

    def build_sentence_graph(self) -> nx.Graph:
        """Build a graph where nodes are sentences and edges represent similarity."""
        try:
            graph_path = 'sentence_graph.pkl'
            if os.path.exists(graph_path):
                with open(graph_path, 'rb') as f:
                    sentence_graph = pickle.load(f)
                logger.info("Loaded sentence graph from disk.")
            else:
                sentence_graph = nx.Graph()
                
                # First add all nodes to ensure complete index coverage
                for idx in range(len(self.sentences)):
                    sentence_graph.add_node(idx)
                
                # Compute embeddings for similarity calculation
                embeddings = self.get_sentence_embeddings()
                embeddings = embeddings.cpu().numpy()
                
                # Calculate similarities and add edges
                similarity_matrix = cosine_similarity(embeddings)
                threshold = 0.60  # Similarity threshold for edge creation
                
                # Add edges based on similarity
                for i in range(len(self.sentences)):
                    for j in range(i + 1, len(self.sentences)):
                        if similarity_matrix[i][j] > threshold:
                            sentence_graph.add_edge(i, j, weight=similarity_matrix[i][j])
                        
                        # Always add edges to adjacent sentences with lower weight
                        # This ensures connectivity for context windows
                        if j == i + 1:
                            if not sentence_graph.has_edge(i, j):
                                sentence_graph.add_edge(i, j, weight=0.5)
                
                # Verify graph properties
                if not nx.is_connected(sentence_graph):
                    logger.warning("Sentence graph is not fully connected. Adding minimal connecting edges.")
                    components = list(nx.connected_components(sentence_graph))
                    for i in range(len(components) - 1):
                        # Connect components with minimal edges
                        comp1 = list(components[i])[0]
                        comp2 = list(components[i + 1])[0]
                        sentence_graph.add_edge(comp1, comp2, weight=0.1)
                
                # Save the graph
                with open(graph_path, 'wb') as f:
                    pickle.dump(sentence_graph, f)
                logger.info("Saved sentence graph to disk.")
                
            logger.debug("Built sentence graph based on similarity.")
            return sentence_graph
            
        except Exception as e:
            logger.error(f"Error building sentence graph: {str(e)}")
            logger.debug(traceback.format_exc())
            
            # Return minimal fallback graph if error occurs
            fallback_graph = nx.Graph()
            for i in range(len(self.sentences)):
                fallback_graph.add_node(i)
                if i > 0:
                    fallback_graph.add_edge(i-1, i, weight=0.5)
            return fallback_graph

    def get_sentence_embeddings(self, sentences=None):
        """Compute or load sentence embeddings for graph construction."""
        if sentences is None:
            embeddings_path = 'sentence_embeddings_for_graph.pkl'
            if os.path.exists(embeddings_path):
                with open(embeddings_path, 'rb') as f:
                    embeddings = pickle.load(f)
                logger.info("Loaded sentence embeddings for graph from disk.")
            else:
                embeddings = self.encoder.encode(self.sentences, convert_to_tensor=True)
                if embeddings.device.type != 'cpu':
                    embeddings = embeddings.cpu()
                with open(embeddings_path, 'wb') as f:
                    pickle.dump(embeddings.cpu().numpy(), f)
                logger.info("Saved sentence embeddings for graph to disk.")
        else:
            # For new sentences, compute embeddings
            embeddings = self.encoder.encode(sentences, convert_to_tensor=True)
            if embeddings.device.type != 'cpu':
                embeddings = embeddings.cpu()
        return embeddings

    def load_or_compute_entity_embeddings(self):
        """Compute or load embeddings for entity sentences with robust error handling."""
        try:
            entity_embeddings_path = 'entity_embeddings.pkl'
            if os.path.exists(entity_embeddings_path):
                with open(entity_embeddings_path, 'rb') as f:
                    entity_nodes, entity_embeddings = pickle.load(f)
                logger.info("Loaded entity embeddings from disk.")
            else:
                entity_nodes = list(self.knowledge_graph.nodes)
                # More robust sentence extraction with fallback
                entity_sentences = []
                for node in entity_nodes:
                    node_data = self.knowledge_graph.nodes[node]
                    sentences = node_data.get('sentences', [])
                    # If no sentences, use the node text itself
                    if not sentences:
                        entity_sentences.append(str(node))
                    else:
                        # Take the first non-empty sentence or fall back to node text
                        valid_sentences = [s for s in sentences if s and isinstance(s, str)]
                        entity_sentences.append(valid_sentences[0] if valid_sentences else str(node))
                        
                # Log some debug information
                logger.debug(f"Processing {len(entity_nodes)} entities")
                logger.debug(f"First few entity sentences: {entity_sentences[:3]}")
                
                # Encode entity sentences
                entity_embeddings = self.get_sentence_embeddings(entity_sentences)
                
                # Save unprojected embeddings
                with open(entity_embeddings_path, 'wb') as f:
                    pickle.dump((entity_nodes, entity_embeddings.cpu().numpy()), f)
                logger.info("Saved entity embeddings to disk.")
                
            # Convert to tensor and project
            entity_embeddings = torch.tensor(entity_embeddings).to(device)
            entity_embeddings = self.projection(entity_embeddings)
            
            # Move back to CPU for storage
            entity_embeddings = entity_embeddings.cpu().detach().numpy()
            
            logger.debug(f"Final entity embeddings shape: {entity_embeddings.shape}")
            return entity_nodes, entity_embeddings

        except Exception as e:
            logger.error(f"Error in entity embeddings computation: {str(e)}")
            logger.debug(traceback.format_exc())
            # Return minimal valid output that won't break downstream processing
            return [], np.zeros((0, 256))  # Empty embeddings with correct dimension

    def get_sentence_window(self, center_idx: int, window_size: int = 7) -> List[str]:
        """Get a window of sentences around a central sentence with enhanced error handling."""
        try:
            visited = set()
            queue = [center_idx]
            window = []
            
            # Validate center_idx
            if not (0 <= center_idx < len(self.sentences)):
                logger.warning(f"Invalid center index {center_idx}, using index 0")
                center_idx = 0
                queue = [0]
                
            while queue and len(window) < window_size:
                current = queue.pop(0)
                if current not in visited and 0 <= current < len(self.sentences):
                    visited.add(current)
                    window.append(self.sentences[current])
                    
                    # Get neighbors from sentence graph if they exist
                    try:
                        neighbors = list(self.sentence_graph.neighbors(current))
                        queue.extend(neighbors)
                    except Exception as e:
                        logger.warning(f"Error getting neighbors for index {current}: {str(e)}")
                        # Add adjacent sentences as fallback
                        if current > 0:
                            queue.append(current - 1)
                        if current < len(self.sentences) - 1:
                            queue.append(current + 1)
                            
            # If window is empty, take the direct context
            if not window and 0 <= center_idx < len(self.sentences):
                window = [self.sentences[center_idx]]
                if center_idx > 0:
                    window.insert(0, self.sentences[center_idx - 1])
                if center_idx < len(self.sentences) - 1:
                    window.append(self.sentences[center_idx + 1])
                    
            logger.debug(f"Built sentence window with {len(window)} sentences")
            return window
            
        except Exception as e:
            logger.error(f"Error building sentence window: {str(e)}")
            # Return singleton window with the center sentence or first sentence as fallback
            if 0 <= center_idx < len(self.sentences):
                return [self.sentences[center_idx]]
            return [self.sentences[0]]


    def find_relevant_subgraph(self, query: str) -> Set[str]:
        """Find relevant nodes in the knowledge graph based on the query using embeddings and GNN reasoning."""
        try:
            # Get query embedding and project it to 256 dimensions
            query_embedding = self.encoder.encode([query], convert_to_tensor=True)
            query_embedding = query_embedding.to(device)
            
            # Project query embedding and detach from computation graph
            query_embedding = self.projection(query_embedding).detach().cpu().numpy()
            
            # Ensure entity embeddings are numpy array
            if torch.is_tensor(self.entity_embeddings):
                entity_embeddings = self.entity_embeddings.cpu().numpy()
            else:
                entity_embeddings = self.entity_embeddings
            
            # Calculate initial similarities
            similarities = cosine_similarity(query_embedding, entity_embeddings)[0]
            
            # Set threshold and get relevant nodes
            threshold = 0.3  # Lowered threshold for better recall
            relevant_nodes = {
                self.entity_nodes[i] 
                for i in range(len(self.entity_nodes)) 
                if similarities[i] > threshold
            }
            
            # GNN reasoning to expand relevant nodes
            try:
                gnn_model = AdvancedGNNReasoner(
                    num_node_features=self.gnn_data.num_node_features,
                    num_relations=self.gnn_data.edge_type.max().item() + 1
                ).to(device)
                gnn_model.eval()
                
                # Load pre-trained GNN weights if available
                gnn_weights_path = 'gnn_model.pth'
                if os.path.exists(gnn_weights_path):
                    gnn_model.load_state_dict(torch.load(gnn_weights_path, map_location=device))
                    logger.debug("Loaded pre-trained GNN model weights.")
                else:
                    logger.warning("GNN weights not found. Using randomly initialized weights.")
                
                with torch.no_grad():
                    # Get node embeddings from GNN
                    node_embeddings = gnn_model(self.gnn_data)
                    
                    # Ensure proper shape and convert to numpy
                    if torch.is_tensor(node_embeddings):
                        # Reshape if needed (handling batch dimension)
                        if len(node_embeddings.shape) == 3:
                            node_embeddings = node_embeddings.squeeze(0)
                        node_embeddings = node_embeddings.detach().cpu().numpy()
                    
                    # Reshape query embedding if needed
                    query_embedding_reshaped = query_embedding.reshape(1, -1)
                    
                    # Compute similarities with proper shapes
                    node_similarities = cosine_similarity(query_embedding_reshaped, node_embeddings)[0]
                    
                    gnn_threshold = 0.4  # Lowered threshold for better recall
                    for i, sim in enumerate(node_similarities):
                        if sim > gnn_threshold:
                            node = self.idx_to_node.get(i)
                            if node is not None:
                                relevant_nodes.add(node)
                        
                    logger.debug(f"GNN relevant nodes found: {len(relevant_nodes)}")
                
            except Exception as e:
                logger.error(f"Error in GNN processing: {str(e)}")
                logger.debug(traceback.format_exc())
                # Continue with just the semantic search results
            
            logger.debug(f"Found {len(relevant_nodes)} relevant nodes")
            return relevant_nodes
            
        except Exception as e:
            logger.error(f"Error finding relevant subgraph: {str(e)}")
            logger.debug(traceback.format_exc())
            return set()  # Return empty set as fallback

    def get_query_embedding(self, query: str) -> np.ndarray:
        """Compute the embedding for the query."""
        query_embedding = self.encoder.encode([query], convert_to_tensor=True)
        query_embedding = query_embedding.cpu().numpy()
        faiss.normalize_L2(query_embedding)
        # Project embedding if necessary
        # query_embedding = self.projection(query_embedding)
        return query_embedding[0]

    def build_context(self, top_matches: List[Tuple[int, float]], relevant_nodes: Set[str]) -> List[str]:
        """Construct the context with enhanced error handling and fallback mechanisms."""
        try:
            context_parts = []
            
            # Add context from semantic search
            for idx, _ in top_matches:
                try:
                    window = self.get_sentence_window(idx)
                    context_parts.extend(window)
                except Exception as e:
                    logger.warning(f"Error getting window for index {idx}: {str(e)}")
                    # Fallback: add single sentence if possible
                    if 0 <= idx < len(self.sentences):
                        context_parts.append(self.sentences[idx])

            # Add context from knowledge graph
            for node in relevant_nodes:
                try:
                    # Safely get connected sentences
                    connected_sentences = []
                    if node in self.knowledge_graph.nodes:
                        connected_sentences = self.knowledge_graph.nodes[node].get('sentences', [])
                        context_parts.extend(connected_sentences)
                except Exception as e:
                    logger.warning(f"Error getting connected sentences for node {node}: {str(e)}")

            # Ensure we have at least some context
            if not context_parts and len(self.sentences) > 0:
                logger.warning("No context found, using fallback")
                # Fallback: use first few sentences
                context_parts = self.sentences[:3]

            # Remove duplicates while preserving order
            seen = set()
            context = []
            for part in context_parts:
                if part not in seen:
                    seen.add(part)
                    context.append(part)

            logger.debug(f"Built context with {len(context)} unique sentences")
            return context

        except Exception as e:
            logger.error(f"Error building context: {str(e)}")
            # Ultimate fallback: return first sentence or empty list
            return [self.sentences[0]] if len(self.sentences) > 0 else []


class AnswerGenerator:
    def __init__(self, openai_key: str):
        self.client = OpenAI(api_key=openai_key)
        self.logger = logging.getLogger("ASUSimRAG")

    def generate_answer(self, prompt: str) -> str:
        """Generate an answer using OpenAI's API with improved error handling."""
        try:
            # Create chat completion with proper formatting
            response = self.client.chat.completions.create(
                model="gpt-4",  # Use a more advanced model
                messages=[
                    {"role": "system", "content": "You are a highly knowledgeable assistant that provides accurate and reliable answers based on the provided context."},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=1000,
                temperature=0.2,  # Lower temperature for more deterministic responses
                top_p=0.9,
                frequency_penalty=0.5,
                presence_penalty=0.0,
            )

            # Check response validity
            if not response or not response.choices:
                self.logger.error("OpenAI API returned empty response or no choices")
                return "I apologize, but I cannot generate an answer at the moment due to a technical issue."

            # Extract the message content
            answer = response.choices[0].message.content.strip()
            
            if not answer:
                self.logger.warning("Generated answer is empty")
                return "I cannot provide an answer based on the available information."

            return answer

        except Exception as e:
            error_msg = str(e)
            self.logger.error(f"Error in OpenAI API call: {error_msg}")
            self.logger.debug(traceback.format_exc())

            # Handle specific error cases
            if "api_key" in error_msg.lower():
                return "There was an issue with the API authentication. Please check your API key."
            elif "rate_limit" in error_msg.lower():
                return "The service is currently experiencing high demand. Please try again in a moment."
            elif "invalid_request" in error_msg.lower():
                return "There was an issue with the request format. Please try rephrasing your question."
            else:
                return "I apologize, but I encountered an error while generating the answer. Please try again."

    def validate_prompt(self, prompt: str) -> bool:
        """Validate the prompt before sending to the API."""
        if not prompt or not isinstance(prompt, str):
            return False
        if len(prompt.strip()) == 0:
            return False
        return True


class RAGSystem:
    # Lower the threshold to be more lenient with matches
    SIMILARITY_THRESHOLD = 0.2  # Significantly lower threshold since we're using L2 distance normalization

    def __init__(self, pdf_path: str, openai_key: str, ontology_path: str = None):
        self.pdf_path = pdf_path
        self.openai_key = openai_key
        self.ontology_path = ontology_path
        self.text = None
        self.sentences = None
        self.processed_docs = None
        self.knowledge_graph = None
        self.embedding_manager = None
        self.answer_generator = AnswerGenerator(openai_key=self.openai_key)
        self.nlp = spacy.load('en_core_web_sm')
        self.gnn_data = None
        self.idx_to_node = None
        self.node_to_idx = None
        self.initialize_system()

    def initialize_system(self):
        """Initialize all components of the RAG system."""
        start_time = time.time()
        self.extract_text()
        self.process_sentences()
        self.build_knowledge_graph()
        self.compute_embeddings()
        self.initialize_gnn()
        logger.info(f"System initialized in {time.time() - start_time:.2f} seconds.")

    def extract_text(self):
        """Extract or load text from PDF."""
        text_path = 'extracted_text.txt'
        if os.path.exists(text_path):
            with open(text_path, 'r', encoding='utf-8') as f:
                self.text = f.read()
            logger.info("Loaded extracted text from disk.")
        else:
            extractor = PDFTextExtractor(self.pdf_path)
            self.text = extractor.extract_text()
            with open(text_path, 'w', encoding='utf-8') as f:
                f.write(self.text)
            logger.info("Saved extracted text to disk.")

    def process_sentences(self):
        """Split and preprocess sentences."""
        sentences_path = 'sentences.pkl'
        processed_docs_path = 'processed_docs.pkl'
        if os.path.exists(sentences_path) and os.path.exists(processed_docs_path):
            with open(sentences_path, 'rb') as f:
                self.sentences = pickle.load(f)
            with open(processed_docs_path, 'rb') as f:
                self.processed_docs = pickle.load(f)
            logger.info("Loaded sentences and processed docs from disk.")
        else:
            processor = SentenceProcessor(self.text, self.nlp)
            self.sentences = processor.split_into_sentences()
            self.processed_docs = processor.preprocess_sentences()
            with open(sentences_path, 'wb') as f:
                pickle.dump(self.sentences, f)
            with open(processed_docs_path, 'wb') as f:
                pickle.dump(self.processed_docs, f)
            logger.info("Saved sentences and processed docs to disk.")

    def build_knowledge_graph(self):
        """Build or load the knowledge graph."""
        kg_path = 'knowledge_graph.pkl'
        mappings_path = 'node_mappings.pkl'
        gnn_data_path = 'gnn_data.pt'
        if os.path.exists(kg_path) and os.path.exists(mappings_path) and os.path.exists(gnn_data_path):
            with open(kg_path, 'rb') as f:
                self.knowledge_graph = pickle.load(f)
            with open(mappings_path, 'rb') as f:
                mappings = pickle.load(f)
                self.idx_to_node = mappings['idx_to_node']
                self.node_to_idx = mappings['node_to_idx']
            # Load gnn_data
            try:
                self.gnn_data = torch.load(gnn_data_path, map_location=device)
                logger.info("Loaded knowledge graph, node mappings, and gnn_data from disk.")
            except Exception as e:
                logger.error(f"Error loading gnn_data: {e}")
                logger.debug(traceback.format_exc())
                raise
        else:
            graph_builder = EnhancedKnowledgeGraphBuilder(self.processed_docs, self.sentences, self.ontology_path)
            self.knowledge_graph = graph_builder.get_knowledge_graph()
            self.gnn_data, self.idx_to_node, self.node_to_idx = graph_builder.convert_to_pyg_data()
            # Save knowledge graph and node mappings
            with open(kg_path, 'wb') as f:
                pickle.dump(self.knowledge_graph, f)
            with open(mappings_path, 'wb') as f:
                pickle.dump({'idx_to_node': self.idx_to_node, 'node_to_idx': self.node_to_idx}, f)
            # Save gnn_data
            torch.save(self.gnn_data, gnn_data_path)
            logger.info("Saved knowledge graph, node mappings, and gnn_data to disk.")

        # Consistency Checks
        num_nodes_kg = len(self.knowledge_graph.nodes())
        num_nodes_pyg = self.gnn_data.num_nodes if self.gnn_data is not None else None
        num_nodes_mapping = len(self.idx_to_node)

        if num_nodes_pyg != num_nodes_mapping or num_nodes_mapping != num_nodes_kg:
            logger.error(f"Node count mismatch: Knowledge Graph={num_nodes_kg}, PyG Data={num_nodes_pyg}, Mappings={num_nodes_mapping}")
            raise ValueError("Inconsistent node counts between knowledge graph, PyG data, and node mappings.")
        else:
            logger.info(f"Node counts are consistent: {num_nodes_kg} nodes.")

    def compute_embeddings(self):
        """Compute or load sentence embeddings."""
        self.embedding_manager = EmbeddingManager(self.sentences, embeddings_path='sentence_embeddings.pkl')

    def initialize_gnn(self):
        """Initialize and train/load the GNN model."""
        try:
            self.gnn_model = AdvancedGNNReasoner(
                num_node_features=self.gnn_data.num_node_features, 
                num_relations=self.gnn_data.edge_type.max().item() + 1
            ).to(device)
            # Load pre-trained weights if available
            gnn_weights_path = 'gnn_model.pth'
            if os.path.exists(gnn_weights_path):
                self.gnn_model.load_state_dict(torch.load(gnn_weights_path, map_location=device))
                logger.info("Loaded pre-trained GNN model weights.")
            else:
                logger.warning("GNN weights not found. Training from scratch is recommended for optimal performance.")
                # Implement training procedure if needed
        except Exception as e:
            logger.error(f"Error initializing GNN model: {e}")
            logger.debug(traceback.format_exc())
            raise

    def answer_question(self, query: str) -> str:
        """Generate an answer to the query using RAG with detailed debugging."""
        start_time = time.time()
        self.logger = logging.getLogger("ASUSimRAG")
        
        try:
            # Input validation
            if not query or not isinstance(query, str):
                return "Please provide a valid question."
            
            self.logger.debug(f"Processing query: {query}")
            self.logger.debug(f"Number of sentences in corpus: {len(self.sentences)}")
            
            # Semantic search with error handling
            searcher = SemanticSearcher(self.embedding_manager)
            self.logger.debug("Initialized SemanticSearcher")
            
            try:
                top_matches = searcher.semantic_search(query, top_k=15)  # Increased from 5 to 10
                self.logger.debug(f"Semantic search returned {len(top_matches)} matches")
                
                # Log the actual matches
                for idx, score in top_matches:
                    if 0 <= idx < len(self.sentences):
                        self.logger.debug(f"Match: index={idx}, score={score}, text={self.sentences[idx][:100]}...")
                    else:
                        self.logger.error(f"Invalid index {idx} in search results")

            except Exception as e:
                self.logger.error(f"Error in semantic search: {str(e)}")
                return "I encountered an error while searching for relevant information."

            # Check if top match meets the similarity threshold
            if not top_matches:
                self.logger.info("No matches found in semantic search.")
                return "I couldn't find any relevant information to answer your question."

            top_match_score = top_matches[0][1]
            self.logger.debug(f"Top match similarity score: {top_match_score}")

            if top_match_score < self.SIMILARITY_THRESHOLD:
                self.logger.info(f"Top match score {top_match_score} below threshold {self.SIMILARITY_THRESHOLD}.")
                
                # Instead of returning "Answer not found", try to use the matches we have
                if top_match_score > 0.001:  # Very low threshold for any relevance
                    self.logger.debug("Using available matches despite low similarity score")
                else:
                    return "I couldn't find sufficiently relevant information to answer your question."

            # Context building with error checking
            try:
                self.logger.debug("Initializing ContextBuilder")
                context_builder = ContextBuilder(
                    sentences=self.sentences,
                    knowledge_graph=self.knowledge_graph,
                    encoder=self.embedding_manager.text_model,
                    processed_docs=self.processed_docs,
                    gnn_data=self.gnn_data,
                    idx_to_node=self.idx_to_node,
                    node_to_idx=self.node_to_idx
                )
                
                # Find relevant nodes using both semantic search and GNN reasoning
                self.logger.debug("Finding relevant subgraph")
                relevant_nodes = context_builder.find_relevant_subgraph(query)
                self.logger.debug(f"Found {len(relevant_nodes)} relevant nodes")
                
                # Build context
                self.logger.debug("Building context")
                context = context_builder.build_context(top_matches, relevant_nodes)
                self.logger.debug(f"Built context with {len(context)} sentences")
                
            except Exception as e:
                self.logger.error(f"Error in context building: {str(e)}")
                self.logger.debug(traceback.format_exc())
                
                # Fallback: use direct matches if available
                self.logger.debug("Attempting to use fallback context")
                context = []
                for idx, _ in top_matches:
                    if 0 <= idx < len(self.sentences):
                        context.append(self.sentences[idx])
                
                if not context:
                    return "I encountered an error while building context for your question."

            # Process and limit context
            max_context_length = 4000  # Increased from 3000
            context_str = ' '.join(context)
            if len(context_str) > max_context_length:
                words = context_str.split()
                truncated_words = words[:max_context_length // 10]
                context_str = ' '.join(truncated_words)
                self.logger.debug(f"Truncated context to {len(context_str)} characters")

            # Build prompt with more detailed instructions
            prompt = f"""Use the following context information to answer the question accurately and comprehensively. If the information in the context is insufficient for a complete answer, focus on what can be confidently stated from the available information.

Context:
---------------------
{context_str}
---------------------

Question: {query}

Provide a detailed answer that:
1. Directly addresses the key points in the question
2. Includes specific requirements or eligibility criteria if mentioned
3. References any relevant deadlines or processes
4. Notes any additional resources or next steps

Answer:"""

            # Generate answer with error handling
            try:
                self.logger.debug("Generating answer using OpenAI")
                answer = self.answer_generator.generate_answer(prompt)
                logger.info(f"Answer generated in {time.time() - start_time:.2f} seconds")
                return answer
                
            except Exception as e:
                self.logger.error(f"Error generating answer: {str(e)}")
                return "I encountered an error while generating the answer. Please try again."

        except Exception as e:
            self.logger.error(f"Error in RAG pipeline: {str(e)}")
            self.logger.debug(traceback.format_exc())
            return "I encountered an error while processing your question. Please try again."


# Example usage
def main():
    # Configure logging
    logging.basicConfig(
        level=logging.DEBUG,  # Temporarily set to DEBUG for more detailed output
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    logger = logging.getLogger("ASUSimRAG")

    try:
        # Initialize the RAG system
        logger.info("Initializing RAG system...")
        rag = RAGSystem(
            pdf_path="/Users/rohit/Desktop/ASU/Finances.pdf",
            openai_key="",  # Replace with your actual OpenAI API key
            ontology_path=""  # Replace with the path to your ontology if available
        )
        logger.info("RAG system initialized successfully")

        # Example questions
        questions = [
            "What are the options and requirements for obtaining scholarships and awards as a transfer student at ASU?"
        ]

        # Get answers
        for question in questions:
            logger.info(f"Processing question: {question}")
            try:
                answer = rag.answer_question(question)
                print(f"\nQuestion: {question}")
                print(f"Answer: {answer}")
                logger.info("Successfully generated answer")
            except Exception as e:
                logger.error(f"Error processing question: {str(e)}")
                print(f"Error: Unable to process question - {str(e)}")

    except Exception as e:
        logger.error(f"Error in main: {e}")
        print(f"Error: {e}")

if __name__ == "__main__":
    main()