In [1]:
import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from sql_utils import *
from datetime import datetime, timezone
from pathlib import Path

load_dotenv()  # This looks for the .env file and loads it into os.environ

llm = ChatOpenAI(
    model="gpt-4o-mini",  # recommended for tools + cost
    api_key=os.environ["API_KEY"],
    temperature=0,
    seed=100,
)

response = llm.invoke([
    HumanMessage(content="Reply with exactly: OK")
])

print(response.content)

OK


In [None]:
# Core Python
import sqlite3
import re
import json
from typing import TypedDict, Optional, List, Annotated
from langgraph.graph.message import add_messages

# LangChain / LangGraph
from langchain_core.tools import tool
from langchain_core.messages import (
    HumanMessage,
    AIMessage,
    SystemMessage
)
from langchain.agents import create_agent
from langgraph.graph import StateGraph, END
from langgraph.graph.message import MessagesState
from sql_utils import *


@tool
def list_tables() -> str:
    """
    List non-empty user tables in the SQLite database.
    """
    IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
    conn = sqlite3.connect(DB_PATH)
    try:
        cur = conn.cursor()
        cur.execute("""
            SELECT name
            FROM sqlite_master
            WHERE type='table' AND name NOT LIKE 'sqlite_%'
            ORDER BY name
        """)
        tables = [r[0] for r in cur.fetchall()]

        nonempty = []
        for t in tables:
            # If your DB has weird table names, remove this guard,
            # but keep the quoting below.
            if not IDENT_RE.match(t):
                continue
            try:
                cur.execute(f'SELECT 1 FROM "{t}" LIMIT 1;')
                if cur.fetchone() is not None:
                    nonempty.append(t)
            except sqlite3.Error:
                continue

        return ", ".join(nonempty)
    finally:
        conn.close()


@tool
def get_schema(table: str) -> str:
    """

    Return column names and types for a table.
    """
    conn = sqlite3.connect(DB_PATH)
    cur = conn.cursor()
    cur.execute(f"PRAGMA table_info('{table}')")
    cols = cur.fetchall()
    conn.close()
    return ", ".join(f"{c[1]} {c[2]}" for c in cols)


@tool
def exec_sql(
    query: str,
    db_path: str,
    top_n: int = 10,
    verbose: bool = True,
) -> dict:
    """
    Execute a UNION ALL query by splitting into individual SELECT statements.
    Runs each SELECT with LIMIT top_n, skipping any SELECT that errors.

    Returns:
        rows_all: list of rows (combined from all successful chunks)
        column_names: list of column names (deduped), prefixed as table.column when possible
    """
    
    query_text = normalize_sql(query)
    selects = split_union_selects(query_text)

    rows_all = []
    column_names = []

    conn = sqlite3.connect(db_path)
    conn.create_function("REGEXP", 2, regexp)
    cur = conn.cursor()

    try:
        for i, select_sql in enumerate(selects, 1):
            select_sql_clean = select_sql.rstrip().rstrip(";")
            select_sql_run = f"{select_sql_clean}\nLIMIT {top_n};"

            if verbose:
                print(f"[EXECUTE] chunk {i}/{len(selects)} LIMIT {top_n}")
                # print(select_sql_run)  # uncomment to print full SQL

            try:
                cur.execute(select_sql_run)
                chunk = cur.fetchall()
                rows_all.extend(chunk)

                # collect columns only if chunk succeeded
                tbl = extract_single_table(select_sql_clean)
                for col in extract_select_columns(select_sql_clean):
                    name = f"{tbl}.{col}" if (tbl and "." not in col) else col
                    if name not in column_names:
                        column_names.append(name)

            except Exception as e:
                if verbose:
                    print(f"[SQL ERROR] Skipping chunk {i}: {e}")

    finally:
        conn.close()

    return {
        "rows": rows_all,
        "columns": column_names
    }


from typing import Any, TypedDict
class EvidenceState(TypedDict):
    database_name: str
    messages: Annotated[list, add_messages]
    attempt: int
    max_attempts: int
    phase: str  # "exploration" | "extraction"

    # SQL separation
    exploration_sql: Optional[str]
    extraction_sql: Optional[str]

    rows: Optional[List]
    classification: Optional[dict]
    evidence: Optional[List[str]]

    source_columns: Optional[List[str]]  
    entity_config: dict[str, Any]


def get_explore_system(type, regex):
    return SystemMessage(
        content=(
            "You are a SQL planner. You are provided app databases that are extracted from Android or iPhone devices.\n"
            "apps include Android Whatsapp, Snapchat, Telegram, Google Map, Samsung Internet, iPhone Contacts, Messages, Safari, and Calendar.\n"
            f"Goal: discover if any column of databases contains possible {type}.\n\n"
            "Rules:\n"
            "- Use 'REGEXP' for pattern matching.\n"
            f"- Example: SELECT col FROM table WHERE col REGEXP '{regex}' \n"
            "- Table and col names can be used as hints to find solutions. \n"
            "- Include the tables and columns even there is a small possility of containing solutions.\n"
            "- Pay attention to messages, chats, or other text fields.\n"
            "- Validate your SQL and make sure all tables and columns do exist.\n"
            "- If multiple SQL statements are provided, combine them using UNION ALL. \n"
            f"- Example: SELECT col1 FROM table1 WHERE col1 REGEXP '{regex}' UNION ALL SELECT col2 FROM table2 WHERE col2 REGEXP '{regex}'\n"
            "- Make sure all tables and columns do exist before return SQL. \n"
            "- Return ONLY SQL."
        )
    )

def planner(state: EvidenceState):
    # Extraction upgrade path
    if state["phase"] == "extraction" and state.get("exploration_sql"):
        extraction_sql = upgrade_sql_remove_limit(state["exploration_sql"])
        return {
            "messages": [AIMessage(content=extraction_sql)],
            "extraction_sql": extraction_sql
        }

    # Optional safety stop inside planner too
    if state.get("phase") == "exploration" and state.get("attempt", 0) >= state.get("max_attempts", 0):
        return {
            "phase": "done",
            "messages": [AIMessage(content="STOP: max attempts reached in planner.")]
        }
    # Original discovery logic
    tables = list_tables.invoke({})
    config = state["entity_config"]

    base_system = get_explore_system(
        f"{config.get('type','')}:{config.get('desc','')}".strip(),
        config["regex"]
    )

    grounded_content = (
        f"{base_system.content}\n\n"
        f"EXISTING TABLES: {tables}\n"
        f"CURRENT PHASE: {state['phase']}\n"
        "CRITICAL: Do not query non-existent tables."
    )

    agent = create_agent(llm, [list_tables,get_schema])
    
    result = agent.invoke({
        "messages": [
            SystemMessage(content=grounded_content),
            state["messages"][0]  # original user request only
        ]
    })

    exploration_sql  = normalize_sql(result["messages"][-1].content)

    attempt = state["attempt"] + 1 if state["phase"] == "exploration" else state["attempt"]

    return {
        "messages": [AIMessage(content=exploration_sql)],
        "exploration_sql": exploration_sql,
        "attempt": attempt
    }

def sql_execute(state: EvidenceState):
    top_n=10
    # Choose SQL based on phase
    if state["phase"] == "extraction":
        sql_to_run = state.get("extraction_sql")
        top_n=10000
    else:  # "exploration"
        sql_to_run = state.get("exploration_sql")
        top_n=10

    if not sql_to_run:
        print("[SQL EXEC] No SQL provided for this phase")
        return {
            "rows": [],
            "messages": [AIMessage(content="No SQL to execute")]
        }

    # Execute
    result = result = exec_sql.invoke({
    "query": sql_to_run,
    "db_path": state["database_name"],
    "top_n": top_n,
    "verbose": False
})


    rows = result.get("rows", [])
    cols = result.get("columns", [])

    print(f"[SQL EXEC] Retrieved {len(rows)} rows")
    
    # for i, r in enumerate(rows, 1):
    #     print(f"  row[{i}]: {r}")

    updates = {
        "rows": rows,
        "messages": [AIMessage(content=f"Retrieved {len(rows)} rows")]
    }

    # Track columns only during extraction (provenance)
    if state["phase"] == "extraction":
        updates["source_columns"] = cols
        print(f"[TRACKING] Saved source columns: {cols}")

    return updates

    

def classify(state: EvidenceState):
    # 1. Prepare the text sample for the LLM
    text = rows_to_text(state["rows"], limit=15)
    
    # 2. Get the pii-specific system message
    config= state["entity_config"]
    pii_desc= f"{config.get('type','')}:{config.get('desc','')}".strip()
    system_message = SystemMessage(
        content=(
            f"Decide whether the text contains {pii_desc}.\n"
            "Return ONLY a JSON object with these keys:\n"
            "{ \"found\": true/false, \"confidence\": number, \"reason\": \"string\" }"
        )
    )

    # 3. Invoke the LLM
    result = llm.invoke([
        system_message,
        HumanMessage(content=f"Data to analyze:\n{text}")
    ]).content
    
# 4. Parse the decision
    decision = safe_json_loads(
        result,
        default={"found": False, "confidence": 0.0, "reason": "parse failure"}
    )

    # print("[CLASSIFY]", decision)
    return {"classification": decision}


def switch_to_extraction(state: EvidenceState):
    print("[PHASE] discovery ‚Üí extraction")
    return {"phase": "extraction"}


def extract(state: EvidenceState):
    text = rows_to_text(state["rows"])
    # print(f"Check last 100 characts : {text[:-100]}")
    desc = state["entity_config"].get("desc", "PII")
    system = f"Identify real {desc} from text and normalize them. Return ONLY a JSON array of strings.\n"

    result = llm.invoke([SystemMessage(content=system), HumanMessage(content=text)]).content
    return {"evidence": safe_json_loads(result, default=[])}


def next_step(state: EvidenceState):
    # Once in extraction phase, extract and stop
    if state["phase"] == "extraction":
        return "do_extract"

    c = state["classification"]

    if c["found"] and c["confidence"] >= 0.6:
        return "to_extraction"

    if not c["found"] and c["confidence"] >= 0.6:
        return "stop_none"

    if state["attempt"] >= state["max_attempts"]:
        return "stop_limit"

    return "replan"

In [3]:
def observe(state: EvidenceState):
    """
    Debug / inspection node.
    Does NOT modify state.
    """
    print("\n=== STATE SNAPSHOT ===")

    # Messages
    print("\n--- MESSAGES ---")
    for i, m in enumerate(state["messages"]):
        print(f"{i}: {m.type.upper()} -> {m.content}")

    # Metadata
    print("\n--- BEGIN METADATA ---")
    print(f"attempt         : {state['attempt']}")
    print(f"max_attempts    : {state['max_attempts']}")
    print(f"phase           : {state['phase']}")
    print(f"PII type        : {state['entity_config'].get('type')}")

    # SQL separation
    print(f"exploration_sql : {state.get('exploration_sql')}")
    print(f"extraction_sql  : {state.get('extraction_sql')}")

    # Outputs
    rows = state.get("rows") or []
    print(f"rows_count      : {len(rows)}")
    print(f"rows_sample     : {rows[:1000] if rows else []}")  # small sample to avoid huge logs

    print(f"classification  : {state.get('classification')}")
    print(f"evidence_count  : {len(state.get('evidence') or [])}")
    print(f"evidence_sample : {(state.get('evidence') or [])[:10]}")

    print(f"source_columns  : {state.get('source_columns')}")
    print("\n--- END METADATA ---")

    # IMPORTANT: do not return state, return no-op update
    return {}



from langgraph.graph import StateGraph, END

graph = StateGraph(EvidenceState)

# Nodes
graph.add_node("planner", planner)
graph.add_node("observe_plan", observe)         # Checkpoint 1: The SQL Plan
graph.add_node("execute", sql_execute)
graph.add_node("observe_execution", observe)    # NEW Checkpoint: Post-execution
graph.add_node("classify", classify)
graph.add_node("observe_classify", observe)     # Checkpoint 2: Post-classify
graph.add_node("switch_phase", switch_to_extraction)
graph.add_node("extract", extract)
graph.add_node("observe_final", observe)        # Checkpoint 3: Final results

graph.set_entry_point("planner")

# --- FLOW ---
graph.add_edge("planner", "observe_plan")
graph.add_edge("observe_plan", "execute")

# NEW: observe after execution, before classify
graph.add_edge("execute", "observe_execution")
graph.add_edge("observe_execution", "classify")

graph.add_edge("classify", "observe_classify")

graph.add_conditional_edges(
    "observe_classify",
    next_step,
    {
        "to_extraction": "switch_phase",
        "do_extract": "extract",
        "replan": "planner",
        "stop_none": END,
        "stop_limit": END,
    }
)

graph.add_edge("switch_phase", "planner")

graph.add_edge("extract", "observe_final")
graph.add_edge("observe_final", END)

app = graph.compile()


In [4]:
def run_batch(db_paths, pii_targets, pii_config, app):
    all_results = []

    for p in db_paths:
        db_path = str(p)

        # If your tools rely on global DB_PATH, keep this line.
        # If you refactor tools to use state["database_name"], you can remove it.
        global DB_PATH
        DB_PATH = db_path

        print(f"\nProcessing: {db_path}")

        for target in pii_targets:
            entity_config = pii_config[target]
            print(f"  Processing: {target}")

            result = app.invoke({
                "database_name": db_path,
                "messages": [HumanMessage(content=f"Find {entity_config['desc'].strip()} in the database")],
                "attempt": 1,
                "max_attempts": 2,
                "phase": "exploration",
                "entity_config": entity_config,
                "exploration_sql": None,
                "extraction_sql": None,
                "rows": None,
                "classification": None,
                "evidence": [],
                "source_columns": []
            })

            evidence = result.get("evidence", [])
            source_columns = result.get("source_columns", [])
            raw_rows = result.get("rows", [])

            all_results.append({
                "db_path": db_path,
                "PII_type": target,
                "PII": evidence,
                "Num_of_PII": len(evidence),
                "source_columns": source_columns,
                "Raw_rows_first_100": raw_rows[:100],
                "Total_raw_rows": len(raw_rows),
                "Exploration_sql": result.get("exploration_sql", ""),
                "Extraction_sql": result.get("extraction_sql", "")
            })

    return all_results


def run_batch(db_paths, pii_targets, pii_config, app, out_dir: Path):
    """
    Process databases one-by-one and write one output file per database.
    """
    for p in db_paths:
        db_path = str(p)

        # If your tools rely on global DB_PATH, keep this line.
        global DB_PATH
        DB_PATH = db_path

        print(f"\nProcessing DB: {db_path}")

        db_results = []  # reset per-database

        for target in pii_targets:
            entity_config = pii_config[target]
            print(f"  Processing: {target}")

            result = app.invoke({
                "database_name": db_path,
                "messages": [HumanMessage(content=f"Find {entity_config['desc'].strip()} in the database")],
                "attempt": 1,
                "max_attempts": 2,
                "phase": "exploration",
                "entity_config": entity_config,
                "exploration_sql": None,
                "extraction_sql": None,
                "rows": None,
                "classification": None,
                "evidence": [],
                "source_columns": []
            })

            evidence = result.get("evidence", [])
            source_columns = result.get("source_columns", [])
            raw_rows = result.get("rows", [])

            db_results.append({
                "db_path": db_path,
                "PII_type": target,
                "PII": evidence,
                "Num_of_PII": len(evidence),
                "source_columns": source_columns,
                "Raw_rows_first_100": raw_rows[:100],
                "Total_raw_rows": len(raw_rows),
                "Exploration_sql": result.get("exploration_sql", ""),
                "Extraction_sql": result.get("extraction_sql", "")
            })

        # Save per-database output (includes db name + timestamp)
        save_jsonl(db_results, out_dir, db_path)


def main():
    cfg = load_config_yaml(Path("config.yaml"))

    DB_DIR = Path(cfg.get("db_dir", "selectedDBs"))
    OUT_DIR = Path(cfg.get("out_dir", "batch_results"))
    OUT_DIR.mkdir(exist_ok=True)

    CONFIG_PY = Path(cfg.get("config_py", "my_run_config.py"))
    vars_ = load_vars_from_py(CONFIG_PY, "db_files", "PII_CONFIG")
    db_files = vars_["db_files"]
    PII_CONFIG = vars_["PII_CONFIG"]

    PII_TARGETS = cfg.get("pii_targets", list(PII_CONFIG.keys()))

    db_paths, missing, not_sqlite = build_db_paths(DB_DIR, db_files, is_sqlite_file)
    print_db_path_report(db_paths, missing, not_sqlite)

    # Now run and save one file per DB (no global aggregation)
    run_batch(db_paths, PII_TARGETS, PII_CONFIG, app, OUT_DIR)


if __name__ == "__main__":
    main()


Will process 24 databases (from db_files list).

Processing DB: selectedDBs\A1_commerce.db
  Processing: EMAIL

=== STATE SNAPSHOT ===

--- MESSAGES ---
0: HUMAN -> Find a unique identifier for a destination to which electronic mail (email) can be sent and received over the internet or a private network in the database
1: AI -> SELECT locale FROM android_metadata WHERE locale REGEXP '[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'

--- BEGIN METADATA ---
attempt         : 2
max_attempts    : 2
phase           : exploration
PII type        : email address
exploration_sql : SELECT locale FROM android_metadata WHERE locale REGEXP '[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
extraction_sql  : None
rows_count      : 0
rows_sample     : []
classification  : None
evidence_count  : 0
evidence_sample : []
source_columns  : []

--- END METADATA ---
[SQL EXEC] Retrieved 0 rows

=== STATE SNAPSHOT ===

--- MESSAGES ---
0: HUMAN -> Find a unique identifier for a destination to which electronic 

TypeError: Object of type bytes is not JSON serializable