In [None]:
# https://medium.com/@ayush4002gupta/building-an-llm-agent-to-directly-interact-with-a-database-0c0dd96b8196

import os
from pathlib import Path
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from sql_utils import *


load_dotenv()  # This looks for the .env file and loads it into os.environ

llm = ChatOpenAI(
    model="Qwen2.5-72B-Instruct-AWQ",  # recommended for tools + cost
    api_key=os.environ["API_KEY"],
    base_url=os.environ["VAST_BASE_URL"],
    temperature=0
)

response = llm.invoke([
    HumanMessage(content="Reply with exactly: OK")
])

print(response.content)

DB_PATH = None

ENTITY_CONFIG = {
    "EMAIL": {
        "regex": r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}\b",
        "desc": "valid electronic mail formats used for account registration or contact"
    },
    "PHONE": {
        "regex": r"\+?(?:[0-9]{1,4}[- .]?)?\(?[0-9]{2,4}\)?[- .]?[0-9]{3,4}[- .]?[0-9]{3,9}",
        "desc": "international or local telephone numbers"
    },
    "USERNAME": {
        "regex": r"\b[a-zA-Z][a-zA-Z0-9._-]{2,31}\b",
        "desc": "application-specific usernames, handles, or account identifiers"
    },
    "PERSON_NAME": {
        # "regex": r"[A-Za-z\u4e00-\u9fff][A-Za-z\u4e00-\u9fff\s\.\-]{1,50}",
        "regex": r"(?:\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+){1,3}\b|[\u4e00-\u9fff]{2,4})",
        "desc": (
            "loosely structured human name-like strings used only for discovery "
            "and column pre-filtering; final identification is performed during extraction"
        )
    }
}



OK


In [86]:
# Core Python
import sqlite3
import re
import json
from typing import TypedDict, Optional, List, Annotated
from langgraph.graph.message import add_messages

# LangChain / LangGraph
from langchain_core.tools import tool
from langchain_core.messages import (
    HumanMessage,
    AIMessage,
    SystemMessage
)
from langchain.agents import create_agent
from langgraph.graph import StateGraph, END
from langgraph.graph.message import MessagesState


@tool
def list_tables() -> str:
    """
    List all table names in the SQLite database.
    """
    conn = sqlite3.connect(DB_PATH)
    cur = conn.cursor()
    cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [r[0] for r in cur.fetchall()]
    conn.close()
    return ", ".join(tables)


@tool
def get_schema(table: str) -> str:
    """
    Return column names and types for a table.
    """
    conn = sqlite3.connect(DB_PATH)
    cur = conn.cursor()
    cur.execute(f"PRAGMA table_info('{table}')")
    cols = cur.fetchall()
    conn.close()
    return ", ".join(f"{c[1]} {c[2]}" for c in cols)




@tool
def exec_sql(query: str) -> dict:
    """Execute SQL statements. If one fails, it is skipped and the next is executed."""
    query_text = normalize_sql(query)

    # 1. Parse column names from ALL SELECTs
    column_names = []
    for select_sql in split_union_selects(query_text):
        for col in extract_select_columns(select_sql):
            if col not in column_names:
                column_names.append(col)

    # 2. Execute once
    conn = sqlite3.connect(DB_PATH)
    conn.create_function("REGEXP", 2, regexp)
    cur = conn.cursor()

    try:
        print(f"[EXECUTE] Running query")
        cur.execute(query_text)
        rows = cur.fetchall()
    except Exception as e:
        print(f"[SQL ERROR]: {e}")
        rows = []
    finally:
        conn.close()

    return {
        "rows": rows,
        "columns": column_names
    }


class EmailEvidenceState(TypedDict):
    messages: Annotated[list, add_messages]
    attempt: int
    max_attempts: int
    phase: str            # "discovery" | "extraction"
    sql: Optional[str]   # SQL to execute
    rows: Optional[List]
    classification: Optional[dict]
    evidence: Optional[List[str]]
    target_entity: str     # <--- ADD THIS LINE   
    # Add this to track the "winning" columns
    source_columns: Optional[List[dict]]

    # SQL used during discovery that returned results
    discovered_sql: Optional[List[str]]

def get_discovery_system(target, regex):
    return SystemMessage(
        content=(
            "You are a SQL planner. You are provided databases that are extracted from Android or iOS devices.\n"
            f"Goal: discover if any column contains {target} from databases.\n\n"
            "Rules:\n"
            "- Use 'REGEXP' for pattern matching.\n"
            f"- Example: SELECT col FROM table WHERE col REGEXP '{regex}' LIMIT 10\n"
            f"- pay special attention to tables and/or columns related to message/chat/text. {target} may be embedded in these text.\n"
            "- Validate your SQL and make sure all tables and columns do exist.\n"
            "- If multiple SQL statements are provided, combine them using UNION ALL.\n"
            "- Return ONLY SQL."
        )
    )

    
def get_sql_upgrade_system(target):
    return SystemMessage(
        content=(
            "You are a SQL refiner.\n"
            f"Goal: modify previously successful SQL to extract ALL {target}.\n\n"
            "Rules:\n"
            "- Do NOT invent new tables or columns.\n"
            "- Remove LIMIT clauses.\n"
            "- Preserve WHERE conditions and REGEXP filters.\n"
            "- Return ONLY SQL."
        )
    )

def select_relevant_tables(llm, all_tables: list[str], target: str) -> list[str]:
    """
    Ask the LLM to select relevant tables by name only.
    No schema access, no tools, no loops.
    """
    system = SystemMessage(
        content=(
            "You are a digital forensics assistant.\n"
            f"Target evidence type: {target}.\n\n"
            "From the list of table names below, select the relevant tables "
            "likely to contain this evidence.\n\n"
            "Return ONLY a JSON array of table names.\n"
            "If unsure, return an empty array."
        )
    )

    result = llm.invoke([
        system,
        HumanMessage(content="\n".join(all_tables))
    ]).content

    tables = safe_json_loads(result, default=[])

    # Defensive cleanup
    if not isinstance(tables, list):
        return []

    return [t for t in tables if t in all_tables]


def planner(state: EmailEvidenceState):
    # ---------- EXTRACTION PHASE: upgrade SQL ----------
    if state["phase"] == "extraction" and state.get("discovered_sql"):
        system = get_sql_upgrade_system(state["target_entity"])
        joined_sql = "\nUNION ALL\n".join(state["discovered_sql"])

        sql = normalize_sql(
            llm.invoke([
                system,
                HumanMessage(content=f"Original SQL:\n{joined_sql}")
            ]).content
        )

        print("[PLANNER] Upgraded SQL for extraction")

        return {
            "messages": [AIMessage(content=sql)],
            "sql": sql
        }

    # ---------- DISCOVERY PHASE ----------
    # 1. List tables once
    all_tables = [t.strip() for t in list_tables.invoke({}).split(",")]

    # 2. Agent selects relevant tables by NAME ONLY
    selected_tables = select_relevant_tables(
        llm,
        all_tables,
        state["target_entity"]
    )

    # Fallback: ensure coverage
    if not selected_tables:
        selected_tables = all_tables[:10]

    # 3. Fetch schema deterministically
    schemas = {
        table: get_schema.invoke({"table": table})
        for table in selected_tables
    }

    # 4. Build grounded prompt
    config = ENTITY_CONFIG[state["target_entity"]]

    system = get_discovery_system(
        state["target_entity"],
        config["regex"]
    )

    grounded_content = (
        f"{system.content}\n\n"
        f"ALLOWED TABLES: {', '.join(selected_tables)}\n"
        f"SCHEMA:\n{json.dumps(schemas, indent=2)}\n"
        f"CURRENT PHASE: {state['phase']}\n"
        "CRITICAL: Use ONLY the tables and columns listed above."
    )

    # 5. Single LLM call to generate SQL
    sql = normalize_sql(
        llm.invoke([
            SystemMessage(content=grounded_content),
            state["messages"][0]  # original user request
        ]).content
    )

    attempt = state["attempt"] + 1

    return {
        "messages": [AIMessage(content=sql)],
        "sql": sql,
        "attempt": attempt
    }


def sql_execute(state: EmailEvidenceState):
    # Call the tool (it now returns a dict)
    result = exec_sql.invoke(state["sql"])
    
    rows = result.get("rows", [])
    cols = result.get("columns", [])

    print(f"[SQL EXEC] Retrieved {(rows)}")
    updates = {
        "rows": rows,
        "messages": [AIMessage(content=f"Retrieved {len(rows)} rows")]
    }

    if state["phase"] == "discovery" and rows:
        discovered = list(state.get("discovered_sql", []))
        discovered.append(state["sql"])
        updates["discovered_sql"] = discovered
        print("[DISCOVERY] Saved successful SQL")

    # Tracking logic: Save columns to state only during extraction
    if state["phase"] == "extraction":
        updates["source_columns"] = cols
        print(f"[TRACKING] Saved source columns: {cols}")

    return updates
    

def get_classify_system(target: str):
    return SystemMessage(
        content=(
            f"Decide whether the text contains {target}.\n"
            "Return ONLY a JSON object with these keys:\n"
            "{ \"found\": true/false, \"confidence\": number, \"reason\": \"string\" }"
        )
    )

def classify(state: EmailEvidenceState):
    # 1. Prepare the text sample for the LLM
    text = rows_to_text(state["rows"], limit=10)
    
    # 2. Get the target-specific system message
    system_message = get_classify_system(state["target_entity"])

    # 3. Invoke the LLM
    result = llm.invoke([
        system_message,
        HumanMessage(content=f"Data to analyze:\n{text}")
    ]).content
    
# 4. Parse the decision
    decision = safe_json_loads(
        result,
        default={"found": False, "confidence": 0.0, "reason": "parse failure"}
    )

    # print("[CLASSIFY]", decision)
    return {"classification": decision}


def switch_to_extraction(state: EmailEvidenceState):
    print("[PHASE] discovery ‚Üí extraction")
    return {"phase": "extraction"}




def extract(state: EmailEvidenceState):
    text = rows_to_text(state["rows"])
    system = f"Extract and normalize {state['target_entity']} from text. Return ONLY a JSON array of strings."
    result = llm.invoke([SystemMessage(content=system), HumanMessage(content=text)]).content
    return {"evidence": safe_json_loads(result, default=[])}


def next_step(state: EmailEvidenceState):
    # Once in extraction phase, extract and stop
    if state["phase"] == "extraction":
        return "do_extract"

    c = state["classification"]

    if c["found"] and c["confidence"] >= 0.6:
        return "to_extraction"

    if not c["found"] and c["confidence"] >= 0.6:
        return "stop_none"

    if state["attempt"] >= state["max_attempts"]:
        return "stop_limit"

    return "replan"

In [87]:
def observe(state: EmailEvidenceState):
    """
    Debug / inspection node.
    Does NOT modify state.
    """

    print("\n=== STATE SNAPSHOT ===")

    # Messages
    print("\n--- MESSAGES ---")
    for i, m in enumerate(state["messages"]):
        print(f"{i}: {m.type.upper()} -> {m.content}")

    # Metadata
    print("\n--- BEGIN METADATA ---")
    print(f"attempt       : {state['attempt']}")
    print(f"max_attempts : {state['max_attempts']}")
    print(f"phase         : {state['phase']}")
    print(f"sql           : {state['sql']}")
    print(f"discovered sql           : {state['discovered_sql']}")
    print(f"rows          : {state['rows']}")
    print(f"classification: {state['classification']}")
    print(f"evidence        : {state['evidence']}")
    
    print(f"Source Columns: {state.get('source_columns')}")
    print("\n--- END METADATA ---")

    # IMPORTANT: do not return state, return no-op update
    return {}



from langgraph.graph import StateGraph, END

graph = StateGraph(EmailEvidenceState)

# Define nodes (reusing the same 'observe' function for two different node names)
graph.add_node("planner", planner)
graph.add_node("observe_plan", observe)      # Checkpoint 1: The SQL Plan
graph.add_node("execute", sql_execute)
graph.add_node("classify", classify)
graph.add_node("observe_classify", observe)  # Checkpoint 2: The Logic/Discovery
graph.add_node("switch_phase", switch_to_extraction)
graph.add_node("extract", extract)
graph.add_node("observe_final", observe)     # Checkpoint 3: The Results

graph.set_entry_point("planner")

# --- THE FLOW ---
graph.add_edge("planner", "observe_plan")   # Check SQL before running
graph.add_edge("observe_plan", "execute")

graph.add_edge("execute", "classify")
graph.add_edge("classify", "observe_classify")

# The decision logic now triggers after the second observation
graph.add_conditional_edges(
    "observe_classify",  # Must match the new node name exactly
    next_step,
    {
        "to_extraction": "switch_phase",
        "do_extract": "extract",
        "replan": "planner",
        "stop_none": END,
        "stop_limit": END,
    }
)

graph.add_edge("switch_phase", "planner")

# Change this: Route 'extract' to our new observer instead of END
graph.add_edge("extract", "observe_final")
graph.add_edge("observe_final", END)

app = graph.compile()





In [88]:
BASE_OUTPUT_DIR = r"C:\Users\SAfolabi\Code\results"

# Map targets to output folder names
OUTPUT_FOLDERS = {
    "EMAIL": "emails",
    "PHONE": "phones",
    "USERNAME": "usernames",
    "PERSON_NAME": "names"
}

def process_database(db_path, target_entity, app_instance):
    """
    Runs the forensic extraction on a single database file.
    """
    global DB_PATH
    DB_PATH = db_path
    
    print(f"\nüöÄ Processing: {os.path.basename(db_path)}...")
    
    try:
        # 1. Run the extraction using the provided app instance
        result = app_instance.invoke({
            "messages": [HumanMessage(content=f"Find {target_entity} in the database")],
            "attempt": 0,
            "max_attempts": 3,
            "phase": "discovery",
            "target_entity": target_entity,
            "sql": None,
            "rows": None,
            "classification": None,
            "evidence": [],
            "source_columns": [],
            "discovered_sql": []
        })
        
        # 2. Prepare Output Data
        final_evidence = result.get("evidence", [])
        output_data = {
            "database_file": db_path,
            "target_entity": target_entity,
            "status": "success" if final_evidence else "no_evidence_found",
            "evidence_count": len(final_evidence),
            "evidence": final_evidence,
            "source_columns": result.get("source_columns", []),
            "attempts_used": result.get("attempt"),
            "final_phase": result.get("phase")
        }

        # 3. Handle File Paths safely
        db_filename = Path(db_path).stem
        safe_db = re.sub(r'[ .]', '_', db_filename)
        safe_target = target_entity.replace(" ", "_")
        
        # Determine output directory
        subfolder = OUTPUT_FOLDERS.get(target_entity, "items")
        output_dir = os.path.join(BASE_OUTPUT_DIR, subfolder)
        os.makedirs(output_dir, exist_ok=True)
        
        filename = os.path.join(output_dir, f"{safe_db}_{safe_target}.json")

        # 4. Save JSON
        with open(filename, "w", encoding="utf-8") as f:
            json.dump(output_data, f, indent=4)

        # 5. Print Forensic Report
        target_label = result.get("target_entity", "items")

        print("\n" + "="*40)
        print(f"       üèÅ FORENSIC REPORT: {target_label.upper()}       ")
        print("="*40)

        if final_evidence:
            print(f"‚úÖ Found {len(final_evidence)} unique {target_label} in {os.path.basename(db_path)}")
            print(f"üìÅ Saved to: {filename}")
            for i, item in enumerate(sorted(final_evidence), 1):
                print(f"  {i}. {item}")
            
            # Also print the source columns we tracked!
            sources = result.get("source_columns")
            if sources:
                print(f"\nSource Columns: {', '.join(sources)}")
        else:
            print(f"‚ùå No evidence found in {os.path.basename(db_path)}")
            print(f"‚ùå No {target_label} were extracted.")
            print(f"Last Phase : {result.get('phase')}")
            print(f"Attempts   : {result.get('attempt')}")
        print("-" * 40)

    except Exception as e:
        print(f"‚ö†Ô∏è Error processing {os.path.basename(db_path)}: {e}")
        
        # Error Logging
        error_data = {"database": db_path, "error": str(e)}
        error_file = f"ERROR_{os.path.basename(db_path)}.json"
        with open(error_file, "w", encoding="utf-8") as f:
            json.dump(error_data, f, indent=4)

In [89]:
%%capture cap

# Set your target here once
# TARGET = "EMAIL" 
# TARGET = "PHONE"
# TARGET = "USERNAME"
TARGET = "PERSON_NAME" 

# 2. Define where your database files are located
SOURCE_DIRECTORY = r"D:\Temp" 

# 3. Get all DB files (adjust extensions as needed)
files_to_process = []
for root, dirs, files in os.walk(SOURCE_DIRECTORY):
    for file in files:
        if file.lower().endswith(('.db', '.sqlite', '.sqlite3', '.sqlitedb')):
            files_to_process.append(os.path.join(root, file))

print(f"Found {len(files_to_process)} databases to process for {TARGET}.\n")

# 4. Loop through them
for db_file in files_to_process:
    # We pass 'app' here assuming your LangChain app object is already initialized globally
    process_database(db_file, TARGET, app)

print("\n Batch processing complete.")

In [90]:
filecapture = fr"C:\Users\SAfolabi\Code\cell output\output_addressbook_name.txt"

with open(filecapture, "w", encoding="utf-8") as f:
    f.write(cap.stdout)