<a href="https://colab.research.google.com/github/pavi251503-cyber/Rag-based-sql-project/blob/main/text%20to%20sql%20.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
Human-in-the-Loop SQL Agent System - RAG-Based Backend
Includes: Authentication, RBAC, Audit Logging, RAG Pipeline, Two-Agent System
"""

import os
import json
import re
import sqlite3
import secrets
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional
from enum import Enum

import pandas as pd
import requests
from fastapi import FastAPI, Depends, HTTPException, status, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, validator
from passlib.context import CryptContext
import jwt

# ============================================================================
# CONFIGURATION
# ============================================================================

SECRET_KEY = os.getenv("SECRET_KEY", secrets.token_urlsafe(32))
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "")
OPENROUTER_MODEL = "openai/gpt-4o-mini"
EMBEDDING_MODEL = "text-embedding-3-small"
DB_PATH = "text_to_sql.db"
AUDIT_DB_PATH = "audit_logs.db"
VECTOR_DB_PATH = "vector_store.db"
MAX_ROWS = 1000
MAX_QUERY_COMPLEXITY = 5
TOP_K_EXAMPLES = 3  # Number of similar examples to retrieve

# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
security = HTTPBearer()

# ============================================================================
# ENUMS
# ============================================================================

class UserRole(str, Enum):
    SALES_MANAGER = "sales_manager"
    HR_PERSONNEL = "hr_personnel"
    FINANCE_ANALYST = "finance_analyst"
    INVENTORY_MANAGER = "inventory_manager"
    ADMIN = "admin"

class ActionType(str, Enum):
    LOGIN = "LOGIN"
    LOGOUT = "LOGOUT"
    QUERY_REQUEST = "QUERY_REQUEST"
    QUERY_GENERATED = "QUERY_GENERATED"
    QUERY_APPROVED = "QUERY_APPROVED"
    QUERY_REJECTED = "QUERY_REJECTED"
    QUERY_EXECUTION = "QUERY_EXECUTION"
    QUERY_FAILED = "QUERY_FAILED"

# ============================================================================
# PYDANTIC MODELS
# ============================================================================

class LoginRequest(BaseModel):
    username: str
    password: str

class LoginResponse(BaseModel):
    access_token: str
    token_type: str = "bearer"
    user_id: str
    username: str
    role: UserRole
    permissions: Dict[str, List[str]]

class QueryRequest(BaseModel):
    natural_language: str

    @validator('natural_language')
    def validate_query(cls, v):
        if not v or len(v.strip()) < 3:
            raise ValueError("Query must be at least 3 characters")
        if len(v) > 500:
            raise ValueError("Query too long (max 500 characters)")
        return v.strip()

class GeneratedSQLResponse(BaseModel):
    sql: str
    explanation: str
    confidence_score: float
    warnings: List[str]
    tables_accessed: List[str]
    query_id: str
    similar_examples: List[Dict[str, str]] = []  # RAG retrieved examples

class QueryApprovalRequest(BaseModel):
    query_id: str
    approved: bool
    feedback: Optional[str] = None

class QueryExecutionResponse(BaseModel):
    query_id: str
    status: str
    rows_returned: int
    execution_time_ms: float
    results: List[Dict[str, Any]]
    warnings: List[str] = []

# ============================================================================
# RBAC CONFIGURATION
# ============================================================================

RBAC_PERMISSIONS = {
    UserRole.SALES_MANAGER: {
        "tables": ["customers", "orders", "sales", "sales_team"],
        "forbidden_tables": ["employees", "salaries", "financial_statements"],
        "forbidden_columns": {"employees": ["salary", "ssn"]}
    },
    UserRole.HR_PERSONNEL: {
        "tables": ["employees", "departments", "attendance", "leave_requests"],
        "forbidden_tables": ["customers", "orders", "revenue", "financial_statements"],
        "forbidden_columns": {"employees": ["bank_account"]}
    },
    UserRole.FINANCE_ANALYST: {
        "tables": ["revenue", "expenses", "financial_statements", "budgets"],
        "forbidden_tables": ["employees"],
        "forbidden_columns": {"customers": ["email", "phone"]}
    },
    UserRole.INVENTORY_MANAGER: {
        "tables": ["products", "inventory", "suppliers", "warehouses"],
        "forbidden_tables": ["revenue", "salaries", "employees"],
        "forbidden_columns": {}
    },
    UserRole.ADMIN: {
        "tables": "*",
        "forbidden_tables": [],
        "forbidden_columns": {}
    }
}

# ============================================================================
# USER DATABASE (Mock - Replace with real DB in production)
# ============================================================================

USERS_DB = {
    "john.doe": {
        "user_id": "user_001",
        "username": "john.doe",
        "password_hash": pwd_context.hash("password123"),
        "role": UserRole.SALES_MANAGER,
        "email": "john.doe@company.com",
        "active": True
    },
    "jane.smith": {
        "user_id": "user_002",
        "username": "jane.smith",
        "password_hash": pwd_context.hash("password123"),
        "role": UserRole.HR_PERSONNEL,
        "email": "jane.smith@company.com",
        "active": True
    },
    "admin": {
        "user_id": "user_003",
        "username": "admin",
        "password_hash": pwd_context.hash("admin123"),
        "role": UserRole.ADMIN,
        "email": "admin@company.com",
        "active": True
    }
}

# ============================================================================
# RAG VECTOR STORE
# ============================================================================

class VectorStore:
    """Vector database for RAG - stores question-SQL pairs with embeddings"""

    def __init__(self, db_path: str = VECTOR_DB_PATH):
        self.db_path = db_path
        self._init_db()
        self._seed_examples()

    def _init_db(self):
        """Initialize vector store database"""
        conn = sqlite3.connect(self.db_path)
        cur = conn.cursor()
        cur.execute("""
            CREATE TABLE IF NOT EXISTS query_examples (
                example_id TEXT PRIMARY KEY,
                question TEXT NOT NULL,
                sql_query TEXT NOT NULL,
                tables_used TEXT NOT NULL,
                user_role TEXT,
                embedding TEXT NOT NULL,
                created_at TEXT NOT NULL,
                usage_count INTEGER DEFAULT 0,
                success_rate REAL DEFAULT 1.0
            )
        """)
        conn.commit()
        conn.close()

    def _seed_examples(self):
        """Seed database with example queries for RAG"""
        conn = sqlite3.connect(self.db_path)
        cur = conn.cursor()

        # Check if already seeded
        cur.execute("SELECT COUNT(*) FROM query_examples")
        if cur.fetchone()[0] > 0:
            conn.close()
            return

        # Seed examples for different roles
        examples = [
            # Sales Manager examples
            {
                "question": "Show all customers who placed orders in the last month",
                "sql": "SELECT c.customer_id, c.name, c.email FROM customers c JOIN orders o ON c.customer_id = o.customer_id WHERE o.order_date >= date('now', '-1 month') GROUP BY c.customer_id LIMIT 1000",
                "tables": ["customers", "orders"],
                "role": "sales_manager"
            },
            {
                "question": "What are the top 10 customers by total order value?",
                "sql": "SELECT c.customer_id, c.name, SUM(o.total_amount) as total_spent FROM customers c JOIN orders o ON c.customer_id = o.customer_id GROUP BY c.customer_id ORDER BY total_spent DESC LIMIT 10",
                "tables": ["customers", "orders"],
                "role": "sales_manager"
            },
            {
                "question": "List all sales team members and their performance",
                "sql": "SELECT employee_id, name, sales_target, sales_achieved, (sales_achieved * 100.0 / sales_target) as achievement_pct FROM sales_team ORDER BY achievement_pct DESC LIMIT 1000",
                "tables": ["sales_team"],
                "role": "sales_manager"
            },
            # HR Personnel examples
            {
                "question": "Show employees hired in 2024",
                "sql": "SELECT employee_id, name, department, hire_date FROM employees WHERE hire_date >= '2024-01-01' ORDER BY hire_date DESC LIMIT 1000",
                "tables": ["employees"],
                "role": "hr_personnel"
            },
            {
                "question": "List pending leave requests",
                "sql": "SELECT lr.request_id, e.name, lr.leave_type, lr.start_date, lr.end_date, lr.status FROM leave_requests lr JOIN employees e ON lr.employee_id = e.employee_id WHERE lr.status = 'pending' ORDER BY lr.start_date LIMIT 1000",
                "tables": ["leave_requests", "employees"],
                "role": "hr_personnel"
            },
            # Finance Analyst examples
            {
                "question": "Show monthly revenue for this year",
                "sql": "SELECT strftime('%Y-%m', revenue_date) as month, SUM(amount) as total_revenue FROM revenue WHERE revenue_date >= '2024-01-01' GROUP BY month ORDER BY month LIMIT 1000",
                "tables": ["revenue"],
                "role": "finance_analyst"
            },
            {
                "question": "Compare expenses by category",
                "sql": "SELECT category, SUM(amount) as total_expenses FROM expenses WHERE expense_date >= date('now', '-1 year') GROUP BY category ORDER BY total_expenses DESC LIMIT 1000",
                "tables": ["expenses"],
                "role": "finance_analyst"
            },
            # Inventory Manager examples
            {
                "question": "Show products with low stock levels",
                "sql": "SELECT p.product_id, p.name, i.quantity, p.reorder_level FROM products p JOIN inventory i ON p.product_id = i.product_id WHERE i.quantity < p.reorder_level ORDER BY i.quantity LIMIT 1000",
                "tables": ["products", "inventory"],
                "role": "inventory_manager"
            },
            {
                "question": "List all suppliers and their products",
                "sql": "SELECT s.supplier_id, s.name, COUNT(p.product_id) as product_count FROM suppliers s LEFT JOIN products p ON s.supplier_id = p.supplier_id GROUP BY s.supplier_id ORDER BY product_count DESC LIMIT 1000",
                "tables": ["suppliers", "products"],
                "role": "inventory_manager"
            }
        ]

        for ex in examples:
            example_id = f"ex_{secrets.token_hex(8)}"
            # Create simple embedding (in production, use actual embedding API)
            embedding = self._create_simple_embedding(ex["question"])

            cur.execute("""
                INSERT INTO query_examples (
                    example_id, question, sql_query, tables_used, user_role,
                    embedding, created_at, usage_count, success_rate
                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
            """, (
                example_id,
                ex["question"],
                ex["sql"],
                json.dumps(ex["tables"]),
                ex["role"],
                json.dumps(embedding),
                datetime.utcnow().isoformat(),
                0,
                1.0
            ))

        conn.commit()
        conn.close()

    def _create_simple_embedding(self, text: str) -> List[float]:
        """Create simple word-based embedding (in production, use OpenAI/Cohere API)"""
        # This is a placeholder - in production, use actual embedding API
        words = text.lower().split()
        # Simple bag-of-words style embedding
        vocab = ["customer", "order", "employee", "sales", "revenue", "expense",
                 "product", "inventory", "supplier", "department", "leave", "total",
                 "show", "list", "get", "find", "count", "sum", "group", "filter"]

        embedding = []
        for word in vocab:
            embedding.append(1.0 if word in words else 0.0)

        # Normalize
        magnitude = sum(x**2 for x in embedding) ** 0.5
        if magnitude > 0:
            embedding = [x / magnitude for x in embedding]

        return embedding

    def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
        """Calculate cosine similarity between two vectors"""
        dot_product = sum(a * b for a, b in zip(vec1, vec2))
        return dot_product

    def retrieve_similar_examples(self, question: str, user_role: str, top_k: int = TOP_K_EXAMPLES) -> List[Dict]:
        """RAG: Retrieve similar examples based on question embedding"""
        query_embedding = self._create_simple_embedding(question)

        conn = sqlite3.connect(self.db_path)
        cur = conn.cursor()

        # Get examples for user's role or general examples
        cur.execute("""
            SELECT example_id, question, sql_query, tables_used, embedding, usage_count, success_rate
            FROM query_examples
            WHERE user_role = ? OR user_role IS NULL
        """, (user_role,))

        examples = []
        for row in cur.fetchall():
            example_embedding = json.loads(row[4])
            similarity = self._cosine_similarity(query_embedding, example_embedding)

            examples.append({
                "example_id": row[0],
                "question": row[1],
                "sql_query": row[2],
                "tables_used": json.loads(row[3]),
                "usage_count": row[5],
                "success_rate": row[6],
                "similarity": similarity
            })

        conn.close()

        # Sort by similarity and return top-k
        examples.sort(key=lambda x: x["similarity"], reverse=True)
        return examples[:top_k]

    def add_successful_query(self, question: str, sql: str, tables: List[str], user_role: str):
        """Add a successful query to the knowledge base"""
        conn = sqlite3.connect(self.db_path)
        cur = conn.cursor()

        example_id = f"ex_{secrets.token_hex(8)}"
        embedding = self._create_simple_embedding(question)

        cur.execute("""
            INSERT INTO query_examples (
                example_id, question, sql_query, tables_used, user_role,
                embedding, created_at, usage_count, success_rate
            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
        """, (
            example_id, question, sql, json.dumps(tables), user_role,
            json.dumps(embedding), datetime.utcnow().isoformat(), 1, 1.0
        ))

        conn.commit()
        conn.close()

    def update_example_stats(self, example_id: str, success: bool):
        """Update usage statistics for an example"""
        conn = sqlite3.connect(self.db_path)
        cur = conn.cursor()

        cur.execute("""
            UPDATE query_examples
            SET usage_count = usage_count + 1,
                success_rate = (success_rate * usage_count + ?) / (usage_count + 1)
            WHERE example_id = ?
        """, (1.0 if success else 0.0, example_id))

        conn.commit()
        conn.close()

# ============================================================================
# AUDIT LOGGING
# ============================================================================

class AuditLogger:
    """Complete audit trail system for compliance"""

    def __init__(self, db_path: str = AUDIT_DB_PATH):
        self.db_path = db_path
        self._init_db()

    def _init_db(self):
        """Initialize audit log database"""
        conn = sqlite3.connect(self.db_path)
        cur = conn.cursor()
        cur.execute("""
            CREATE TABLE IF NOT EXISTS audit_logs (
                log_id TEXT PRIMARY KEY,
                timestamp TEXT NOT NULL,
                user_id TEXT NOT NULL,
                username TEXT NOT NULL,
                user_role TEXT NOT NULL,
                user_ip TEXT,
                action TEXT NOT NULL,
                natural_language TEXT,
                generated_sql TEXT,
                rag_examples_used TEXT,
                approved INTEGER,
                approval_time_seconds REAL,
                execution_status TEXT,
                execution_time_ms REAL,
                rows_returned INTEGER,
                data_accessed TEXT,
                compliance_flags TEXT,
                error_message TEXT,
                additional_data TEXT
            )
        """)
        conn.commit()
        conn.close()

    def log(self, action: ActionType, user: Dict, request: Request = None, **kwargs):
        """Log an action to audit trail"""
        conn = sqlite3.connect(self.db_path)
        cur = conn.cursor()

        log_id = f"audit_{secrets.token_hex(8)}"
        timestamp = datetime.utcnow().isoformat() + "Z"
        user_ip = request.client.host if request else "unknown"

        cur.execute("""
            INSERT INTO audit_logs (
                log_id, timestamp, user_id, username, user_role, user_ip,
                action, natural_language, generated_sql, rag_examples_used, approved,
                approval_time_seconds, execution_status, execution_time_ms,
                rows_returned, data_accessed, compliance_flags, error_message,
                additional_data
            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        """, (
            log_id, timestamp, user["user_id"], user["username"], user["role"],
            user_ip, action.value,
            kwargs.get("natural_language"),
            kwargs.get("generated_sql"),
            json.dumps(kwargs.get("rag_examples_used", [])),
            kwargs.get("approved"),
            kwargs.get("approval_time_seconds"),
            kwargs.get("execution_status"),
            kwargs.get("execution_time_ms"),
            kwargs.get("rows_returned"),
            json.dumps(kwargs.get("data_accessed", [])),
            json.dumps(kwargs.get("compliance_flags", [])),
            kwargs.get("error_message"),
            json.dumps(kwargs.get("additional_data", {}))
        ))

        conn.commit()
        conn.close()
        return log_id

# ============================================================================
# AUTHENTICATION & AUTHORIZATION
# ============================================================================

def verify_password(plain_password: str, hashed_password: str) -> bool:
    """Verify password against hash"""
    return pwd_context.verify(plain_password, hashed_password)

def create_access_token(data: dict, expires_delta: timedelta = None):
    """Create JWT access token"""
    to_encode = data.copy()
    expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
    to_encode.update({"exp": expire})
    return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)

def decode_token(token: str) -> Dict:
    """Decode and verify JWT token"""
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        return payload
    except jwt.ExpiredSignatureError:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Token expired"
        )
    except jwt.JWTError:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid token"
        )

def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> Dict:
    """Get current authenticated user from token"""
    token = credentials.credentials
    payload = decode_token(token)
    username = payload.get("sub")

    if not username or username not in USERS_DB:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid authentication credentials"
        )

    user = USERS_DB[username]
    if not user["active"]:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="User account is inactive"
        )

    return user

def check_table_permission(user: Dict, table: str) -> bool:
    """Check if user has permission to access table"""
    role = user["role"]
    permissions = RBAC_PERMISSIONS[role]

    if permissions["tables"] == "*":
        return True

    if table in permissions["forbidden_tables"]:
        return False

    return table in permissions["tables"]

# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================

def get_schema(db_path: str) -> Dict[str, List[Dict[str, str]]]:
    """Extract database schema"""
    conn = sqlite3.connect(db_path)
    cur = conn.cursor()
    cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';")
    tables = [r[0] for r in cur.fetchall()]
    schema = {}
    for t in tables:
        cur.execute(f'PRAGMA table_info({t})')
        cols = [{'name': r[1], 'type': r[2]} for r in cur.fetchall()]
        schema[t] = cols
    conn.close()
    return schema

def filter_schema_by_role(schema: Dict, user: Dict) -> Dict:
    """Filter schema based on user role permissions"""
    role = user["role"]
    permissions = RBAC_PERMISSIONS[role]

    if permissions["tables"] == "*":
        return schema

    filtered = {}
    for table, columns in schema.items():
        if table in permissions["tables"]:
            forbidden_cols = permissions["forbidden_columns"].get(table, [])
            filtered_cols = [c for c in columns if c["name"] not in forbidden_cols]
            filtered[table] = filtered_cols

    return filtered

def schema_to_text(schema: Dict[str, List[Dict[str, str]]]) -> str:
    """Convert schema dict to readable text"""
    parts = []
    for t, cols in schema.items():
        cols_str = ", ".join(f"{c['name']} ({c['type']})" for c in cols)
        parts.append(f"{t}: {cols_str}")
    return "\n".join(parts)

def strip_codeblocks(text: str) -> str:
    """Remove markdown code blocks"""
    if not text:
        return text
    text = text.strip()
    if text.startswith("```") and text.endswith("```"):
        parts = text.splitlines()
        parts = parts[1:-1]
        text = "\n".join(parts).strip()
    return text

def parse_json_like(response_text: str) -> Optional[Dict]:
    """Parse JSON from LLM response"""
    if not response_text:
        return None
    text = strip_codeblocks(response_text)
    s = text.find('{')
    e = text.rfind('}')
    if s != -1 and e != -1 and e > s:
        candidate = text[s:e+1]
        try:
            return json.loads(candidate)
        except Exception:
            pass
    try:
        return json.loads(text)
    except Exception:
        return None

def openrouter_chat(messages: List[Dict], model: str = OPENROUTER_MODEL, temperature: float = 0.1) -> str:
    """Call OpenRouter API with anti-hallucination settings"""
    if not OPENROUTER_API_KEY:
        raise RuntimeError("OpenRouter API Key missing!")

    url = "https://openrouter.ai/api/v1/chat/completions"
    headers = {
        "Authorization": f"Bearer {OPENROUTER_API_KEY}",
        "Content-Type": "application/json",
    }
    payload = {
        "model": model,
        "messages": messages,
        "temperature": temperature
    }

    res = requests.post(url, headers=headers, json=payload, timeout=60)
    if res.status_code != 200:
        raise RuntimeError(f"OpenRouter Error {res.status_code}: {res.text}")

    data = res.json()
    try:
        return data["choices"][0]["message"]["content"]
    except Exception:
        raise RuntimeError("Bad OpenRouter response format:\n" + str(data))

# ============================================================================
# AGENT 1: RAG-ENHANCED SQL GENERATOR
# ============================================================================

class ValidationResult:
    """Result of SQL validation"""
    def __init__(self, is_valid: bool, errors: List[str], warnings: List[str], confidence_score: float):
        self.is_valid = is_valid
        self.errors = errors
        self.warnings = warnings
        self.confidence_score = confidence_score

class SQLGeneratorAgent:
    """Agent 1: RAG-enhanced SQL generation from natural language"""

    def __init__(self, schema: Dict, user: Dict, vector_store: VectorStore):
        self.schema = schema
        self.user = user
        self.vector_store = vector_store
        self.table_names = set(schema.keys())

    def generate_sql(self, natural_language: str) -> Dict:
        """Generate SQL using RAG pipeline"""

        # Step 1: RAG Retrieval - Get similar examples
        similar_examples = self.vector_store.retrieve_similar_examples(
            natural_language,
            self.user["role"]
        )

        # Step 2: Translate to SQL using retrieved examples
        sql = self._translate_to_sql_with_rag(natural_language, similar_examples)

        # Step 3: Validate SQL
        validation = self._validate_sql(sql)

        if not validation.is_valid:
            raise ValueError(f"Generated invalid SQL: {', '.join(validation.errors)}")

        # Step 4: Extract tables accessed
        tables_accessed = self._extract_tables(sql)

        # Step 5: Generate explanation
        explanation = self._generate_explanation(natural_language, sql)

        return {
            "sql": sql,
            "explanation": explanation,
            "confidence_score": validation.confidence_score,
            "warnings": validation.warnings,
            "tables_accessed": tables_accessed,
            "similar_examples": [
                {
                    "question": ex["question"],
                    "sql": ex["sql_query"],
                    "similarity": round(ex["similarity"], 3)
                }
                for ex in similar_examples
            ]
        }

    def _translate_to_sql_with_rag(self, question: str, examples: List[Dict]) -> str:
        """RAG-enhanced translation: Use retrieved examples in prompt"""
        schema_text = schema_to_text(self.schema)

        # Format retrieved examples for few-shot learning
        examples_text = ""
        if examples:
            examples_text = "\n\nSIMILAR EXAMPLES FROM KNOWLEDGE BASE:\n"
            for i, ex in enumerate(examples, 1):
                examples_text += f"\nExample {i} (similarity: {ex['similarity']:.2f}):\n"
                examples_text += f"Question: {ex['question']}\n"
                examples_text += f"SQL: {ex['sql_query']}\n"

        system_prompt = f"""You are an expert SQL translator with STRICT safety rules and RAG-enhanced context.

User Role: {self.user['role']}
Authorized Tables ONLY: {', '.join(self.table_names)}

ABSOLUTE RULES:
1. Generate ONLY valid JSON: {{"sql": "<query>", "explanation": "<explanation>"}}
2. Use ONLY SELECT statements (NO DROP, DELETE, UPDATE, ALTER, INSERT, TRUNCATE)
3. Use ONLY tables from the authorized list above
4. SQLite syntax only
5. Add LIMIT clause (max {MAX_ROWS} rows)
6. No comments in SQL
7. Learn from the similar examples provided below

ANTI-HALLUCINATION MEASURES:
- Temperature: 0.1 (low randomness)
- Schema-aware: Only use provided tables/columns
- RAG-enhanced: Use similar examples as reference
- If unsure, follow the pattern of similar examples

Any violation will be automatically BLOCKED.{examples_text}"""

        user_prompt = f"""Schema (Authorized Tables Only):
{schema_text}

User question:
{question}

Generate SQL query as JSON with explanation. Use the similar examples above as reference."""

        response = openrouter_chat([
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ])

        parsed = parse_json_like(response)
        if not parsed or "sql" not in parsed:
            raise ValueError("Invalid LLM response format")

        return parsed["sql"].strip()

    def _validate_sql(self, sql: str) -> ValidationResult:
        """Comprehensive SQL validation"""
        errors = []
        warnings = []
        confidence = 1.0

        # 1. Must be SELECT
        if not sql.strip().upper().startswith('SELECT'):
            errors.append("Only SELECT statements allowed")
            confidence = 0.0

        # 2. Check for disallowed keywords
        disallowed = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'CREATE',
                      'REPLACE', 'TRUNCATE', 'EXEC', 'EXECUTE']
        for keyword in disallowed:
            if re.search(r'\b' + keyword + r'\b', sql, re.IGNORECASE):
                errors.append(f"BLOCKED: Disallowed keyword '{keyword}'")
                confidence = 0.0

        # 3. Check table permissions
        referenced_tables = self._extract_tables(sql)
        for table in referenced_tables:
            if not check_table_permission(self.user, table):
                errors.append(f"UNAUTHORIZED: No permission for table '{table}'")
                confidence = 0.0

        # 4. Check for SQL injection patterns
        injection_patterns = [r'--', r'/\*', r'xp_', r';\s*DROP', r';\s*DELETE']
        for pattern in injection_patterns:
            if re.search(pattern, sql, re.IGNORECASE):
                warnings.append(f"Security warning: Suspicious pattern '{pattern}'")
                confidence *= 0.8

        # 5. Check for LIMIT clause
        if not re.search(r'\bLIMIT\b', sql, re.IGNORECASE):
            warnings.append(f"No LIMIT clause - adding LIMIT {MAX_ROWS}")
            confidence *= 0.9

        # 6. Check query complexity
        join_count = len(re.findall(r'\bJOIN\b', sql, re.IGNORECASE))
        if join_count > MAX_QUERY_COMPLEXITY:
            errors.append(f"Too complex: {join_count} JOINs (max: {MAX_QUERY_COMPLEXITY})")
            confidence *= 0.5

        is_valid = len(errors) == 0
        return ValidationResult(is_valid, errors, warnings, confidence)

    def _extract_tables(self, sql: str) -> List[str]:
        """Extract table names from SQL"""
        tables = []
        from_tables = re.findall(r'\bFROM\s+(\w+)', sql, re.IGNORECASE)
        join_tables = re.findall(r'\bJOIN\s+(\w+)', sql, re.IGNORECASE)
        tables.extend(from_tables)
        tables.extend(join_tables)
        return list(set(tables))

    def _generate_explanation(self, question: str, sql: str) -> str:
        """Generate human-readable explanation of SQL"""
        return f"This query retrieves data to answer: '{question}'"

# ============================================================================
# AGENT 2: QUERY EXECUTOR
# ============================================================================

class QueryExecutorAgent:
    """Agent 2: Executes approved SQL queries safely"""

    def __init__(self, db_path: str):
        self.db_path = db_path

    def execute(self, sql: str, query_id: str) -> Dict:
        """Execute SQL with safety measures and timing"""
        start_time = datetime.utcnow()

        try:
            # Final security check
            self._final_security_check(sql)

            # Execute with read-only connection
            df = self._execute_with_limit(sql)

            end_time = datetime.utcnow()
            execution_time_ms = (end_time - start_time).total_seconds() * 1000

            return {
                "query_id": query_id,
                "status": "SUCCESS",
                "rows_returned": len(df),
                "execution_time_ms": round(execution_time_ms, 2),
                "results": df.to_dict(orient="records"),
                "warnings": []
            }

        except Exception as e:
            end_time = datetime.utcnow()
            execution_time_ms = (end_time - start_time).total_seconds() * 1000

            return {
                "query_id": query_id,
                "status": "FAILED",
                "rows_returned": 0,
                "execution_time_ms": round(execution_time_ms, 2),
                "results": [],
                "warnings": [str(e)]
            }

    def _final_security_check(self, sql: str):
        """Final security check before execution"""
        disallowed = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'TRUNCATE']
        for keyword in disallowed:
            if re.search(r'\b' + keyword + r'\b', sql, re.IGNORECASE):
                raise ValueError(f"Security block: {keyword} not allowed")

    def _execute_with_limit(self, sql: str) -> pd.DataFrame:
        """Execute SQL with result size limiting"""
        conn = sqlite3.connect(self.db_path)
        try:
            df = pd.read_sql_query(sql, conn)

            if len(df) > MAX_ROWS:
                df = df.head(MAX_ROWS)

            return df
        finally:
            conn.close()

# ============================================================================
# QUERY STATE MANAGEMENT
# ============================================================================

PENDING_QUERIES: Dict[str, Dict] = {}

# ============================================================================
# FASTAPI APPLICATION
# ============================================================================

app = FastAPI(
    title="RAG-Based Human-in-the-Loop SQL Agent API",
    description="Secure AI-powered database access with RAG and mandatory human approval",
    version="1.0.0"
)

# CORS Configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize components
audit_logger = AuditLogger()
vector_store = VectorStore()

# ============================================================================
# API ENDPOINTS
# ============================================================================

@app.post("/api/auth/login", response_model=LoginResponse)
async def login(request: Request, login_data: LoginRequest):
    """Authenticate user and return JWT token"""

    user = USERS_DB.get(login_data.username)

    if not user or not verify_password(login_data.password, user["password_hash"]):
        audit_logger.log(
            ActionType.LOGIN,
            {"user_id": "unknown", "username": login_data.username, "role": "none"},
            request,
            execution_status="FAILED",
            error_message="Invalid credentials"
        )
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid username or password"
        )

    access_token = create_access_token(
        data={"sub": user["username"]},
        expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    )

    permissions = RBAC_PERMISSIONS[user["role"]]

    audit_logger.log(
        ActionType.LOGIN,
        user,
        request,
        execution_status="SUCCESS"
    )

    return LoginResponse(
        access_token=access_token,
        user_id=user["user_id"],
        username=user["username"],
        role=user["role"],
        permissions=permissions
    )

@app.post("/api/generate-sql", response_model=GeneratedSQLResponse)
async def generate_sql(
    request: Request,
    query_request: QueryRequest,
    current_user: Dict = Depends(get_current_user)
):
    """AGENT 1: RAG-enhanced SQL generation from natural language"""

    start_time = datetime.utcnow()

    audit_logger.log(
        ActionType.QUERY_REQUEST,
        current_user,
        request,
        natural_language=query_request.natural_language
    )

    try:
        full_schema = get_schema(DB_PATH)
        filtered_schema = filter_schema_by_role(full_schema, current_user)

        if not filtered_schema:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail="No authorized tables for your role"
            )

        # RAG-enhanced SQL generation
        agent = SQLGeneratorAgent(filtered_schema, current_user, vector_store)
        result = agent.generate_sql(query_request.natural_language)

        query_id = f"query_{secrets.token_hex(8)}"

        PENDING_QUERIES[query_id] = {
            "query_id": query_id,
            "user": current_user,
            "natural_language": query_request.natural_language,
            "sql": result["sql"],
            "explanation": result["explanation"],
            "confidence_score": result["confidence_score"],
            "warnings": result["warnings"],
            "tables_accessed": result["tables_accessed"],
            "similar_examples": result["similar_examples"],
            "created_at": datetime.utcnow(),
            "approved": None
        }

        approval_time = (datetime.utcnow() - start_time).total_seconds()
        audit_logger.log(
            ActionType.QUERY_GENERATED,
            current_user,
            request,
            natural_language=query_request.natural_language,
            generated_sql=result["sql"],
            rag_examples_used=result["similar_examples"],
            data_accessed=result["tables_accessed"],
            approval_time_seconds=approval_time
        )

        return GeneratedSQLResponse(
            query_id=query_id,
            **result
        )

    except Exception as e:
        audit_logger.log(
            ActionType.QUERY_FAILED,
            current_user,
            request,
            natural_language=query_request.natural_language,
            execution_status="FAILED",
            error_message=str(e)
        )
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"SQL generation failed: {str(e)}"
        )

@app.post("/api/approve-query")
async def approve_query(
    request: Request,
    approval: QueryApprovalRequest,
    current_user: Dict = Depends(get_current_user)
):
    """Human verification checkpoint - Approve or reject generated SQL"""

    if approval.query_id not in PENDING_QUERIES:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="Query not found or expired"
        )

    query_data = PENDING_QUERIES[approval.query_id]

    if query_data["user"]["user_id"] != current_user["user_id"]:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Not authorized to approve this query"
        )

    approval_time = (datetime.utcnow() - query_data["created_at"]).total_seconds()

    query_data["approved"] = approval.approved
    query_data["approval_time_seconds"] = approval_time
    query_data["feedback"] = approval.feedback

    action = ActionType.QUERY_APPROVED if approval.approved else ActionType.QUERY_REJECTED
    audit_logger.log(
        action,
        current_user,
        request,
        natural_language=query_data["natural_language"],
        generated_sql=query_data["sql"],
        approved=approval.approved,
        approval_time_seconds=approval_time,
        additional_data={"feedback": approval.feedback}
    )

    if approval.approved:
        return {
            "status": "approved",
            "message": "Query approved. Proceeding to execution.",
            "query_id": approval.query_id
        }
    else:
        del PENDING_QUERIES[approval.query_id]
        return {
            "status": "rejected",
            "message": "Query rejected by user.",
            "query_id": approval.query_id
        }

@app.post("/api/execute-query", response_model=QueryExecutionResponse)
async def execute_query(
    request: Request,
    query_id: str,
    current_user: Dict = Depends(get_current_user)
):
    """AGENT 2: Execute approved SQL query"""

    if query_id not in PENDING_QUERIES:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="Query not found"
        )

    query_data = PENDING_QUERIES[query_id]

    if query_data["user"]["user_id"] != current_user["user_id"]:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Not authorized to execute this query"
        )

    if not query_data.get("approved"):
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Query not approved yet"
        )

    executor = QueryExecutorAgent(DB_PATH)
    result = executor.execute(query_data["sql"], query_id)

    # Add successful query to RAG knowledge base
    if result["status"] == "SUCCESS":
        vector_store.add_successful_query(
            query_data["natural_language"],
            query_data["sql"],
            query_data["tables_accessed"],
            current_user["role"]
        )

    audit_logger.log(
        ActionType.QUERY_EXECUTION,
        current_user,
        request,
        natural_language=query_data["natural_language"],
        generated_sql=query_data["sql"],
        execution_status=result["status"],
        execution_time_ms=result["execution_time_ms"],
        rows_returned=result["rows_returned"],
        data_accessed=query_data["tables_accessed"]
    )

    del PENDING_QUERIES[query_id]

    return QueryExecutionResponse(**result)

@app.get("/api/health")
async def health_check():
    """Health check endpoint"""
    return {
        "status": "healthy",
        "timestamp": datetime.utcnow().isoformat(),
        "version": "1.0.0",
        "rag_enabled": True
    }

@app.get("/api/user/permissions")
async def get_user_permissions(current_user: Dict = Depends(get_current_user)):
    """Get current user's permissions"""
    permissions = RBAC_PERMISSIONS[current_user["role"]]
    return {
        "user_id": current_user["user_id"],
        "username": current_user["username"],
        "role": current_user["role"],
        "permissions": permissions
    }

@app.get("/api/rag/stats")
async def get_rag_stats(current_user: Dict = Depends(get_current_user)):
    """Get RAG knowledge base statistics"""
    conn = sqlite3.connect(VECTOR_DB_PATH)
    cur = conn.cursor()

    cur.execute("SELECT COUNT(*) FROM query_examples")
    total_examples = cur.fetchone()[0]

    cur.execute("SELECT COUNT(*) FROM query_examples WHERE user_role = ?", (current_user["role"],))
    role_examples = cur.fetchone()[0]

    cur.execute("SELECT AVG(usage_count), AVG(success_rate) FROM query_examples")
    avg_stats = cur.fetchone()

    conn.close()

    return {
        "total_examples": total_examples,
        "role_specific_examples": role_examples,
        "avg_usage_count": round(avg_stats[0] or 0, 2),
        "avg_success_rate": round(avg_stats[1] or 0, 2)
    }

# ============================================================================
# STARTUP EVENT
# ============================================================================

@app.on_event("startup")
async def startup_event():
    """Initialize system on startup"""
    print("ðŸš€ RAG-Based Human-in-the-Loop SQL Agent System Starting...")
    print("âœ… Authentication system initialized")
    print("âœ… Audit logger initialized")
    print("âœ… RAG vector store initialized")
    print("âœ… System ready!")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

ModuleNotFoundError: No module named 'passlib'