In [None]:
from __future__ import annotations
import os
import json
import time
from typing import TypedDict, List, Dict, Any, Optional, Literal

from pyspark.sql import SparkSession
from pyspark.sql.functions import col

from pydantic import BaseModel

# LangChain / LangGraph
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.graph import StateGraph, END

from dotenv import load_dotenv
import os

In [None]:
# ----------------------------
# 1) Spark Setup + Mock DataFrames
# ----------------------------

spark = SparkSession.builder.appName("FraudInvestigationMock").getOrCreate()

# Mock payments DataFrame
payments_df = spark.createDataFrame([
("ALERT-001", 18250.0, "INR", "STELLAR GADGETS HK", 5732, "ecom", "HKG", "CUST-99231", "D-ABCD-123", "203.0.113.54", "2025-08-18T22:44:21Z", ["velocity_alert", "geo_anomaly"]),
("ALERT-002", 50000.0, "INR", "LOCAL MART IND", 5411, "pos", "IND", "CUST-99231", "D-ABCD-123", "203.0.113.54", "2025-08-19T10:14:00Z", []),
], ["txn_id", "amount", "currency", "merchant", "mcc", "channel", "country", "customer_id", "device_id", "ip", "timestamp", "flags"])

# Mock customers DataFrame
customers_df = spark.createDataFrame([
("CUST-99231", "FULL", "Retail-Premium", 4.2, "k.k@example.com", "+91-98xxxxxx", "IND", False, 1500, 45000, 0),
], ["customer_id", "kyc_level", "risk_segment", "tenure_years", "email", "phone", "home_country", "vip", "account_age_days", "avg_monthly_spend_inr", "chargeback_count_12m"])

# Mock devices DataFrame
devices_df = spark.createDataFrame([
("D-ABCD-123", "low", True, False, 5, "8.3.1"),
], ["device_id", "reputation", "seen_before", "rooted", "distance_km_from_home", "app_version"])


In [None]:
# ----------------------------
# 2) Shared State
# ----------------------------
class CaseState(TypedDict):
    case_id: str
    input: Dict[str, Any]
    triage_summary: str
    enrichment: Dict[str, Any]
    rule_hits: List[str]
    model_score: float
    investigation_notes: str
    decision: Literal["approve", "deny", "escalate", "monitor", "unknown"]
    escalate_to: Optional[str]
    audit_log: List[Dict[str, Any]]


def add_audit(state: CaseState, event: str, payload: Dict[str, Any] | None = None) -> CaseState:
    entry = {"ts": time.strftime("%Y-%m-%d %H:%M:%S"), "event": event, "payload": payload or {}}
    state.setdefault("audit_log", []).append(entry)
    return state


In [None]:
# ----------------------------
# 3) Spark-based fetch functions
# ----------------------------
def fetch_payment(case_id: str) -> Dict[str, Any]:
    row = payments_df.filter(col("txn_id") == case_id).first()
    return row.asDict() if row else {}

def fetch_customer(customer_id: str) -> Dict[str, Any]:
    row = customers_df.filter(col("customer_id") == customer_id).first()
    return row.asDict() if row else {}

def fetch_device(device_id: str) -> Dict[str, Any]:
    row = devices_df.filter(col("device_id") == device_id).first()
    return row.asDict() if row else {}

In [None]:
# Load variables from .env into environment
load_dotenv()

# Retrieve the key
LLM_API_KEY = os.getenv("GOOGLE_API_KEY")

print("Loaded API Key:", LLM_API_KEY is not None)

In [None]:
# ----------------------------
# 4) LLM Setup
# ----------------------------
def make_llm(temperature: float = 0.2):
    return ChatGoogleGenerativeAI(
        model="gemini-2.5-flash",
        google_api_key=LLM_API_KEY,
        temperature=0)

SYSTEM_BASE = (
"You are a bank fraud investigation assistant. Be precise, cautious, and explain reasoning succinctly."
)

def call_llm(messages: List[Dict[str, str]], temperature: float = 0.2) -> str:
    llm = make_llm(temperature)
    resp = llm.invoke([
        SystemMessage(content=SYSTEM_BASE)
    ] + [HumanMessage(content=m["content"]) for m in messages])
    return resp.content


In [None]:
# ----------------------------
# 5) Agents
# ----------------------------
def triage_agent(state: CaseState) -> CaseState:
    payment = fetch_payment(state["case_id"])
    prompt = f"Summarize suspicious aspects of payment {json.dumps(payment)}"
    summary = call_llm([{"role": "user", "content": prompt}])
    state["triage_summary"] = summary
    state["input"] = payment
    return add_audit(state, "triage_complete", {"summary": summary})

def enrichment_agent(state: CaseState) -> CaseState:
    p = state["input"]
    cust = fetch_customer(p["customer_id"])
    dev = fetch_device(p["device_id"])
    enrichment = {"customer": cust, "device": dev}
    state["enrichment"] = enrichment
    return add_audit(state, "enrichment_complete", enrichment)

def rules_agent(state: CaseState) -> CaseState:
    p, e = state["input"], state["enrichment"]
    hits = []
    if p["amount"] > 15000 and p["channel"] == "ecom":
        hits.append("RULE:HIGH_AMOUNT_ECOM")
    if not e["device"]["seen_before"]:
        hits.append("RULE:NEW_DEVICE")
    if p["country"] != e["customer"]["home_country"]:
        hits.append("RULE:GEO_MISMATCH")
    state["rule_hits"] = hits
    return add_audit(state, "rules_complete", {"rule_hits": hits})

def model_agent(state: CaseState) -> CaseState:
    p, e = state["input"], state["enrichment"]
    feats = {"amount": p["amount"], "channel": p["channel"], "new_device": not e["device"]["seen_before"]}
    prompt = f"Estimate fraud probability (0-1) for {json.dumps(feats)}"
    score_text = call_llm([{"role": "user", "content": prompt}], temperature=0.0)
    try:
        score = float(score_text.strip().split()[0])
    except Exception:
        score = 0.5
    state["model_score"] = score
    return add_audit(state, "model_scored", {"score": score})

def investigator_agent(state: CaseState) -> CaseState:
    p, hits, s = state["input"], state["rule_hits"], state["model_score"]
    prompt = f"Summarize investigation for {p}, rules {hits}, score {s}"
    note = call_llm([{"role": "user", "content": prompt}])
    state["investigation_notes"] = note
    return add_audit(state, "investigation_complete", {"notes": note})

def decision_agent(state: CaseState) -> CaseState:
    s = state["model_score"]
    decision = "approve"
    if s > 0.75:
        decision = "deny"
    elif s >= 0.45:
        decision = "escalate"
    state["decision"] = decision
    return add_audit(state, "decision_ready", {"decision": decision, "score": s})


In [None]:
# ----------------------------
# 6) Graph Wiring
# ----------------------------
def build_graph():
    graph = StateGraph(CaseState)
    graph.add_node("triage", triage_agent)
    graph.add_node("enrichment", enrichment_agent)
    graph.add_node("rules", rules_agent)
    graph.add_node("model", model_agent)
    graph.add_node("investigator", investigator_agent)
    graph.add_node("decision", decision_agent)

    graph.set_entry_point("triage")
    graph.add_edge("triage", "enrichment")
    graph.add_edge("enrichment", "rules")
    graph.add_edge("rules", "model")
    graph.add_edge("model", "investigator")
    graph.add_edge("investigator", "decision")
    graph.add_edge("decision", END)
    return graph.compile()

In [None]:
# ----------------------------
# 7) Batch Runner
# ----------------------------
def run_batch():
    app = build_graph()
    results = []
    case_ids = [row["txn_id"] for row in payments_df.collect()]
    for cid in case_ids:
        init: CaseState = {"case_id": cid, "input": {}, "triage_summary": "", "enrichment": {}, "rule_hits": [], "model_score": 0.0, "investigation_notes": "", "decision": "unknown", "escalate_to": None, "audit_log": []}
        final_state = app.invoke(init)
        results.append(final_state)
    return results


In [None]:
try:
    batch_results = run_batch()
    for res in batch_results:
        print("=== CASE:", res["case_id"], "DECISION:", res.get("decision"), "SCORE:", res.get("model_score"))
        print("RULES:", res.get("rule_hits"))
        print("NOTES:\n", res.get("investigation_notes"))
        print("-"*50)
except RuntimeError as e:
    print("Setup error:", e)

In [None]:
batch_results