In [12]:
# ============================================================================
# GST GRIEVANCE RESOLUTION MULTI-AGENT SYSTEM (ENHANCED VERSION)
# ============================================================================
# Architecture: Multi-agent system with LangGraph orchestration
# Primary LLM: Gemini 2.5 Pro for resolution
# Supporting LLM: Gemini 2.5 Flash for preprocessing and classification
# Embeddings: LOCAL sentence-transformers (NO API CALLS)
# Storage: Persistent FAISS vector database
# Web Search: Tavily API / DuckDuckGo fallback
# ============================================================================

# --- IMPORTS ---
import os
import json
import time
import re
import pickle
from pathlib import Path
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional, TypedDict, Annotated
from enum import Enum
import pandas as pd
from dotenv import load_dotenv
from IPython.display import display, Markdown

# PyTorch - MUST IMPORT BEFORE Config class
import torch
from openai import OpenAI

# LangChain & LangGraph
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain_core.runnables import RunnablePassthrough
from langgraph.graph import StateGraph, END
from pydantic import BaseModel, Field

# Vector store and embeddings - LOCAL MODELS ONLY
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer

## Knowledge Graph
import sqlite3
import json
import networkx as nx
from pathlib import Path


# Web search providers
try:
    from tavily import TavilyClient
    TAVILY_AVAILABLE = True
except ImportError:
    TAVILY_AVAILABLE = False

try:
    from langchain_community.tools import DuckDuckGoSearchRun
    from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
    DUCKDUCKGO_AVAILABLE = True
except ImportError:
    DUCKDUCKGO_AVAILABLE = False

# Twitter API (optional)
try:
    import tweepy
except ImportError:
    tweepy = None

# Logging
import logging

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()

# --- CONFIGURATION ---
class Config:
    """Configuration management"""
    # API Keys
    GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
    TWITTER_BEARER_TOKEN = os.getenv("TWITTER_BEARER_TOKEN")
    TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
    DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
    
    # Model configurations
    PREPROCESSOR_MODEL = "gemini-2.5-pro"
    CLASSIFIER_MODEL = "gemini-2.5-flash-preview-09-2025"
    RESOLVER_MODEL = "gemini-2.5-flash-preview-09-2025"
    DEEPSEEK_MODEL = "deepseek-chat" 
    
    # Embedding model configuration
    LOCAL_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
    
    # Device detection for macOS M1 (MPS), CUDA, or CPU
    @staticmethod
    def get_device():
        """Auto-detect best available device"""
        if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
            logger.info("✅ Using Apple Silicon GPU (MPS)")
            return "mps"
        elif torch.cuda.is_available():
            logger.info("✅ Using NVIDIA CUDA GPU")
            return "cuda"
        else:
            logger.info("ℹ️ Using CPU (no GPU detected)")
            return "cpu"
    
    EMBEDDING_DEVICE = get_device.__func__()
    
    # Storage paths
    VECTOR_STORE_PATH = "./data/gst_knowledge_base"
    
    # Thresholds and configurations
    MIN_CONFIDENCE_THRESHOLD = 95
    NULL_RESPONSE_THRESHOLD = 95
    PREPROCESSOR_TEMPERATURE = 0.2
    CLASSIFIER_TEMPERATURE = 0.2
    RESOLVER_TEMPERATURE = 0.0
    MAX_RETRIEVAL_TIME = 10
    
    # Retrieval settings
    MAX_LOCAL_RESULTS = 5
    MAX_WEB_RESULTS = 10
    MAX_TWITTER_RESULTS = 10

    # Web search settings 
    WEB_SEARCH_UNRESTRICTED = True  # Set False to use domain filtering
    
    # GST specific URLs
    GSTN_FAQ_URL = "https://tutorial.gst.gov.in/downloads/news/FAQ.pdf"
    CBIC_BASE_URL = "https://cbic-gst.gov.in"
    GSTN_TWITTER = "@Infosys_GSTN"
    

# Validate configuration
if not Config.GOOGLE_API_KEY:
    logger.warning("⚠️ WARNING: GOOGLE_API_KEY not found in .env file")

# Log device information
logger.info(f"📊 Device: {Config.EMBEDDING_DEVICE}")
logger.info(f"🔧 PyTorch version: {torch.__version__}")
if Config.EMBEDDING_DEVICE == "mps":
    logger.info(f"🍎 MPS built: {torch.backends.mps.is_built()}")

# --- ENUMS AND MODELS ---
class GrievanceCategory(str, Enum):
    REGISTRATION = "registration"
    GSTR_FILING = "gstr_filing"
    EWAY_BILL = "eway_bill"
    REFUND = "refund"
    ITC_MISMATCH = "itc_mismatch"
    PENALTY_NOTICE = "penalty_notice"
    PORTAL_ERROR = "portal_error"
    API_INTEGRATION = "api_integration"
    COMPLIANCE = "compliance"
    GENERAL = "general"

class IntentType(str, Enum):
    INFORMATIONAL = "informational"
    PROCEDURAL = "procedural"
    ERROR_RESOLUTION = "error_resolution"
    COMPLIANCE_CLARIFICATION = "compliance_clarification"
    REFUND_STATUS = "refund_status"

class ExtractedEntity(BaseModel):
    entity_type: str
    value: str
    context: Optional[str] = None

class CoreIssue(BaseModel):
    issue_text: str
    keywords: List[str]
    priority: int

class PreprocessingOutput(BaseModel):
    cleaned_text: str
    detected_intent: str
    core_issues: List[CoreIssue]
    entities: List[ExtractedEntity]
    language: str

class ClassificationOutput(BaseModel):
    primary_category: str
    secondary_categories: List[str]
    confidence_scores: Dict[str, float]
    sub_type: Optional[str] = None

class RetrievalSource(BaseModel):
    source_type: str
    content: str
    citation: str
    relevance_score: float
    date: Optional[str] = None

class RetrievalOutput(BaseModel):
    twitter_results: List[RetrievalSource]
    local_results: List[RetrievalSource] 
    web_results: List[RetrievalSource]  
    llm_reasoning: List[RetrievalSource]
    total_sources: int
    retrieval_time: float

class IssueResolution(BaseModel):
    issue: str
    resolution: Optional[str]
    confidence: int
    legal_basis: Optional[str] = None
    procedural_steps: Optional[List[str]] = None
    source_citations: List[str]
    reason_for_null: Optional[str] = None

class ResolverOutput(BaseModel):
    resolutions: List[IssueResolution]
    overall_confidence: int
    requires_escalation: bool

class FinalResponse(BaseModel):
    direct_answer: str
    detailed_explanation: Optional[str] = None
    legal_basis: Optional[str] = None
    recent_updates: Optional[str] = None
    additional_resources: List[str]
    confidence_score: int
    requires_manual_review: bool

class AgentState(TypedDict):
    user_query: str
    session_id: str
    conversation_history: List[Dict[str, str]]
    preprocessing_output: Optional[PreprocessingOutput]
    classification_output: Optional[ClassificationOutput]
    retrieval_output: Optional[RetrievalOutput]
    resolver_output: Optional[ResolverOutput]
    final_response: Optional[FinalResponse]
    timestamp: str
    processing_time: float
    errors: List[str]
    feedback_received: Optional[str]
    iteration_count: int
    escalation_requested: bool

class LightweightKnowledgeGraph:
    """Embedded knowledge graph retriever"""
    
    def __init__(self, db_path: str = "./knowledge_graph.db"):
        self.db_path = Path(db_path)
        self.graph = nx.DiGraph()
        self.conn = None
        
        if self.db_path.exists():
            self.conn = sqlite3.connect(str(self.db_path))
        else:
            logger.warning(f"⚠️ Graph database not found: {db_path}")
    
    def load(self):
        """Load graph from SQLite database"""
        if not self.conn:
            logger.error("❌ No database connection for knowledge graph")
            return
        
        try:
            # Load nodes - FIXED: Avoid 'type' keyword conflict
            cursor = self.conn.execute("SELECT id, type, label, metadata FROM nodes")
            for node_id, node_type, label, metadata in cursor:
                meta = json.loads(metadata) if metadata else {}
                meta['entity_type'] = node_type  # Store as 'entity_type' instead of 'type'
                meta['label'] = label
                self.graph.add_node(node_id, **meta)
            
            # Load edges
            cursor = self.conn.execute("SELECT source, target, relation, weight FROM edges")
            for source, target, relation, weight in cursor:
                self.graph.add_edge(source, target, relation=relation, weight=weight)
            
            logger.info(f"✅ Graph loaded: {self.graph.number_of_nodes()} nodes, {self.graph.number_of_edges()} edges")
        except Exception as e:
            logger.error(f"❌ Failed to load graph: {e}")
    
    def find_related(self, entity: str, max_depth: int = 2, max_results: int = 20) -> list:
        """Find entities related to entity via BFS traversal"""
        if entity not in self.graph:
            return []
        
        related = []
        visited = set()
        queue = [(entity, 0)]
        
        while queue and len(related) < max_results:
            node, depth = queue.pop(0)
            if depth > max_depth or node in visited:
                continue
            visited.add(node)
            
            if node != entity and not node.startswith("doc:"):
                related.append(node)
            
            if depth < max_depth:
                for neighbor in self.graph.neighbors(node):
                    if neighbor not in visited:
                        queue.append((neighbor, depth + 1))
        
        return related[:max_results]
    
    def get_entity_info(self, entity: str) -> dict:
        """Return node metadata for an entity"""
        if entity not in self.graph:
            return {}
        return dict(self.graph.nodes[entity])
    
    def close(self):
        if self.conn:
            self.conn.close()

# --- INITIALIZE LOCAL EMBEDDINGS ---
def initialize_local_embeddings():
    """Initialize local embedding model - NO API CALLS"""
    try:
        logger.info("🔄 Initializing local embedding model...")
        logger.info(f"📦 Model: {Config.LOCAL_EMBEDDING_MODEL}")
        logger.info(f"💻 Device: {Config.EMBEDDING_DEVICE}")
        
        if Config.EMBEDDING_DEVICE == "mps":
            os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
            logger.info("🍎 MPS fallback enabled for unsupported operations")
        
        embeddings = HuggingFaceEmbeddings(
            model_name=Config.LOCAL_EMBEDDING_MODEL,
            model_kwargs={'device': Config.EMBEDDING_DEVICE},
            encode_kwargs={'normalize_embeddings': True, 'batch_size': 32},
            show_progress=False
        )
        
        logger.info(f"✅ Local embedding model loaded successfully")
        logger.info(f"🎯 No API calls - fully offline embeddings")
        return embeddings
        
    except Exception as e:
        logger.error(f"❌ Failed to initialize local embeddings: {e}")
        raise
        
# Initialize global embeddings instance
try:
    local_embeddings = initialize_local_embeddings()
    test_embedding = local_embeddings.embed_query("Test query for GST portal")
    logger.info(f"✅ Embedding test successful! Dimension: {len(test_embedding)}")
except Exception as e:
    logger.error(f"❌ Embedding initialization failed: {e}")
    local_embeddings = None
    
# --- LLM INITIALIZATION ---
def initialize_llms():
    """Initialize LLM instances with separate preprocessor"""
    
    preprocessor_llm = ChatGoogleGenerativeAI(
        model=Config.PREPROCESSOR_MODEL,  # gemini-2.5-pro
        temperature=Config.PREPROCESSOR_TEMPERATURE,
        google_api_key=Config.GOOGLE_API_KEY,
        response_mime_type="application/json"
    )
    
    classifier_llm = ChatGoogleGenerativeAI(
        model=Config.CLASSIFIER_MODEL,
        temperature=Config.CLASSIFIER_TEMPERATURE,
        google_api_key=Config.GOOGLE_API_KEY,
        response_mime_type="application/json"
    )

    resolver_llm = ChatGoogleGenerativeAI(
        model=Config.RESOLVER_MODEL,
        temperature=Config.RESOLVER_TEMPERATURE,
        google_api_key=Config.GOOGLE_API_KEY,
        response_mime_type="application/json"
    )

    logger.info(f"✅ LLMs initialized:")
    logger.info(f"   Preprocessor: {Config.PREPROCESSOR_MODEL}")
    logger.info(f"   Classifier: {Config.CLASSIFIER_MODEL}")
    logger.info(f"   Resolver: {Config.RESOLVER_MODEL}")
    return preprocessor_llm, classifier_llm, resolver_llm

preprocessor_llm, classifier_llm, resolver_llm = initialize_llms()

# --- PREPROCESSING & CLASSIFICATION AGENTS ---
class PreprocessingAgent:
    def __init__(self, llm):
        self.llm = llm
        self.parser = JsonOutputParser(pydantic_object=PreprocessingOutput)
        self.prompt = ChatPromptTemplate.from_template("""
You are an expert at preprocessing GST-related grievance queries.

**User Query:** {query}

Extract: intent, core issues, entities (GSTIN, GST forms like GSTR-3A/GSTR-4, dates, amounts), language.

{format_instructions}
""")

    def process(self, state: AgentState) -> AgentState:
        try:
            logger.info("🔄 Agent 1: Preprocessing...")
            response = self.llm.invoke(self.prompt.format(
                query=state["user_query"],
                format_instructions=self.parser.get_format_instructions()
            ))
            preprocessing_output = PreprocessingOutput(**self.parser.parse(response.content))
            state["preprocessing_output"] = preprocessing_output
            logger.info(f"✅ Found {len(preprocessing_output.core_issues)} issues")
            return state
        except Exception as e:
            logger.error(f"❌ Preprocessing error: {e}")
            state["errors"].append(str(e))
            return state

class ClassificationAgent:
    def __init__(self, llm):
        self.llm = llm
        self.parser = JsonOutputParser(pydantic_object=ClassificationOutput)
        self.prompt = ChatPromptTemplate.from_template("""
Classify GST grievance. Categories: registration, gstr_filing, penalty_notice, portal_error, refund, itc_mismatch, eway_bill, etc.

Intent: {intent}
Issues: {core_issues}

{format_instructions}
""")

    def process(self, state: AgentState) -> AgentState:
        try:
            logger.info("🔄 Agent 2: Classifying...")
            preprocessing = state["preprocessing_output"]
            response = self.llm.invoke(self.prompt.format(
                intent=preprocessing.detected_intent,
                core_issues=json.dumps([i.model_dump() for i in preprocessing.core_issues]),
                format_instructions=self.parser.get_format_instructions()
            ))
            classification_output = ClassificationOutput(**self.parser.parse(response.content))
            state["classification_output"] = classification_output
            logger.info(f"✅ Primary: {classification_output.primary_category}")
            return state
        except Exception as e:
            logger.error(f"❌ Classification error: {e}")
            state["errors"].append(str(e))
            return state

# --- NEW: LOCAL RETRIEVAL AGENT (Persistent FAISS) ---
class LocalRetrievalAgent:
    """
    PRODUCTION: Load pre-built knowledge base with graph support
    Handles large-scale PDF ingestion output (20K+ chunks)
    """
    
    def __init__(self, embeddings=None, kb_folder: str = "./", enable_graph: bool = True):
        """
        Initialize with pre-built knowledge base
        
        Args:
            embeddings: Pre-initialized local embeddings (must match ingestion model)
            kb_folder: Path to processed knowledge base folder (contains faiss_index/, kb_metadata.json, etc.)
            enable_graph: Enable knowledge graph retrieval for entity relationships
        """
        self.embeddings = embeddings or local_embeddings
        self.kb_folder = Path(kb_folder)
        self.enable_graph = enable_graph
        self.vector_store = None
        self.kb_metadata = {}
        self.graph_retriever = None
        
        if not self.embeddings:
            logger.error("❌ No local embeddings provided!")
            return
        
        # Load pre-built knowledge base
        self._load_knowledge_base()
        
        # Load knowledge graph if enabled
        if self.enable_graph:
            self._load_knowledge_graph()
    
    def _load_knowledge_base(self):
        """Load pre-processed knowledge base from disk"""
        faiss_path = self.kb_folder / "faiss_index"
        metadata_path = self.kb_folder / "kb_metadata.json"
        
        if not faiss_path.exists():
            logger.error(f"❌ No processed knowledge base found at {faiss_path}")
            logger.info("💡 Expected structure:")
            logger.info("   ./faiss_index/ - FAISS vector store")
            logger.info("   ./kb_metadata.json - Metadata")
            logger.info("   ./knowledge_graph.db - Graph database")
            logger.info("\n💡 Run your PDF ingestion pipeline first!")
            return
        
        try:
            logger.info(f"🔄 Loading knowledge base from {self.kb_folder}")
            
            # Load FAISS index
            self.vector_store = FAISS.load_local(
                str(faiss_path),
                self.embeddings,
                allow_dangerous_deserialization=True
            )
            
            # Load metadata
            if metadata_path.exists():
                with open(metadata_path, 'r') as f:
                    self.kb_metadata = json.load(f)
            
            # Log stats
            total_docs = self.kb_metadata.get('total_files', self.kb_metadata.get('total_pdfs', 'Unknown'))
            total_chunks = self.kb_metadata.get('total_chunks', 'Unknown')
            embedding_model = self.kb_metadata.get('embedding_model', 'Unknown')
            
            logger.info(f"✅ Loaded knowledge base:")
            logger.info(f"   📄 Documents: {total_docs}")
            logger.info(f"   📊 Chunks: {total_chunks}")
            logger.info(f"   🧠 Model: {embedding_model}")
            logger.info(f"   💾 Size: {self._get_index_size():.2f} MB")
            
        except Exception as e:
            logger.error(f"❌ Failed to load knowledge base: {e}")
            logger.error(f"   Check that embedding model matches: {Config.LOCAL_EMBEDDING_MODEL}")
            self.vector_store = None
    
    def _load_knowledge_graph(self):
        """Load knowledge graph for relationship-based retrieval"""
        graph_path = self.kb_folder / "knowledge_graph.db"
        
        if not graph_path.exists():
            logger.warning(f"⚠️ Knowledge graph not found at {graph_path}")
            logger.info("   Graph-enhanced retrieval disabled")
            return
        
        try:
            # Import graph retriever            
            self.graph_retriever = LightweightKnowledgeGraph(db_path=str(graph_path))
            self.graph_retriever.load()
            logger.info(f"✅ Loaded knowledge graph:")
            logger.info(f"   🕸️ Nodes: {self.graph_retriever.graph.number_of_nodes()}")
            logger.info(f"   🔗 Edges: {self.graph_retriever.graph.number_of_edges()}")
            
        except ImportError:
            logger.warning("⚠️ Knowledge graph module not found. Skipping graph loading.")
            self.graph_retriever = None
        except Exception as e:
            logger.error(f"❌ Failed to load knowledge graph: {e}")
            self.graph_retriever = None
    
    def _get_index_size(self) -> float:
        """Calculate total size of knowledge base in MB"""
        if not self.kb_folder.exists():
            return 0.0
        total_bytes = sum(f.stat().st_size for f in self.kb_folder.rglob('*') if f.is_file())
        return total_bytes / (1024 * 1024)
    
    def retrieve(self, query: str, k: int = 5, filter_category: Optional[str] = None, 
                 use_graph: bool = True) -> List[RetrievalSource]:
        """
        Retrieve relevant chunks using hybrid vector + graph search
        
        Args:
            query: Search query
            k: Number of results
            filter_category: Optional category filter
            use_graph: Enable graph-enhanced retrieval (if available)
        
        Returns:
            List of RetrievalSource objects
        """
        results = []
        
        if not self.vector_store:
            logger.warning("⚠️ Vector store not available")
            return results
        
        try:
            # --- PHASE 1: VECTOR SEARCH ---
            # Fetch more than k for filtering and re-ranking
            fetch_k = k * 3 if use_graph and self.graph_retriever else k * 2
            
            docs = self.vector_store.similarity_search_with_score(query, k=fetch_k)
            
            vector_results = []
            for doc, score in docs:
                # Apply category filter if specified
                if filter_category and doc.metadata.get('category') != filter_category:
                    continue
                
                # Build citation from metadata
                filename = doc.metadata.get('filename', doc.metadata.get('source', 'Unknown'))
                page_info = doc.metadata.get('page_num', doc.metadata.get('slide_num', ''))
                
                if page_info:
                    citation = f"{filename} (Page {page_info})"
                else:
                    citation = filename
                
                vector_results.append({
                    'source': RetrievalSource(
                        source_type="local_kb",
                        content=doc.page_content,
                        citation=citation,
                        relevance_score=float(1 - score),  # Convert distance to similarity
                        date=None
                    ),
                    'metadata': doc.metadata,
                    'vector_score': float(1 - score)
                })
            
            # --- PHASE 2: GRAPH ENHANCEMENT (Optional) ---
            if use_graph and self.graph_retriever:
                graph_boost = self._apply_graph_boost(query, vector_results)
                vector_results = graph_boost
            
            # --- PHASE 3: RETURN TOP-K ---
            # Sort by final score and take top k
            vector_results.sort(key=lambda x: x['vector_score'], reverse=True)
            
            for item in vector_results[:k]:
                results.append(item['source'])
            
            logger.info(f"   📖 Retrieved {len(results)} chunks from knowledge base")
            if use_graph and self.graph_retriever:
                logger.info(f"      🕸️ Graph-enhanced retrieval active")
            
        except Exception as e:
            logger.error(f"❌ Retrieval error: {e}")
        
        return results
    
    def _apply_graph_boost(self, query: str, vector_results: List[Dict]) -> List[Dict]:
        """
        Boost relevance scores using knowledge graph relationships
        
        Strategy:
        1. Extract entities from query
        2. Find related entities in graph
        3. Boost chunks that mention related entities
        """
        try:
            # Extract entities from query (simple regex for now)
            import re
            
            query_entities = set()
            
            # GST Forms
            forms = re.findall(r'GSTR-?\d+[A-Z]?', query, re.IGNORECASE)
            query_entities.update([f"forms:{f}" for f in forms])
            
            # Sections
            sections = re.findall(r'Section\s+\d+[A-Z]?', query, re.IGNORECASE)
            query_entities.update([f"sections:{s}" for s in sections])
            
            # Notifications
            notifications = re.findall(r'Notification\s+(?:No\.?\s*)?\d+/\d{4}', query, re.IGNORECASE)
            query_entities.update([f"notifications:{n}" for n in notifications])
            
            if not query_entities:
                return vector_results  # No boost if no entities found
            
            # Find related entities via graph
            related_entities = set()
            for entity in query_entities:
                if entity in self.graph_retriever.graph:
                    related = self.graph_retriever.find_related(entity, max_depth=2)
                    related_entities.update(related)
            
            # Boost chunks that mention related entities
            for result in vector_results:
                content = result['source'].content.lower()
                boost = 0.0
                
                for related_entity in related_entities:
                    # Extract entity label (e.g., "forms:GSTR-4" -> "GSTR-4")
                    entity_label = related_entity.split(":", 1)[-1].lower()
                    if entity_label in content:
                        boost += 0.05  # Small boost per related entity
                
                # Apply boost (cap at +0.2)
                result['vector_score'] = min(result['vector_score'] + min(boost, 0.2), 1.0)
            
            logger.info(f"      🎯 Graph boost applied ({len(related_entities)} related entities)")
            
        except Exception as e:
            logger.warning(f"⚠️ Graph boost failed: {e}")
        
        return vector_results
    
    def get_stats(self) -> dict:
        """Get comprehensive statistics about the knowledge base"""
        stats = {
            "document_count": self.kb_metadata.get('total_chunks', 0),
            "source_files": self.kb_metadata.get('total_files', self.kb_metadata.get('total_pdfs', 0)),
            "embedding_model": self.kb_metadata.get('embedding_model', 'Unknown'),
            "chunk_size": self.kb_metadata.get('chunk_size', 'Unknown'),
            "kb_folder": str(self.kb_folder),
            "index_size_mb": self._get_index_size(),
            "graph_enabled": self.graph_retriever is not None
        }
        
        if self.graph_retriever:
            stats.update({
                "graph_nodes": self.graph_retriever.graph.number_of_nodes(),
                "graph_edges": self.graph_retriever.graph.number_of_edges()
            })
        
        return stats
    
    def search_with_graph(self, entity: str, max_related: int = 10) -> List[str]:
        """
        Direct graph search: Find entities related to a given entity
        
        Args:
            entity: Entity to search (e.g., "forms:GSTR-4")
            max_related: Maximum number of related entities to return
        
        Returns:
            List of related entity IDs
        """
        if not self.graph_retriever:
            logger.warning("⚠️ Knowledge graph not loaded")
            return []
        
        try:
            related = self.graph_retriever.find_related(entity, max_depth=2)
            return related[:max_related]
        except Exception as e:
            logger.error(f"❌ Graph search error: {e}")
            return []


# --- NEW: WEB RETRIEVAL AGENT (Internet Search) ---
# --- NEW: WEB RETRIEVAL AGENT (Enhanced with LLM Query Optimization) ---
class WebRetrievalAgent:
    """
    Agent for web-based retrieval using internet search
    Features:
    - Tavily API (recommended, LLM-optimized)
    - DuckDuckGo (free fallback)
    - LLM-based query optimization for long queries
    - Automatic fallback to regex extraction
    """
    
    def __init__(self, provider: str = "auto", unrestricted: bool = True, llm=None):
        """
        Initialize web retrieval agent
        
        Args:
            provider: "tavily", "duckduckgo", or "auto"
            unrestricted: True for general web search, False for domain filtering
            llm: Optional LLM for query optimization (reuses classifier_llm)
        """
        self.provider = None
        self.client = None
        self.unrestricted = unrestricted
        self.llm = llm
        
        if provider == "auto":
            self.provider = self._auto_detect_provider()
        else:
            self.provider = provider
        
        self._initialize_provider()
    
    def _auto_detect_provider(self) -> str:
        """Auto-detect available search provider"""
        if TAVILY_AVAILABLE and Config.TAVILY_API_KEY:
            logger.info("✅ Tavily API key found, using Tavily")
            return "tavily"
        
        if DUCKDUCKGO_AVAILABLE:
            logger.info("✅ Using DuckDuckGo (free, no API key)")
            return "duckduckgo"
        
        logger.warning("⚠️ No search provider available - install tavily-python or duckduckgo-search")
        return None
    
    def _initialize_provider(self):
        """Initialize the selected search provider"""
        if self.provider == "tavily":
            self._initialize_tavily()
        elif self.provider == "duckduckgo":
            self._initialize_duckduckgo()
    
    def _initialize_tavily(self):
        """Initialize Tavily client"""
        try:
            self.client = TavilyClient(api_key=Config.TAVILY_API_KEY)
            logger.info("✅ Tavily client initialized")
            if self.unrestricted:
                logger.info("🌐 Unrestricted web search enabled (no domain filtering)")
        except Exception as e:
            logger.error(f"❌ Tavily initialization failed: {e}")
            logger.info("💡 Get API key at: https://app.tavily.com/")
            self.client = None
    
    def _initialize_duckduckgo(self):
        """Initialize DuckDuckGo search"""
        try:
            wrapper = DuckDuckGoSearchAPIWrapper(
                region="en-in",
                time="y",
                max_results=15,
                safesearch="moderate"
            )
            self.client = DuckDuckGoSearchRun(api_wrapper=wrapper)
            logger.info("✅ DuckDuckGo search initialized")
            if self.unrestricted:
                logger.info("🌐 Unrestricted web search enabled (general internet search)")
        except Exception as e:
            logger.error(f"❌ DuckDuckGo initialization failed: {e}")
            logger.info("💡 Install with: pip install duckduckgo-search")
            self.client = None
    
    # --- QUERY OPTIMIZATION METHODS ---
    
    def _extract_key_terms(self, text: str, max_terms: int = 5) -> List[str]:
        """
        Extract key terms from text using regex (fallback method)
        
        Priority:
        1. GST form numbers (GSTR-4, GSTR-3A, etc.)
        2. Notification numbers
        3. Important keywords
        4. Financial years
        """
        import re
        
        key_terms = []
        
        # 1. GST Forms (highest priority)
        forms = re.findall(r'GSTR-?\d+[A-Z]?', text, re.IGNORECASE)
        key_terms.extend(forms[:3])
        
        # 2. Notification numbers
        notifications = re.findall(r'Notification\s+(?:No\.?\s*)?\d+/\d{4}', text, re.IGNORECASE)
        key_terms.extend(notifications[:2])
        
        # 3. Important keywords
        important_keywords = [
            'late fee', 'penalty', 'notice', 'filing', 'refund', 
            'portal error', 'ITC', 'composition', 'registration', 'mismatch'
        ]
        text_lower = text.lower()
        for keyword in important_keywords:
            if keyword in text_lower and len(key_terms) < max_terms:
                key_terms.append(keyword)
        
        # 4. Financial years
        if len(key_terms) < max_terms:
            fys = re.findall(r'FY\s*\d{4}-?\d{2,4}', text, re.IGNORECASE)
            key_terms.extend(fys[:1])
        
        return key_terms[:max_terms]
    
    def _build_focused_query_with_llm(self, query: str, category: str = None) -> str:
        """
        Use classifier LLM to extract focused search query from long text
        Reuses existing classifier_llm - no additional model initialization
        """
        if not self.llm:
            logger.info("   ℹ️ No LLM provided, using regex extraction")
            return self._build_focused_query_regex(query, category, None)
        
        try:
            prompt = f"""Extract a concise web search query (max 30 words) from this GST grievance:

Query: {query[:800]}

Focus on:
- GST form numbers (GSTR-4, GSTR-3A, etc.)
- Key issues (late fee, filing, penalty, notice, error)
- Notification/section numbers
- Important entities only

Return ONLY the search query, nothing else. No explanations."""
            
            response = self.llm.invoke(prompt)
            focused_query = response.content.strip()[:350]
            
            logger.info(f"   🤖 LLM-optimized query ({len(focused_query)} chars): {focused_query[:80]}...")
            return focused_query
            
        except Exception as e:
            logger.warning(f"⚠️ LLM query generation failed: {e}, using regex fallback")
            return self._build_focused_query_regex(query, category, None)
    
    def _build_focused_query_regex(self, query: str, category: str = None, 
                                    keywords: List[str] = None, max_length: int = 350) -> str:
        """
        Build a focused query using regex extraction (fallback method)
        """
        # Extract key terms
        if keywords:
            query_parts = keywords[:5]
        else:
            query_parts = self._extract_key_terms(query, max_terms=5)
        
        # Add category context if available
        if category and not self.unrestricted:
            category_context = {
                "gstr_filing": "GST return filing",
                "penalty_notice": "GST penalty late fee notice",
                "refund": "GST refund process",
                "registration": "GST registration",
                "itc_mismatch": "GST ITC mismatch",
                "eway_bill": "GST e-way bill",
                "portal_error": "GST portal error"
            }
            context = category_context.get(category, "GST")
            query_parts.insert(0, context)
        
        # Build query
        focused_query = " ".join(query_parts)
        
        # Truncate if still too long
        if len(focused_query) > max_length:
            focused_query = focused_query[:max_length].rsplit(' ', 1)[0]
        
        logger.info(f"   🔍 Regex-extracted query ({len(focused_query)} chars): {focused_query[:80]}...")
        return focused_query
    
    def _build_focused_query_tavily(self, query: str, category: str = None, 
                                     keywords: List[str] = None) -> str:
        """
        Build query for Tavily with 400-char limit
        Strategy: Use LLM if available, else fall back to regex
        """
        # Try LLM-based extraction first (preferred)
        if self.llm:
            return self._build_focused_query_with_llm(query, category)
        
        # Fall back to regex extraction
        return self._build_focused_query_regex(query, category, keywords)
    
    def _build_query_duckduckgo(self, query: str, category: str = None, 
                                keywords: List[str] = None) -> str:
        """Build query for DuckDuckGo (more lenient than Tavily)"""
        if self.unrestricted:
            # General web search - truncate to reasonable length
            enhanced_query = query[:500]
            if keywords:
                enhanced_query += " " + " ".join(keywords[:3])
        else:
            # Restricted to official sites
            enhanced_query = f"{query[:300]} GST India site:gst.gov.in OR site:cbic-gst.gov.in"
        
        return enhanced_query
    
    # --- RETRIEVAL METHODS ---
    
    def retrieve_tavily(self, query: str, max_results: int = 10) -> List[RetrievalSource]:
        """
        Retrieve using Tavily API
        
        Args:
            query: Optimized search query (already under 400 chars)
            max_results: Number of results (default 10)
        """
        results = []
        
        try:
            search_params = {
                "query": query,
                "search_depth": "advanced",
                "max_results": max_results,
                "include_answer": True,
                "include_raw_content": False
            }
            
            # Only add domain restrictions if NOT unrestricted
            if not self.unrestricted:
                search_params["include_domains"] = [
                    "gst.gov.in",
                    "cbic-gst.gov.in",
                    "tutorial.gst.gov.in",
                    "cleartax.in",
                    "taxguru.in"
                ]
                search_params["exclude_domains"] = ["quora.com", "reddit.com"]
            
            response = self.client.search(**search_params)
            
            # Process results
            for idx, result in enumerate(response.get('results', [])):
                results.append(RetrievalSource(
                    source_type="web_search",
                    content=f"**{result.get('title', 'Untitled')}**\n\n{result.get('content', '')}",
                    citation=result.get('url', 'Unknown source'),
                    relevance_score=result.get('score', 0.8),
                    date=None
                ))
            
            # Add direct answer if available
            if response.get('answer'):
                results.insert(0, RetrievalSource(
                    source_type="web_answer",
                    content=f"**AI-Generated Summary:**\n\n{response['answer']}",
                    citation="Tavily AI Direct Answer",
                    relevance_score=1.0,
                    date=None
                ))
            
            logger.info(f"   🌐 Retrieved {len(results)} results from Tavily")
            
        except Exception as e:
            logger.error(f"❌ Tavily search error: {e}")
        
        return results
    
    def retrieve_duckduckgo(self, query: str, max_results: int = 10) -> List[RetrievalSource]:
        """
        Retrieve using DuckDuckGo
        
        Args:
            query: Search query
            max_results: Number of results (default 10)
        """
        results = []
        
        try:
            # Execute search
            search_results = self.client.run(query)
            
            # Parse results
            if isinstance(search_results, str):
                chunks = [chunk.strip() for chunk in search_results.split('\n\n') if chunk.strip()]
                
                for idx, chunk in enumerate(chunks[:max_results]):
                    # Extract URL if present
                    url_match = re.search(r'\[(https?://[^\]]+)\]', chunk)
                    url = url_match.group(1) if url_match else "DuckDuckGo Search Result"
                    
                    # Clean content
                    content = re.sub(r'\[https?://[^\]]+\]', '', chunk).strip()
                    
                    if content:
                        results.append(RetrievalSource(
                            source_type="web_search",
                            content=content,
                            citation=url,
                            relevance_score=0.8 - (idx * 0.05),
                            date=None
                        ))
            
            logger.info(f"   🌐 Retrieved {len(results)} results from DuckDuckGo")
            
        except Exception as e:
            logger.error(f"❌ DuckDuckGo search error: {e}")
        
        return results
    
    def retrieve(self, query: str, category: str = None, 
                 keywords: List[str] = None, max_results: int = 10) -> List[RetrievalSource]:
        """
        Main retrieval method with intelligent query optimization
        
        Args:
            query: Search query (may be long)
            category: GST category (optional)
            keywords: Additional keywords (optional)
            max_results: Number of results (default 10)
        
        Returns:
            List of RetrievalSource objects
        """
        if not self.client:
            logger.warning("⚠️ No search provider available")
            return []
        
        # Build optimized query based on provider
        if self.provider == "tavily":
            # Tavily has 400-char limit - use intelligent optimization
            enhanced_query = self._build_focused_query_tavily(query, category, keywords)
        else:
            # DuckDuckGo is more lenient
            enhanced_query = self._build_query_duckduckgo(query, category, keywords)
        
        # Route to provider
        if self.provider == "tavily":
            return self.retrieve_tavily(enhanced_query, max_results)
        elif self.provider == "duckduckgo":
            return self.retrieve_duckduckgo(enhanced_query, max_results)
        else:
            return []

# --- TWITTER RETRIEVAL AGENT ---
class TwitterRetrievalAgent:
    def __init__(self):
        self.bearer_token = Config.TWITTER_BEARER_TOKEN
        self.client = None
        if self.bearer_token and tweepy:
            try:
                self.client = tweepy.Client(bearer_token=self.bearer_token)
            except:
                pass

    def retrieve(self, keywords: List[str], max_results: int = 10) -> List[RetrievalSource]:
        results = []
        if not self.client:
            logger.warning("⚠️ Twitter API not configured")
            return results
        
        try:
            query = f"from:{Config.GSTN_TWITTER.replace('@', '')} ({' OR '.join(keywords)})"
            tweets = self.client.search_recent_tweets(
                query=query, max_results=min(max_results, 100),
                tweet_fields=["created_at", "text"]
            )
            if tweets.data:
                for tweet in tweets.data[:max_results]:
                    results.append(RetrievalSource(
                        source_type="twitter",
                        content=tweet.text,
                        citation=f"@Infosys_GSTN tweet from {tweet.created_at.strftime('%Y-%m-%d')}",
                        relevance_score=0.8,
                        date=tweet.created_at.strftime('%Y-%m-%d')
                    ))
            logger.info(f"   📱 Retrieved {len(results)} tweets")
        except Exception as e:
            logger.error(f"❌ Twitter error: {e}")
        return results

# --- NEW: SUB-AGENT 3.4: LLM REASONING RETRIEVAL AGENT ---
class LLMReasoningAgent:
    """
    Agent for retrieving reasoning-based analysis from DeepSeek LLM
    Provides legal interpretation of core issues based on GST statutes and circulars
    """
    
    def __init__(self):
        """Initialize DeepSeek client"""
        self.api_key = Config.DEEPSEEK_API_KEY
        self.model = Config.DEEPSEEK_MODEL
        self.client = None
        
        if self.api_key:
            try:
                self.client = OpenAI(
                    api_key=self.api_key,
                    base_url="https://api.deepseek.com"
                )
                logger.info("✅ DeepSeek LLM client initialized")
            except Exception as e:
                logger.error(f"❌ DeepSeek initialization failed: {e}")
                self.client = None
        else:
            logger.warning("⚠️ DEEPSEEK_API_KEY not found. LLM reasoning disabled.")
    
    def retrieve(self, core_issues: List[CoreIssue], entities: List[ExtractedEntity]) -> List[RetrievalSource]:
        """
        Get reasoning-based analysis from DeepSeek for each core issue
        
        Args:
            core_issues: List of core issues from preprocessing
            entities: List of extracted entities for context
            
        Returns:
            List of RetrievalSource objects with LLM reasoning
        """
        results = []
        
        if not self.client:
            logger.warning("⚠️ DeepSeek client not available. Skipping LLM reasoning.")
            return results
        
        try:
            # Build context from entities
            entity_context = self._build_entity_context(entities)
            
            # Process each core issue
            for idx, issue in enumerate(core_issues):
                logger.info(f"   🤖 Querying DeepSeek for issue {idx+1}/{len(core_issues)}...")
                
                # Build prompt for this specific issue
                prompt = self._build_issue_prompt(issue, entity_context)
                
                # Call DeepSeek API
                try:
                    response = self.client.chat.completions.create(
                        model=self.model,
                        messages=[
                            {
                                "role": "system",
                                "content": self._get_system_prompt()
                            },
                            {
                                "role": "user",
                                "content": prompt
                            }
                        ],
                        temperature=0.3,  # Balanced between creativity and consistency
                        max_tokens=1000,
                        stream=False
                    )
                    
                    reasoning = response.choices[0].message.content
                    
                    # Add to results
                    results.append(RetrievalSource(
                        source_type="llm_reasoning",
                        content=f"**Issue {idx+1}: {issue.issue_text}**\n\n{reasoning}",
                        citation=f"DeepSeek Legal Analysis (Issue {idx+1})",
                        relevance_score=0.95,  # High relevance as it's targeted analysis
                        date=None
                    ))
                    
                except Exception as e:
                    logger.error(f"❌ DeepSeek API error for issue {idx+1}: {e}")
                    # Add error placeholder
                    results.append(RetrievalSource(
                        source_type="llm_reasoning",
                        content=f"**Issue {idx+1}: {issue.issue_text}**\n\nUnable to retrieve reasoning: {str(e)}",
                        citation=f"DeepSeek Analysis (Error)",
                        relevance_score=0.0,
                        date=None
                    ))
            
            logger.info(f"   🤖 Retrieved {len(results)} LLM reasoning results from DeepSeek")
            
        except Exception as e:
            logger.error(f"❌ LLM reasoning retrieval error: {e}")
        
        return results
    
    def _get_system_prompt(self) -> str:
        """Get the system prompt for DeepSeek"""
        return """You are an expert GST (Goods and Services Tax) legal advisor for India with deep knowledge of:
- Central GST Act, 2017 and State GST Acts
- GST Rules and Regulations
- CBIC (Central Board of Indirect Taxes and Customs) notifications, circulars, and orders
- GST Council decisions and press releases
- GSTN (Goods and Services Tax Network) guidelines
- Case laws and tribunal decisions related to GST

Your role is to provide accurate, legally sound analysis of GST grievances based on:
1. Relevant sections of CGST/SGST/IGST Acts
2. Applicable rules and notifications
3. Official circulars and clarifications
4. Precedents from case law (if relevant)
5. If user quotes a statute or notificaton, do check if the same has been amended in future.

Provide clear, actionable guidance while citing specific legal provisions. If unsure about any aspect, explicitly state the uncertainty and recommend verification from official sources or professional consultation."""

    def _build_entity_context(self, entities: List[ExtractedEntity]) -> str:
        """Build context string from extracted entities"""
        if not entities:
            return "No specific entities identified."
        
        entity_groups = {}
        for entity in entities:
            entity_type = entity.entity_type
            if entity_type not in entity_groups:
                entity_groups[entity_type] = []
            entity_groups[entity_type].append(entity.value)
        
        context_parts = []
        for entity_type, values in entity_groups.items():
            context_parts.append(f"- {entity_type}: {', '.join(values)}")
        
        return "**Context:**\n" + "\n".join(context_parts)
    
    def _build_issue_prompt(self, issue: CoreIssue, entity_context: str) -> str:
        """Build prompt for a specific issue"""
        return f"""Analyze the following GST grievance issue and provide expert legal guidance.

**Issue:**
{issue.issue_text}

**Keywords:** {', '.join(issue.keywords)}

{entity_context}

**Instructions:**
1. Identify applicable GST provisions (Acts, Rules, Notifications, Circulars)
2. Provide legal interpretation based on current GST law
3. Explain the correct procedure or resolution
4. Cite specific sections, rules, or notification numbers where relevant
5. Highlight any recent updates or clarifications from CBIC/GSTN
6. If the issue involves misconceptions, clarify them with legal basis
7. Keep the response concise (max 300 words) but comprehensive

**Response Format:**
- Legal Basis: [Cite specific provisions]
- Analysis: [Your legal interpretation]
- Resolution: [Clear guidance on what the taxpayer should do]
- Important Notes: [Any warnings, deadlines, or critical points]"""

# --- RETRIEVAL ORCHESTRATOR ---
class RetrievalOrchestrator:
    """Orchestrates retrieval from all sources (4 agents)"""
    
    def __init__(self):
        self.twitter_agent = TwitterRetrievalAgent()
        
        self.local_agent = LocalRetrievalAgent(
            embeddings=local_embeddings,
            kb_folder="./",
            enable_graph=True
        )
        
        # Pass classifier_llm for query optimization
        self.web_agent = WebRetrievalAgent(
            provider="auto",
            unrestricted=Config.WEB_SEARCH_UNRESTRICTED,
            llm=classifier_llm  # Reuse classifier LLM
        )
        
        self.llm_agent = LLMReasoningAgent()
        
        logger.info("✅ Retrieval orchestrator initialized")
        stats = self.local_agent.get_stats()
        logger.info(f"   📊 Local KB:")
        logger.info(f"      Files: {stats['source_files']}")
        logger.info(f"      Chunks: {stats['document_count']}")
        logger.info(f"      Size: {stats['index_size_mb']:.2f} MB")
        if stats['graph_enabled']:
            logger.info(f"      Graph: {stats['graph_nodes']} nodes, {stats['graph_edges']} edges")
        logger.info(f"   🌐 Web search: {self.web_agent.provider} (LLM optimization: {'enabled' if self.web_agent.llm else 'disabled'})")
        logger.info(f"   🤖 LLM reasoning: {'enabled' if self.llm_agent.client else 'disabled'}")
    
    
    def process(self, state: AgentState) -> AgentState:
        """
        THIS METHOD WAS MISSING - Process retrieval from all sources
        """
        try:
            logger.info("🔄 Agent 3: Retrieving from multiple sources (4 agents)...")
            start_time = time.time()
            
            preprocessing = state["preprocessing_output"]
            classification = state["classification_output"]
            
            # Extract keywords from all core issues
            all_keywords = list(set([
                kw for issue in preprocessing.core_issues for kw in issue.keywords
            ]))
            
            # Build combined query
            combined_query = " ".join([issue.issue_text for issue in preprocessing.core_issues])
            
            # 1. Twitter (optional, real-time)
            twitter_results = self.twitter_agent.retrieve(all_keywords, Config.MAX_TWITTER_RESULTS)
            
            # 2. Local KB (FAISS vector store)
            local_results = self.local_agent.retrieve(
                query=combined_query,
                k=Config.MAX_LOCAL_RESULTS,
                filter_category=None,
                use_graph=True  # Enable graph-enhanced retrieval
            )
            
            # 3. Web search
            web_results = self.web_agent.retrieve(
                query=combined_query,
                category=classification.primary_category if not Config.WEB_SEARCH_UNRESTRICTED else None,
                keywords=all_keywords[:5],
                max_results=Config.MAX_WEB_RESULTS
            )
            
            # 4. LLM Reasoning (DeepSeek)
            llm_reasoning = self.llm_agent.retrieve(
                core_issues=preprocessing.core_issues,
                entities=preprocessing.entities
            )
            
            # Aggregate results
            retrieval_output = RetrievalOutput(
                twitter_results=twitter_results,
                local_results=local_results,
                web_results=web_results,
                llm_reasoning=llm_reasoning,
                total_sources=len(twitter_results) + len(local_results) + len(web_results) + len(llm_reasoning),
                retrieval_time=time.time() - start_time
            )
            
            state["retrieval_output"] = retrieval_output
            
            logger.info(f"✅ Retrieved {retrieval_output.total_sources} sources")
            logger.info(f"   📱 Twitter: {len(twitter_results)}")
            logger.info(f"   📖 Local KB: {len(local_results)}")
            logger.info(f"   🌐 Web: {len(web_results)}")
            logger.info(f"   🤖 LLM Reasoning: {len(llm_reasoning)}")
            
            return state
            
        except Exception as e:
            logger.error(f"❌ Retrieval error: {e}")
            state["errors"].append(f"Retrieval error: {str(e)}")
            return state

# --- RESOLVER AND RESPONSE AGENTS ---
class ResolverAgent:
    def __init__(self, llm):
        self.llm = llm
        self.parser = JsonOutputParser(pydantic_object=ResolverOutput)

    def process(self, state: AgentState) -> AgentState:
        try:
            logger.info("🔄 Agent 4: Resolving with Gemini 2.5 Flash...")
            preprocessing = state["preprocessing_output"]
            retrieval = state["retrieval_output"]
            
            # UPDATED: Include LLM reasoning in prompt
            prompt = f"""
Resolve these GST issues using provided sources:

Core Issues: {json.dumps([i.model_dump() for i in preprocessing.core_issues])}

**Retrieved Sources:**

Twitter Updates: {[s.content for s in retrieval.twitter_results[:3]]}

Local Knowledge Base: {[s.content for s in retrieval.local_results[:5]]}

Web Search Results: {[s.content for s in retrieval.web_results[:5]]}

DeepSeek Legal Reasoning: {[s.content for s in retrieval.llm_reasoning]}

**Instructions:**
- Provide accurate, actionable resolutions for each issue
- Cross-validate information across all sources
- Give priority to LLM reasoning for legal interpretation
- Cite sources appropriately
- Min confidence: 95. If uncertain, set resolution to null with reason
- Return JSON matching ResolverOutput schema

{self.parser.get_format_instructions()}
"""
            
            response = self.llm.invoke(prompt)
            resolver_output = ResolverOutput(**json.loads(response.content))
            state["resolver_output"] = resolver_output
            logger.info(f"✅ Confidence: {resolver_output.overall_confidence}%")
            return state
        except Exception as e:
            logger.error(f"❌ Resolver error: {e}")
            state["resolver_output"] = ResolverOutput(
                resolutions=[], overall_confidence=0, requires_escalation=True
            )
            return state


class ResponseGenerationAgent:
    def process(self, state: AgentState) -> AgentState:
        try:
            logger.info("🔄 Agent 5: Generating response...")
            resolver = state["resolver_output"]
            retrieval = state["retrieval_output"]
            
            parts = []
            for res in resolver.resolutions:
                if res.resolution:
                    parts.append(f"**{res.issue}**\n{res.resolution}")
                else:
                    parts.append(f"**{res.issue}**\n⚠️ Unable to answer. {res.reason_for_null}")
            
            direct_answer = "\n\n".join(parts) if parts else "Manual review recommended."
            
            recent_updates = None
            if retrieval.twitter_results:
                recent_updates = "Recent GSTN Updates:\n" + "\n".join([
                    f"- {s.content} ({s.citation})" for s in retrieval.twitter_results[:3]
                ])
            
            final_response = FinalResponse(
                direct_answer=direct_answer,
                detailed_explanation=None,
                legal_basis=None,
                recent_updates=recent_updates,
                additional_resources=[
                    "GST Portal: https://www.gst.gov.in",
                    "GSTN Tutorials: https://tutorial.gst.gov.in",
                    "CBIC GST: https://cbic-gst.gov.in"
                ],
                confidence_score=resolver.overall_confidence,
                requires_manual_review=resolver.requires_escalation
            )
            
            state["final_response"] = final_response
            return state
        except Exception as e:
            logger.error(f"❌ Response generation error: {e}")
            return state

# --- WORKFLOW ---
def should_escalate(state: AgentState) -> str:
    if state.get("escalation_requested") or (state.get("final_response") and state["final_response"].requires_manual_review):
        return "escalate"
    return "complete"

def handle_escalation(state: AgentState) -> AgentState:
    logger.info("🔄 Agent 6: Escalating...")
    if state["final_response"]:
        state["final_response"].direct_answer = "Escalated to GST expert team.\n\n" + state["final_response"].direct_answer
    return state

def create_workflow():
    preprocessing_agent = PreprocessingAgent(preprocessor_llm)
    classification_agent = ClassificationAgent(classifier_llm)
    retrieval_orchestrator = RetrievalOrchestrator()
    resolver_agent = ResolverAgent(resolver_llm)
    response_agent = ResponseGenerationAgent()

    workflow = StateGraph(AgentState)
    workflow.add_node("preprocessing", preprocessing_agent.process)
    workflow.add_node("classification", classification_agent.process)
    workflow.add_node("retrieval", retrieval_orchestrator.process)
    workflow.add_node("resolver", resolver_agent.process)
    workflow.add_node("response_generation", response_agent.process)
    workflow.add_node("escalation", handle_escalation)

    workflow.set_entry_point("preprocessing")
    workflow.add_edge("preprocessing", "classification")
    workflow.add_edge("classification", "retrieval")
    workflow.add_edge("retrieval", "resolver")
    workflow.add_edge("resolver", "response_generation")
    workflow.add_conditional_edges("response_generation", should_escalate, {
        "escalate": "escalation", "complete": END
    })
    workflow.add_edge("escalation", END)

    return workflow.compile()

# --- MAIN FUNCTION ---
def process_gst_grievance(query: str, session_id: str = None) -> Dict[str, Any]:
    initial_state = AgentState(
        user_query=query,
        session_id=session_id or f"session_{int(time.time())}",
        conversation_history=[],
        preprocessing_output=None,
        classification_output=None,
        retrieval_output=None,
        resolver_output=None,
        final_response=None,
        timestamp=datetime.now().isoformat(),
        processing_time=0.0,
        errors=[],
        feedback_received=None,
        iteration_count=0,
        escalation_requested=False
    )

    workflow_app = create_workflow()
    logger.info("=" * 80)
    logger.info("🚀 Starting GST Grievance Resolution")
    logger.info(f"📝 Query: {query}")
    logger.info("=" * 80)

    final_state = workflow_app.invoke(initial_state)

    result = {
        "session_id": final_state["session_id"],
        "query": query,
        "response": final_state["final_response"].model_dump() if final_state["final_response"] else None,
        "preprocessing": final_state["preprocessing_output"].model_dump() if final_state["preprocessing_output"] else None,
        "classification": final_state["classification_output"].model_dump() if final_state["classification_output"] else None,
        "retrieval_stats": {
            "total_sources": final_state["retrieval_output"].total_sources if final_state["retrieval_output"] else 0,
            "retrieval_time": final_state["retrieval_output"].retrieval_time if final_state["retrieval_output"] else 0,
            "local_count": len(final_state["retrieval_output"].local_results) if final_state["retrieval_output"] else 0,
            "web_count": len(final_state["retrieval_output"].web_results) if final_state["retrieval_output"] else 0,
            "twitter_count": len(final_state["retrieval_output"].twitter_results) if final_state["retrieval_output"] else 0,
            "llm_count": len(final_state["retrieval_output"].llm_reasoning) if final_state["retrieval_output"] else 0,
        },
        "resolution_stats": {
            "overall_confidence": final_state["resolver_output"].overall_confidence if final_state["resolver_output"] else 0,
            "requires_escalation": final_state["resolver_output"].requires_escalation if final_state["resolver_output"] else True
        },
        "processing_time": final_state["processing_time"],
        "errors": final_state["errors"]
    }

    logger.info("=" * 80)
    logger.info(f"✅ Complete. Time: {final_state['processing_time']:.2f}s")
    logger.info(f"🎯 Confidence: {result['resolution_stats']['overall_confidence']}%")
    logger.info("=" * 80)

    return result

# --- DISPLAY FUNCTION ---
def display_result(result: Dict[str, Any]):
    if not result["response"]:
        display(Markdown("## ❌ Error: No response generated"))
        return

    response = result["response"]
    display(Markdown(f"# 🎯 GST Grievance Response"))
    display(Markdown(f"**Session:** {result['session_id']}"))
    display(Markdown(f"**Time:** {result['processing_time']:.2f}s"))
    display(Markdown(f"**Confidence:** {response['confidence_score']}%"))
    display(Markdown(f"**Status:** {'⚠️ Manual review' if response['requires_manual_review'] else '✅ Resolved'}"))
    
    # Display retrieval stats
    stats = result['retrieval_stats']
    display(Markdown(f"**Sources:** {stats['total_sources']} (Local: {stats['local_count']}, Web: {stats['web_count']}, Twitter: {stats['twitter_count']}, LLM: {stats.get('llm_count', 0)})"))
    
    display(Markdown("---"))
    display(Markdown("## 📝 Answer"))
    display(Markdown(response["direct_answer"]))
    
    if response["recent_updates"]:
        display(Markdown("## 📱 Recent Updates"))
        display(Markdown(response["recent_updates"]))
    
    display(Markdown("## 🔗 Resources"))
    for resource in response["additional_resources"]:
        display(Markdown(f"- {resource}"))


# --- UTILITY FUNCTIONS ---
def add_to_knowledge_base(documents: List[dict]):
    """Add documents to local knowledge base"""
    retrieval_orchestrator = RetrievalOrchestrator()
    retrieval_orchestrator.local_agent.add_documents(documents)

def get_kb_stats():
    """Get knowledge base statistics"""
    retrieval_orchestrator = RetrievalOrchestrator()
    return retrieval_orchestrator.local_agent.get_stats()

# --- USAGE ---
print("✅ GST Grievance Resolution System Loaded!")
print("\n📚 System Features:")
print("   • Local embeddings (MPS/CUDA accelerated)")
print("   • Persistent FAISS vector store")
print("   • Web search (Tavily/DuckDuckGo)")
print("   • Twitter real-time updates")
print("\n💡 Usage:")
print('   result = process_gst_grievance("Your GST query here")')
print('   display_result(result)')
print('\n🔧 Utilities:')
print('   stats = get_kb_stats()  # View knowledge base stats')
print('   add_to_knowledge_base([{...}])  # Add documents')


INFO: ✅ Using Apple Silicon GPU (MPS)
INFO: 📊 Device: mps
INFO: 🔧 PyTorch version: 2.8.0
INFO: 🍎 MPS built: True
INFO: 🔄 Initializing local embedding model...
INFO: 📦 Model: sentence-transformers/all-MiniLM-L6-v2
INFO: 💻 Device: mps
INFO: 🍎 MPS fallback enabled for unsupported operations
INFO: Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2
INFO: ✅ Local embedding model loaded successfully
INFO: 🎯 No API calls - fully offline embeddings
INFO: ✅ Embedding test successful! Dimension: 384
E0000 00:00:1760100288.589762 21847708 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.
E0000 00:00:1760100288.603345 21847708 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.
INFO: ✅ LLMs initialized:
E0000 00:00:1760100288.604636 21847708 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.
INFO:    Preprocessor: gemini-2.5-pro
INFO: 

✅ GST Grievance Resolution System Loaded!

📚 System Features:
   • Local embeddings (MPS/CUDA accelerated)
   • Persistent FAISS vector store
   • Web search (Tavily/DuckDuckGo)
   • Twitter real-time updates

💡 Usage:
   result = process_gst_grievance("Your GST query here")
   display_result(result)

🔧 Utilities:
   stats = get_kb_stats()  # View knowledge base stats
   add_to_knowledge_base([{...}])  # Add documents
