uing bnl environment
https://medium.com/@ayush4002gupta/building-an-llm-agent-to-directly-interact-with-a-database-0c0dd96b8196

pip uninstall -y langchain langchain-core langchain-openai langgraph langgraph-prebuilt langgraph-checkpoint langgraph-sdk langsmith langchain-community langchain-google-genai langchain-text-splitters

pip install langchain==1.2.0 langchain-core==1.2.2 langchain-openai==1.1.4 langgraph==1.0.5 langgraph-prebuilt==1.0.5 langgraph-checkpoint==3.0.1 langgraph-sdk==0.3.0 langsmith==0.5.0

In [13]:
# only for find models
# import google.generativeai as genai


In [14]:
# https://medium.com/@ayush4002gupta/building-an-llm-agent-to-directly-interact-with-a-database-0c0dd96b8196

import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from sql_utils import *
from datetime import datetime, timezone



load_dotenv()  # This looks for the .env file and loads it into os.environ

llm = ChatOpenAI(
    model="gpt-4o-mini",  # recommended for tools + cost
    api_key=os.environ["API_KEY"],
    temperature=0,
    seed=100,
)

response = llm.invoke([
    HumanMessage(content="Reply with exactly: OK")
])

print(response.content)

ENTITY_CONFIG = {
    "EMAIL": {
        "regex": r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
        "desc": "valid electronic mail formats used for account registration or contact"
    },
    "PHONE": {
        "regex": r"\+?[0-9]{1,4}[- .]?\(?[0-9]{1,3}?\)?[- .]?[0-9]{1,4}[- .]?[0-9]{1,4}[- .]?[0-9]{1,9}",
        "desc": "international or local telephone numbers"
    },
    "USERNAME": {
        "regex": r"\b[a-zA-Z][a-zA-Z0-9._-]{2,31}\b",
        "desc": "application-specific usernames, handles, or account identifiers"
    },
    "PERSON_NAME": {
        "regex": r"[A-Za-z][A-Za-z\s\.\-]{1,50}",
        "desc": (
            "loosely structured human name-like strings used only for discovery "
            "and column pre-filtering; final identification is performed during extraction"
        )
    }
}



OK


In [15]:
# Core Python
import sqlite3
import re
import json
from typing import TypedDict, Optional, List, Annotated
from langgraph.graph.message import add_messages

# LangChain / LangGraph
from langchain_core.tools import tool
from langchain_core.messages import (
    HumanMessage,
    AIMessage,
    SystemMessage
)
from langchain.agents import create_agent
from langgraph.graph import StateGraph, END
from langgraph.graph.message import MessagesState


@tool
def list_tables() -> str:
    """
    List non-empty user tables in the SQLite database.
    """
    IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
    conn = sqlite3.connect(DB_PATH)
    try:
        cur = conn.cursor()
        cur.execute("""
            SELECT name
            FROM sqlite_master
            WHERE type='table' AND name NOT LIKE 'sqlite_%'
            ORDER BY name
        """)
        tables = [r[0] for r in cur.fetchall()]

        nonempty = []
        for t in tables:
            # If your DB has weird table names, remove this guard,
            # but keep the quoting below.
            if not IDENT_RE.match(t):
                continue
            try:
                cur.execute(f'SELECT 1 FROM "{t}" LIMIT 1;')
                if cur.fetchone() is not None:
                    nonempty.append(t)
            except sqlite3.Error:
                continue

        return ", ".join(nonempty)
    finally:
        conn.close()


@tool
def get_schema(table: str) -> str:
    """

    Return column names and types for a table.
    """
    conn = sqlite3.connect(DB_PATH)
    cur = conn.cursor()
    cur.execute(f"PRAGMA table_info('{table}')")
    cols = cur.fetchall()
    conn.close()
    return ", ".join(f"{c[1]} {c[2]}" for c in cols)


@tool
def exec_sql(query: str) -> dict:
    """Execute SQL statements. If one fails, it is skipped and the next is executed."""
    query_text = normalize_sql(query)

    # 1. Parse column names from ALL SELECTs
    column_names = []
    for select_sql in split_union_selects(query_text):
        for col in extract_select_columns(select_sql):
            if col not in column_names:
                column_names.append(col)

    # 2. Execute once
    conn = sqlite3.connect(DB_PATH)
    conn.create_function("REGEXP", 2, regexp)
    cur = conn.cursor()

    try:
        print(f"[EXECUTE] Running query")
        cur.execute(query_text)
        rows = cur.fetchall()
    except Exception as e:
        print(f"[SQL ERROR]: {e}")
        rows = []
    finally:
        conn.close()

    return {
        "rows": rows,
        "columns": column_names
    }




class EmailEvidenceState(TypedDict):
    messages: Annotated[list, add_messages]
    attempt: int
    max_attempts: int
    phase: str  # "exploration" | "extraction"

    # SQL separation
    exploration_sql: Optional[str]
    extraction_sql: Optional[str]

    rows: Optional[List]
    classification: Optional[dict]
    evidence: Optional[List[str]]

    target_entity: str
    source_columns: Optional[List[str]]


def get_explore_system(target, regex):
    return SystemMessage(
        content=(
            "You are a SQL planner. You are provided app databases that are extracted from Android or iPhone devices.\n"
            "apps include Android Whatsapp, Snapchat, Telegram, Google Map, Samsung Internet, iPhone Contacts, Messages, Safari, and Calendar.\n"
            f"Goal: discover if any column of databases contains {target}.\n\n"
            "Rules:\n"
            "- Use 'REGEXP' for pattern matching.\n"
            f"- Example: SELECT col FROM table WHERE col REGEXP '{regex}' LIMIT 10\n"
            "- Pay attention to messages, chats, or other text fields.\n"
            "- Validate your SQL and make sure all tables and columns do exist.\n"
            "- If multiple SQL statements are provided, combine them using UNION ALL and LIMIT 200.\n"
            "- Return ONLY SQL."
        )
    )

    
def upgrade_sql_remove_limit(sql: str) -> str:
    _LIMIT_RE = re.compile(r"\s+LIMIT\s+\d+\s*;?\s*$", re.IGNORECASE)
    _LIMIT_ANYWHERE_RE = re.compile(r"\s+LIMIT\s+\d+\s*(?=($|\n|UNION|ORDER|GROUP|HAVING))", re.IGNORECASE)  
    # Remove LIMIT clauses robustly (including UNION queries)
    upgraded = re.sub(r"\bLIMIT\s+\d+\b", "", sql, flags=re.IGNORECASE)
    # Clean up extra whitespace
    upgraded = re.sub(r"\s+\n", "\n", upgraded)
    upgraded = re.sub(r"\n\s+\n", "\n", upgraded)
    upgraded = re.sub(r"\s{2,}", " ", upgraded).strip()
    return upgraded



def planner(state: EmailEvidenceState):
    # Extraction upgrade path
    if state["phase"] == "extraction" and state.get("exploration_sql"):
        extraction_sql = upgrade_sql_remove_limit(state["exploration_sql"])
        return {
            "messages": [AIMessage(content=extraction_sql)],
            "extraction_sql": extraction_sql
        }

    # Optional safety stop inside planner too
    if state.get("phase") == "exploration" and state.get("attempt", 0) >= state.get("max_attempts", 0):
        return {
            "phase": "done",
            "messages": [AIMessage(content="STOP: max attempts reached in planner.")]
        }
    # Original discovery logic
    tables = list_tables.invoke({})
    config = ENTITY_CONFIG[state["target_entity"]]

    base_system = get_explore_system(
        state["target_entity"],
        config["regex"]
    )

    grounded_content = (
        f"{base_system.content}\n\n"
        f"EXISTING TABLES: {tables}\n"
        f"CURRENT PHASE: {state['phase']}\n"
        "CRITICAL: Do not query non-existent tables."
    )

    agent = create_agent(llm, [list_tables,get_schema])
    
    result = agent.invoke({
        "messages": [
            SystemMessage(content=grounded_content),
            state["messages"][0]  # original user request only
        ]
    })

    exploration_sql  = normalize_sql(result["messages"][-1].content)

    attempt = state["attempt"] + 1 if state["phase"] == "exploration" else state["attempt"]

    return {
        "messages": [AIMessage(content=exploration_sql)],
        "exploration_sql": exploration_sql,
        "attempt": attempt
    }

def sql_execute(state: EmailEvidenceState):
    # Choose SQL based on phase
    if state["phase"] == "extraction":
        sql_to_run = state.get("extraction_sql")
    else:  # "exploration"
        sql_to_run = state.get("exploration_sql")

    if not sql_to_run:
        print("[SQL EXEC] No SQL provided for this phase")
        return {
            "rows": [],
            "messages": [AIMessage(content="No SQL to execute")]
        }

    # Execute
    result = exec_sql.invoke(sql_to_run)

    rows = result.get("rows", [])
    cols = result.get("columns", [])

    print(f"[SQL EXEC] Retrieved {len(rows)} rows")
    
    # for i, r in enumerate(rows, 1):
    #     print(f"  row[{i}]: {r}")

    updates = {
        "rows": rows,
        "messages": [AIMessage(content=f"Retrieved {len(rows)} rows")]
    }

    # Track columns only during extraction (provenance)
    if state["phase"] == "extraction":
        updates["source_columns"] = cols
        print(f"[TRACKING] Saved source columns: {cols}")

    return updates

    

def get_classify_system(target: str):
    return SystemMessage(
        content=(
            f"Decide whether the text contains {target}.\n"
            "Return ONLY a JSON object with these keys:\n"
            "{ \"found\": true/false, \"confidence\": number, \"reason\": \"string\" }"
        )
    )

def classify(state: EmailEvidenceState):
    # 1. Prepare the text sample for the LLM
    text = rows_to_text(state["rows"], limit=15)
    
    # 2. Get the target-specific system message
    system_message = get_classify_system(state["target_entity"])

    # 3. Invoke the LLM
    result = llm.invoke([
        system_message,
        HumanMessage(content=f"Data to analyze:\n{text}")
    ]).content
    
# 4. Parse the decision
    decision = safe_json_loads(
        result,
        default={"found": False, "confidence": 0.0, "reason": "parse failure"}
    )

    # print("[CLASSIFY]", decision)
    return {"classification": decision}


def switch_to_extraction(state: EmailEvidenceState):
    print("[PHASE] discovery ‚Üí extraction")
    return {"phase": "extraction"}


def extract(state: EmailEvidenceState):
    text = rows_to_text(state["rows"])
    # print(f"Check last 100 characts : {text[:-100]}")
    system = f"Identify real {state['target_entity']} from text and normalize them. Return ONLY a JSON array of strings.\n"
    result = llm.invoke([SystemMessage(content=system), HumanMessage(content=text)]).content
    return {"evidence": safe_json_loads(result, default=[])}


def next_step(state: EmailEvidenceState):
    # Once in extraction phase, extract and stop
    if state["phase"] == "extraction":
        return "do_extract"

    c = state["classification"]

    if c["found"] and c["confidence"] >= 0.6:
        return "to_extraction"

    if not c["found"] and c["confidence"] >= 0.6:
        return "stop_none"

    if state["attempt"] >= state["max_attempts"]:
        return "stop_limit"

    return "replan"

In [16]:
def observe(state: EmailEvidenceState):
    """
    Debug / inspection node.
    Does NOT modify state.
    """
    print("\n=== STATE SNAPSHOT ===")

    # Messages
    print("\n--- MESSAGES ---")
    for i, m in enumerate(state["messages"]):
        print(f"{i}: {m.type.upper()} -> {m.content}")

    # Metadata
    print("\n--- BEGIN METADATA ---")
    print(f"attempt         : {state['attempt']}")
    print(f"max_attempts    : {state['max_attempts']}")
    print(f"phase           : {state['phase']}")
    print(f"target_entity   : {state.get('target_entity')}")

    # SQL separation
    print(f"exploration_sql : {state.get('exploration_sql')}")
    print(f"extraction_sql  : {state.get('extraction_sql')}")

    # Outputs
    rows = state.get("rows") or []
    print(f"rows_count      : {len(rows)}")
    print(f"rows_sample     : {rows[:1000] if rows else []}")  # small sample to avoid huge logs

    print(f"classification  : {state.get('classification')}")
    print(f"evidence_count  : {len(state.get('evidence') or [])}")
    print(f"evidence_sample : {(state.get('evidence') or [])[:10]}")

    print(f"source_columns  : {state.get('source_columns')}")
    print("\n--- END METADATA ---")

    # IMPORTANT: do not return state, return no-op update
    return {}



from langgraph.graph import StateGraph, END

graph = StateGraph(EmailEvidenceState)

# Nodes
graph.add_node("planner", planner)
graph.add_node("observe_plan", observe)         # Checkpoint 1: The SQL Plan
graph.add_node("execute", sql_execute)
graph.add_node("observe_execution", observe)    # NEW Checkpoint: Post-execution
graph.add_node("classify", classify)
graph.add_node("observe_classify", observe)     # Checkpoint 2: Post-classify
graph.add_node("switch_phase", switch_to_extraction)
graph.add_node("extract", extract)
graph.add_node("observe_final", observe)        # Checkpoint 3: Final results

graph.set_entry_point("planner")

# --- FLOW ---
graph.add_edge("planner", "observe_plan")
graph.add_edge("observe_plan", "execute")

# NEW: observe after execution, before classify
graph.add_edge("execute", "observe_execution")
graph.add_edge("observe_execution", "classify")

graph.add_edge("classify", "observe_classify")

graph.add_conditional_edges(
    "observe_classify",
    next_step,
    {
        "to_extraction": "switch_phase",
        "do_extract": "extract",
        "replan": "planner",
        "stop_none": END,
        "stop_limit": END,
    }
)

graph.add_edge("switch_phase", "planner")

graph.add_edge("extract", "observe_final")
graph.add_edge("observe_final", END)

app = graph.compile()


In [17]:
from pathlib import Path
import json

DB_DIR = Path(r"selectedDBs_test")   # change if needed
# DB_DIR = Path(r"Agent_Evidence_Discovery\selectedDBs")   # change if needed
OUT_DIR = Path("batch_results")
OUT_DIR.mkdir(exist_ok=True)

PII_TARGETS = ["EMAIL", "PHONE", "USERNAME", "PERSON_NAME"]  # pick what you need
# PII_TARGETS = ["PHONE"]  # pick what you need

def is_sqlite_file(p: Path) -> bool:
    # quick header check to avoid weird files
    try:
        with p.open("rb") as f:
            return f.read(16) == b"SQLite format 3\x00"
    except Exception:
        return False

db_paths = sorted([
    p for p in DB_DIR.rglob("*")
    if p.suffix.lower() in {".db", ".sqlite", ".sqlite3"} and is_sqlite_file(p)
])

print(f"Found {len(db_paths)} sqlite files")

Found 1 sqlite files


In [18]:
all_results = []

for p in db_paths:
    DB_PATH = str(p)  # updates the global used by your @tool functions
    print(f"\nProcessing: {DB_PATH}")

    for target in PII_TARGETS:
        result = app.invoke({
            "messages": [HumanMessage(content=f"Find {target} in the database")],
            "attempt": 1,
            "max_attempts": 2,
            "phase": "exploration",
            "target_entity": target,
            "exploration_sql": None,
            "extraction_sql": None,
            "rows": None,
            "classification": None,
            "evidence": [],
            "source_columns": []
        })

        evidence = result.get("evidence", [])
        source_columns = result.get("source_columns", [])
        all_results.append({
            "db_path": DB_PATH,
            "PII_type": target,
            "PII": evidence,
            "Num_of_PII": len(evidence),
            "source_columns": source_columns,
            "Num_of_source_columns": len(source_columns)
        })

# Save one JSONL for easy grep/filter later
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
out_path = OUT_DIR / f"evidence_{ts}.jsonl"

with out_path.open("w", encoding="utf-8") as f:
    for r in all_results:
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

print(f"Wrote: {out_path.resolve()}")



Processing: selectedDBs_test\A1_msgstore.db

=== STATE SNAPSHOT ===

--- MESSAGES ---
0: HUMAN -> Find EMAIL in the database
1: AI -> SELECT text_data FROM message WHERE text_data REGEXP '[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
UNION ALL
SELECT description FROM message_text WHERE description REGEXP '[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
UNION ALL
SELECT vcard FROM message_vcard WHERE vcard REGEXP '[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
LIMIT 200;

--- BEGIN METADATA ---
attempt         : 2
max_attempts    : 2
phase           : exploration
target_entity   : EMAIL
exploration_sql : SELECT text_data FROM message WHERE text_data REGEXP '[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
UNION ALL
SELECT description FROM message_text WHERE description REGEXP '[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
UNION ALL
SELECT vcard FROM message_vcard WHERE vcard REGEXP '[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
LIMIT 200;
extraction_sql  : None
rows_count      : 