uing bnl environment
https://medium.com/@ayush4002gupta/building-an-llm-agent-to-directly-interact-with-a-database-0c0dd96b8196

pip uninstall -y langchain langchain-core langchain-openai langgraph langgraph-prebuilt langgraph-checkpoint langgraph-sdk langsmith langchain-community langchain-google-genai langchain-text-splitters

pip install langchain==1.2.0 langchain-core==1.2.2 langchain-openai==1.1.4 langgraph==1.0.5 langgraph-prebuilt==1.0.5 langgraph-checkpoint==3.0.1 langgraph-sdk==0.3.0 langsmith==0.5.0

In [1]:
# only for find models
# import google.generativeai as genai


In [2]:
# https://medium.com/@ayush4002gupta/building-an-llm-agent-to-directly-interact-with-a-database-0c0dd96b8196

import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from sql_utils import *


load_dotenv()  # This looks for the .env file and loads it into os.environ

llm = ChatOpenAI(
    model="gpt-4o-mini",  # recommended for tools + cost
    api_key=os.environ["API_KEY"],
    temperature=0
)

response = llm.invoke([
    HumanMessage(content="Reply with exactly: OK")
])

print(response.content)

# DB_PATH = r"msgstore.db"
# DB_PATH = r"users4.db"
DB_PATH = r"Agent_Evidence_Discovery_Github\selectedDBs\ChatStorage.sqlite"
# DB_PATH = r"F:\mobile_images\Cellebriate_2024\Cellebrite_CTF_File1\CellebriteCTF24_Sharon\Sharon\EXTRACTION_FFS 01\EXTRACTION_FFS\Dump\data\data\com.whatsapp\databases\stickers.db"
# DB_PATH = r"F:\mobile_images\Cellebriate_2024\Cellebrite_CTF_File1\CellebriteCTF24_Sharon\Sharon\EXTRACTION_FFS 01\EXTRACTION_FFS\Dump\data\data\com.android.vending\databases\localappstate.db"

ENTITY_CONFIG = {
    "EMAIL": {
        "regex": r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
        "desc": "valid electronic mail formats used for account registration or contact"
    },
    "PHONE": {
        "regex": r"\+?[0-9]{1,4}[- .]?\(?[0-9]{1,3}?\)?[- .]?[0-9]{1,4}[- .]?[0-9]{1,4}[- .]?[0-9]{1,9}",
        "desc": "international or local telephone numbers"
    },
    "USERNAME": {
        "regex": r"\b[a-zA-Z][a-zA-Z0-9._-]{2,31}\b",
        "desc": "application-specific usernames, handles, or account identifiers"
    },
    "PERSON_NAME": {
        "regex": r"[A-Za-z][A-Za-z\s\.\-]{1,50}",
        "desc": (
            "loosely structured human name-like strings used only for discovery "
            "and column pre-filtering; final identification is performed during extraction"
        )
    }
}



OK


In [3]:
# Core Python
import sqlite3
import re
import json
from typing import TypedDict, Optional, List, Annotated
from langgraph.graph.message import add_messages

# LangChain / LangGraph
from langchain_core.tools import tool
from langchain_core.messages import (
    HumanMessage,
    AIMessage,
    SystemMessage
)
from langchain.agents import create_agent
from langgraph.graph import StateGraph, END
from langgraph.graph.message import MessagesState


@tool
def list_tables() -> str:
    """
    List all table names in the SQLite database.
    """
    conn = sqlite3.connect(DB_PATH)
    cur = conn.cursor()
    cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [r[0] for r in cur.fetchall()]
    conn.close()
    return ", ".join(tables)


@tool
def get_schema(table: str) -> str:
    """
    Return column names and types for a table.
    """
    conn = sqlite3.connect(DB_PATH)
    cur = conn.cursor()
    cur.execute(f"PRAGMA table_info('{table}')")
    cols = cur.fetchall()
    conn.close()
    return ", ".join(f"{c[1]} {c[2]}" for c in cols)




@tool
def exec_sql(query: str) -> dict:
    """Execute SQL statements. If one fails, it is skipped and the next is executed."""
    query_text = normalize_sql(query)

    # 1. Parse column names from ALL SELECTs
    column_names = []
    for select_sql in split_union_selects(query_text):
        for col in extract_select_columns(select_sql):
            if col not in column_names:
                column_names.append(col)

    # 2. Execute once
    conn = sqlite3.connect(DB_PATH)
    conn.create_function("REGEXP", 2, regexp)
    cur = conn.cursor()

    try:
        print(f"[EXECUTE] Running query")
        cur.execute(query_text)
        rows = cur.fetchall()
    except Exception as e:
        print(f"[SQL ERROR]: {e}")
        rows = []
    finally:
        conn.close()

    return {
        "rows": rows,
        "columns": column_names
    }




class EmailEvidenceState(TypedDict):
    messages: Annotated[list, add_messages]
    attempt: int
    max_attempts: int
    phase: str  # "exploration" | "extraction"

    # SQL separation
    exploration_sql: Optional[str]
    extraction_sql: Optional[str]

    rows: Optional[List]
    classification: Optional[dict]
    evidence: Optional[List[str]]

    target_entity: str
    source_columns: Optional[List[str]]


def get_discovery_system(target, regex):
    return SystemMessage(
        content=(
            "You are a SQL planner. You are provided databases that are extracted from Android or iOS devices.\n"
            f"Goal: discover if any column contains {target}.\n\n"
            "Rules:\n"
            "- Use 'REGEXP' for pattern matching.\n"
            f"- Example: SELECT col FROM table WHERE col REGEXP '{regex}' LIMIT 15\n"
            "- Validate your SQL and make sure all tables and columns do exist.\n"
            "- If multiple SQL statements are provided, combine them using UNION ALL and LIMIT 15.\n"
            "- Return ONLY SQL."
        )
    )

    
def get_sql_upgrade_system(target):
    return SystemMessage(
        content=(
            "You are a SQL refiner.\n"
            f"Goal: modify previously successful SQL to extract ALL {target}.\n\n"
            "Rules:\n"
            "- Do NOT invent new tables or columns.\n"
            "- Remove LIMIT clauses.\n"
            "- Preserve WHERE conditions and REGEXP filters.\n"
            "- Return ONLY SQL."
        )
    )


def planner(state: EmailEvidenceState):
    # Extraction upgrade path
    if state["phase"] == "extraction" and state.get("exploration_sql"):
        system = get_sql_upgrade_system(state["target_entity"])
        original_sql = state["exploration_sql"]

        result = llm.invoke([
            system,
            HumanMessage(content=f"Original SQL:\n{original_sql}")
        ])

        extraction_sql = normalize_sql(result.content)

        print("[PLANNER] Upgraded SQL for extraction")

        return {
            "messages": [AIMessage(content=original_sql)],
            "extraction_sql": extraction_sql
        }

    # Original discovery logic
    tables = list_tables.invoke({})
    config = ENTITY_CONFIG[state["target_entity"]]

    base_system = get_discovery_system(
        state["target_entity"],
        config["regex"]
    )

    grounded_content = (
        f"{base_system.content}\n\n"
        f"EXISTING TABLES: {tables}\n"
        f"CURRENT PHASE: {state['phase']}\n"
        "CRITICAL: Do not query non-existent tables."
    )

    agent = create_agent(llm, [list_tables, get_schema])
    
    result = agent.invoke({
        "messages": [
            SystemMessage(content=grounded_content),
            state["messages"][0]  # original user request only
        ]
    })

    exploration_sql  = normalize_sql(result["messages"][-1].content)

    attempt = state["attempt"] + 1 if state["phase"] == "discovery" else state["attempt"]

    return {
        "messages": [AIMessage(content=exploration_sql)],
        "exploration_sql": exploration_sql,
        "attempt": attempt
    }

def sql_execute(state: EmailEvidenceState):
    # Choose SQL based on phase
    if state["phase"] == "extraction":
        sql_to_run = state.get("extraction_sql")
    else:  # "exploration"
        sql_to_run = state.get("exploration_sql")

    if not sql_to_run:
        print("[SQL EXEC] No SQL provided for this phase")
        return {
            "rows": [],
            "messages": [AIMessage(content="No SQL to execute")]
        }

    # Execute
    result = exec_sql.invoke(sql_to_run)

    rows = result.get("rows", [])
    cols = result.get("columns", [])

    print(f"[SQL EXEC] Retrieved {len(rows)} rows")
    
    # for i, r in enumerate(rows, 1):
    #     print(f"  row[{i}]: {r}")

    updates = {
        "rows": rows,
        "messages": [AIMessage(content=f"Retrieved {len(rows)} rows")]
    }

    # Track columns only during extraction (provenance)
    if state["phase"] == "extraction":
        updates["source_columns"] = cols
        print(f"[TRACKING] Saved source columns: {cols}")

    return updates

    

def get_classify_system(target: str):
    return SystemMessage(
        content=(
            f"Decide whether the text contains {target}.\n"
            "Return ONLY a JSON object with these keys:\n"
            "{ \"found\": true/false, \"confidence\": number, \"reason\": \"string\" }"
        )
    )

def classify(state: EmailEvidenceState):
    # 1. Prepare the text sample for the LLM
    text = rows_to_text(state["rows"], limit=15)
    
    # 2. Get the target-specific system message
    system_message = get_classify_system(state["target_entity"])

    # 3. Invoke the LLM
    result = llm.invoke([
        system_message,
        HumanMessage(content=f"Data to analyze:\n{text}")
    ]).content
    
# 4. Parse the decision
    decision = safe_json_loads(
        result,
        default={"found": False, "confidence": 0.0, "reason": "parse failure"}
    )

    # print("[CLASSIFY]", decision)
    return {"classification": decision}


def switch_to_extraction(state: EmailEvidenceState):
    print("[PHASE] discovery ‚Üí extraction")
    return {"phase": "extraction"}




def extract(state: EmailEvidenceState):
    text = rows_to_text(state["rows"])
    print(f"Check last 100 characts : {text[:-100]}")
    system = f"Extract and normalize {state['target_entity']} from text. Return ONLY a JSON array of strings."
    result = llm.invoke([SystemMessage(content=system), HumanMessage(content=text)]).content
    return {"evidence": safe_json_loads(result, default=[])}


def next_step(state: EmailEvidenceState):
    # Once in extraction phase, extract and stop
    if state["phase"] == "extraction":
        return "do_extract"

    c = state["classification"]

    if c["found"] and c["confidence"] >= 0.6:
        return "to_extraction"

    if not c["found"] and c["confidence"] >= 0.6:
        return "stop_none"

    if state["attempt"] >= state["max_attempts"]:
        return "stop_limit"

    return "replan"

In [4]:
def observe(state: EmailEvidenceState):
    """
    Debug / inspection node.
    Does NOT modify state.
    """
    print("\n=== STATE SNAPSHOT ===")

    # Messages
    print("\n--- MESSAGES ---")
    for i, m in enumerate(state["messages"]):
        print(f"{i}: {m.type.upper()} -> {m.content}")

    # Metadata
    print("\n--- BEGIN METADATA ---")
    print(f"attempt         : {state['attempt']}")
    print(f"max_attempts    : {state['max_attempts']}")
    print(f"phase           : {state['phase']}")
    print(f"target_entity   : {state.get('target_entity')}")

    # SQL separation
    print(f"exploration_sql : {state.get('exploration_sql')}")
    print(f"extraction_sql  : {state.get('extraction_sql')}")

    # Outputs
    rows = state.get("rows") or []
    print(f"rows_count      : {len(rows)}")
    print(f"rows_sample     : {rows[:1000] if rows else []}")  # small sample to avoid huge logs

    print(f"classification  : {state.get('classification')}")
    print(f"evidence_count  : {len(state.get('evidence') or [])}")
    print(f"evidence_sample : {(state.get('evidence') or [])[:10]}")

    print(f"source_columns  : {state.get('source_columns')}")
    print("\n--- END METADATA ---")

    # IMPORTANT: do not return state, return no-op update
    return {}



from langgraph.graph import StateGraph, END

graph = StateGraph(EmailEvidenceState)

# Nodes
graph.add_node("planner", planner)
graph.add_node("observe_plan", observe)         # Checkpoint 1: The SQL Plan
graph.add_node("execute", sql_execute)
graph.add_node("observe_execution", observe)    # NEW Checkpoint: Post-execution
graph.add_node("classify", classify)
graph.add_node("observe_classify", observe)     # Checkpoint 2: Post-classify
graph.add_node("switch_phase", switch_to_extraction)
graph.add_node("extract", extract)
graph.add_node("observe_final", observe)        # Checkpoint 3: Final results

graph.set_entry_point("planner")

# --- FLOW ---
graph.add_edge("planner", "observe_plan")
graph.add_edge("observe_plan", "execute")

# NEW: observe after execution, before classify
graph.add_edge("execute", "observe_execution")
graph.add_edge("observe_execution", "classify")

graph.add_edge("classify", "observe_classify")

graph.add_conditional_edges(
    "observe_classify",
    next_step,
    {
        "to_extraction": "switch_phase",
        "do_extract": "extract",
        "replan": "planner",
        "stop_none": END,
        "stop_limit": END,
    }
)

graph.add_edge("switch_phase", "planner")

graph.add_edge("extract", "observe_final")
graph.add_edge("observe_final", END)

app = graph.compile()


In [5]:

# Set your target here once
# TARGET = "EMAIL" 
# TARGET = "PHONE"
# TARGET = "USERNAME"
TARGET = "PERSON_NAME"

result = app.invoke({
    "messages": [HumanMessage(content=f"Find {TARGET} in the database")],
    "attempt": 1,
    "max_attempts": 3,
    "phase": "exploration",
    "target_entity": TARGET,          # tells the planner what to look for

    # SQL separation
    "exploration_sql": None,
    "extraction_sql": None,

    # Runtime outputs
    "rows": None,
    "classification": None,
    "evidence": [],

    # Provenance
    "source_columns": []
})

# Use the generic 'evidence' key we defined in the state
final_evidence = result.get("evidence", [])
target_label = result.get("target_entity", "items")

print("\n" + "="*40)
print(f"       üèÅ FORENSIC REPORT: {target_label.upper()}       ")
print("="*40)

if final_evidence:
    print(f"‚úÖ Success! Found {len(final_evidence)} unique {target_label}:")
    for i, item in enumerate(sorted(final_evidence), 1):
        print(f"  {i}. {item}")
    
    # Also print the source columns we tracked!
    sources = result.get("source_columns")
    if sources:
        print(f"\nSource Columns: {', '.join(sources)}")
else:
    print(f"‚ùå No {target_label} were extracted.")
    print(f"Last Phase : {result.get('phase')}")
    print(f"Attempts   : {result.get('attempt')}")

print("="*40)



=== STATE SNAPSHOT ===

--- MESSAGES ---
0: HUMAN -> Find PERSON_NAME in the database
1: AI -> SELECT ZCONTACTNAME FROM ZWAGROUPMEMBER WHERE ZCONTACTNAME REGEXP '[A-Za-z][A-Za-z\s\.\-]{1,50}'
UNION ALL
SELECT ZPARTNERNAME FROM ZWACHATSESSION WHERE ZPARTNERNAME REGEXP '[A-Za-z][A-Za-z\s\.\-]{1,50}'
UNION ALL
SELECT ZAUTHORNAME FROM ZWAMEDIAITEM WHERE ZAUTHORNAME REGEXP '[A-Za-z][A-Za-z\s\.\-]{1,50}'
UNION ALL
SELECT ZTEXT FROM ZWAMESSAGE WHERE ZTEXT REGEXP '[A-Za-z][A-Za-z\s\.\-]{1,50}'
UNION ALL
SELECT ZPUSHNAME FROM ZWAPROFILEPUSHNAME WHERE ZPUSHNAME REGEXP '[A-Za-z][A-Za-z\s\.\-]{1,50}'
LIMIT 15;

--- BEGIN METADATA ---
attempt         : 1
max_attempts    : 3
phase           : exploration
target_entity   : PERSON_NAME
exploration_sql : SELECT ZCONTACTNAME FROM ZWAGROUPMEMBER WHERE ZCONTACTNAME REGEXP '[A-Za-z][A-Za-z\s\.\-]{1,50}'
UNION ALL
SELECT ZPARTNERNAME FROM ZWACHATSESSION WHERE ZPARTNERNAME REGEXP '[A-Za-z][A-Za-z\s\.\-]{1,50}'
UNION ALL
SELECT ZAUTHORNAME FROM ZWAMEDIAITEM 