# XAI Fraud Agent

## Phase 1: The ML & Explainability Layer

### Import Libraries for XGBoost Model

In [11]:
import xgboost as xgb
import json
import pandas as pd
import numpy as np
import shap

### Preprocessing

In [12]:
def preprocess_transaction(transaction_row, preprocessor=None):
    """
    Helper function to preprocess a single transaction row.
    Returns processed DataFrame.
    """
    # Convert to DataFrame if it's a dict or Series
    if not isinstance(transaction_row, pd.DataFrame):
        df_input = pd.DataFrame([transaction_row])
    else:
        df_input = transaction_row.copy()
        
    # 1. Drop irrelevant columns
    cols_to_drop = ['month', 'device_fraud_count', 'fraud_bool']
    df_input = df_input.drop(columns=[c for c in cols_to_drop if c in df_input.columns], errors='ignore')
    
    # 2. Convert types
    for col in df_input.columns:
        if col not in ['payment_type', 'employment_status', 'housing_status', 'source', 'device_os']:
             df_input[col] = pd.to_numeric(df_input[col], errors='coerce')
                
    # 3. Handle Missing Values
    missing_cols = [
        "prev_address_months_count", "current_address_months_count",
        "bank_months_count", "session_length_in_minutes"
    ]
    for col in missing_cols:
        if col in df_input.columns:
             df_input[col] = df_input[col].replace(-1, np.nan)
    
    # 4. Apply OneHotEncoding
    if preprocessor:
        try:
            X_transformed = preprocessor.transform(df_input)
            return X_transformed
        except Exception as e:
            print(f"Preprocessing error: {e}")
            return None
    else:
        return df_input

### Predict Function

In [13]:
def predict(transaction_row, preprocessor, model_params_path='XGBoostModelParameters.json', model_path='XGBoostModel.json'):
    """
    Takes a single transaction row, preprocesses it, and returns the fraud probability.
    """
    # Load parameters
    try:
        with open(model_params_path, 'r') as file:
            loaded_params = json.load(file)
    except FileNotFoundError:
        print(f"Error: {model_params_path} not found.")
        return None

    X_transformed = preprocess_transaction(transaction_row, preprocessor)
    if X_transformed is None:
        return None
    
    X_numpy = X_transformed.to_numpy()

    # Load Model (Note: This assumes model file exists)
    try:
        model = xgb.XGBClassifier(**loaded_params)
        model.load_model(model_path)
    except Exception as e:
        print(f"Error loading model: {e}")
        return None
    
    # Inference
    try:
        probability = model.predict_proba(X_numpy)[0, 1]
        return float(probability)
    except Exception as e:
        print(f"Prediction error: {e}")
        return None


### Create SHAP Explanation Values

In [14]:
def get_shap_explanation(transaction_data, model, preprocessor):
    """
    Generates a SHAP explanation for a single transaction.
    Returns a dictionary with fraud probability and top 3 contributing features.
    """
    # Preprocess
    X_transformed = preprocess_transaction(transaction_data, preprocessor)
    if X_transformed is None:
        return {"error": "Preprocessing failed"}
    
    # Ensure we use DataFrame for column names in SHAP
    feature_names = preprocessor.get_feature_names_out() if hasattr(preprocessor, 'get_feature_names_out') else X_transformed.columns
    X_df = pd.DataFrame(X_transformed, columns=feature_names)
    
    # Calculate SHAP values
    explainer = shap.TreeExplainer(model)
    shap_values = explainer(X_df)
    
    # Get values for the first (and only) row
    # shap_values.values shape is (1, n_features)
    # Binary classification: some shap versions output values for both classes, some just one.
    # For XGBClassifier binary, it usually outputs log-odds for class 1.
    
    row_values = shap_values.values[0]
    # base_value = shap_values.base_values[0] # Not strictly needed for top 3
    data_values = X_df.iloc[0]
    
    # Calculate probability
    prob = model.predict_proba(X_df)[0, 1]
    
    # Identify top 3 features pushing score HIGHER (positive contribution to fraud class)
    # We want features that increase the probability of fraud.
    
    # Create list of (feature_name, shap_value, feature_value)
    contributions = []
    
    # Handle multi-class output shape if SHAP returns (1, n_features, 2)
    if len(row_values.shape) > 1:
        # Assuming class 1 is index 1
        row_values = row_values[:, 1]

    for name, val, feat_val in zip(X_df.columns, row_values, data_values):
        contributions.append((name, val, feat_val))
    
    # Sort by SHAP value descending (highest positive impact first)
    contributions.sort(key=lambda x: x[1], reverse=True)
    
    top_3 = contributions[:3]
    
    top_reasons = []
    for name, val, feat_val in top_3:
        # Clean up feature name (remove 'cat__' etc if present)
        clean_name = str(name).replace('cat__', '').replace('remainder__', '')
        
        # Format based on value type
        if isinstance(feat_val, (int, float)):
             reason = f"{clean_name} = {feat_val:.2f}"
        else:
             reason = f"{clean_name} = {feat_val}"
        
        top_reasons.append(reason)
        
    return {
        "score": float(prob),
        "top_reasons": top_reasons
    }

## Phase 2: The Data Environment

In [15]:
import sqlite3
import kagglehub
from faker import Faker

# Setup SQLLite connection 
connection = sqlite3.connect("Fraud_Agent.db")

# Download latest version
path = kagglehub.dataset_download("sgpjesus/bank-account-fraud-dataset-neurips-2022")

#print("Path to dataset files:", path)
# ensure we point to a .csv file (dataset_download may return a path without extension)
csv_path = str(path) + "/Base.csv"

# read the CSV into a DataFrame and setup the final test data
df_OG = pd.read_csv(csv_path)
mask = df_OG["month"] == 7

full_test_data = df_OG[mask].sample(frac=0.5).reset_index(drop=True).drop('month',axis=1) 

df = full_test_data

# Add 
fake = Faker()
num_users = len(df) // 8
user_ids = [f"USER_{i:04d}" for i in range(num_users)]

df['user_id'] = (user_ids * 9)[:len(df)]

print(df)

# Setup a Table in SQL
table_name = "transaction_history"
full_test_data.to_sql(table_name, connection, if_exists='replace', index=False)

# Verify the data was written by reading it back into a new DataFrame
query = f"SELECT * FROM {table_name}"
result_df = pd.read_sql_query(query, connection)
print("\nData read from SQLite table:")
#print(result_df)

# Close the database connection
connection.close()


       fraud_bool  income  name_email_similarity  prev_address_months_count  \
0               0     0.8               0.917751                         -1   
1               0     0.7               0.671499                         63   
2               0     0.9               0.856539                         -1   
3               0     0.6               0.997083                        116   
4               0     0.9               0.016758                         35   
...           ...     ...                    ...                        ...   
48417           0     0.4               0.448188                         -1   
48418           0     0.9               0.728931                         26   
48419           1     0.9               0.023149                         -1   
48420           0     0.9               0.242042                         -1   
48421           0     0.5               0.445138                         31   

       current_address_months_count  customer_age  

In [16]:
def get_db_size_pragma(db_path):
    """
    Gets the size of a SQLite database in bytes using PRAGMA statements.
    """
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # Get page count
    cursor.execute("PRAGMA page_count;")
    page_count = cursor.fetchone()[0]

    # Get page size
    cursor.execute("PRAGMA page_size;")
    page_size = cursor.fetchone()[0]

    conn.close()

    # Calculate total size in bytes
    size_in_bytes = page_count * page_size
    return size_in_bytes

db_file_path = "Fraud_Agent.db"
size = get_db_size_pragma(db_file_path)

print(f"The size of the database is: {size} bytes (via PRAGMA)")

The size of the database is: 13254656 bytes (via PRAGMA)


In [17]:
import os
#os.remove("Fraud_Agent.db")

### The Vector Store. 
Write a script to read that PDF, split it into chunks using LangChain's RecursiveCharacterTextSplitter, and save it into a local ChromaDB vector database.

In [18]:

import os
import chromadb
from langchain_chroma import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter

# Fix for SQLite on Mac (common issue with ChromaDB)
import sqlite3
import sys
if sys.platform.startswith('darwin'):
    try:
        __import__('pysqlite3')
        sys.modules['sqlite3'] = sys.modules['pysqlite3']
    except ImportError:
        pass

# Configuration
CHROMA_PATH = "./chroma_db"
CHROMA_COLLECTION_NAME = "bank_policies"
PDF_PATH = "Fraud_Detection_Policy.pdf"

def ingest_pdf():
    # 1. Load PDF
    if not os.path.exists(PDF_PATH):
        print(f"Error: {PDF_PATH} not found.")
        return

    print(f"Loading {PDF_PATH}...")
    loader = PyPDFLoader(PDF_PATH)
    documents = loader.load()

    # 2. Split Text
    print("Splitting text...")
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    chunked_documents = text_splitter.split_documents(documents)
    
    print(f"Created {len(chunked_documents)} chunks.")

    # 3. Initialize Embeddings
    print("Initializing embeddings...")
    embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

    # 4. Initialize Chroma Client
    print("Initializing ChromaDB client...")
    chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)

    # 5. Add to Chroma
    print(f"Adding documents to collection '{CHROMA_COLLECTION_NAME}'...")
    Chroma.from_documents(
        documents=chunked_documents,
        embedding=embedding_function,
        collection_name=CHROMA_COLLECTION_NAME,
        client=chroma_client,
    )
    
    print(f"Successfully added {len(chunked_documents)} chunks to ChromaDB at {CHROMA_PATH}")

if __name__ == "__main__":
    ingest_pdf()



Loading Fraud_Detection_Policy.pdf...
Splitting text...
Created 3 chunks.
Initializing embeddings...


Loading weights: 100%|██████████| 103/103 [00:00<00:00, 1504.80it/s, Materializing param=pooler.dense.weight]                             
[1mBertModel LOAD REPORT[0m from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


Initializing ChromaDB client...
Adding documents to collection 'bank_policies'...
Successfully added 3 chunks to ChromaDB at ./chroma_db


## Building the Agent's "Tools" 
Step 3.1: Database Tool.

In [19]:
@tool
def get_user_transactions(user_id: str):
    """
    Randomly select 5 transactions for this user.
    Use this tool when you need to verify if a user has a history of fraud
    transactions.
    """
    db_path = "Fraud_Agent.db"

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # SELECT any 5 transactions of the user at random
    query = f"SELECT * FROM transaction_history WHERE user_id = '{user_id}' LIMIT 5"
    result_df = pd.read_sql_query(query, conn)
    print("\nData read from SQLite table:")

    conn.close()

    return result_df


## Policy RAG Tool
Write a function search_bank_policy(query: str) that queries your ChromaDB and returns the most relevant paragraph from your PDF.

In [20]:
import logging
from transformers import logging as transformers_logging
transformers_logging.set_verbosity_error()
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain.tools import tool
import chromadb
import sys

# Fix for SQLite on Mac (common issue with ChromaDB)
if sys.platform.startswith('darwin'):
    try:
        __import__('pysqlite3')
        sys.modules['sqlite3'] = sys.modules['pysqlite3']
    except ImportError:
        pass

CHROMA_PATH = "./chroma_db"
CHROMA_COLLECTION_NAME = "bank_policies"

@tool
def search_bank_policy(query: str) -> str:
    """
    Searches the official Bank Anti-Fraud Policy documentation. 
    Use this tool when you need to verify if a flagged transaction 
    violates specific banking regulations or internal risk thresholds.
    """
    # Consistency: Use same embeddings and paths as ingestion
    embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

    # Initialize Persistent Client
    chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)

    vector_db = Chroma(
        client=chroma_client,
        embedding_function=embeddings,
        collection_name=CHROMA_COLLECTION_NAME
    )
    
    # Perform the Similarity Search
    # k=1 returns only the single most relevant paragraph
    docs = vector_db.similarity_search(query, k=1)
    
    if not docs:
        return "No relevant policy found for this query."
    
    # Return the text content of the best match
    return docs[0].page_content

# Test the tool manually
try:
    # result = search_bank_policy.invoke("What is the limit for overseas wire transfers?")
    result = search_bank_policy.invoke("housing status")
    print(f"Policy Found: {result}")
except Exception as e:
    print(f"Error during search: {e}")


Loading weights: 100%|██████████| 103/103 [00:00<00:00, 1630.89it/s, Materializing param=pooler.dense.weight]                             


Policy Found: Housing Status (BE, BB, BC): These anonymized categories are historically associated with lower
fraud rates.
Escalation Rules


### Testing tools and functions
Please comment out @tools from above functions to test

In [21]:
import logging
from transformers import logging as transformers_logging
transformers_logging.set_verbosity_error()
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
from langchain_chroma import Chroma
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
import unittest

# Fix for SQLite on Mac (common issue with ChromaDB)
if sys.platform.startswith('darwin'):
    try:
        __import__('pysqlite3')
        sys.modules['sqlite3'] = sys.modules['pysqlite3']
    except ImportError:
        pass

# --- Unit Tests ---
class TestCoreTools(unittest.TestCase):
    
    def test_get_user_transactions_output_type(self):
        """Verify that get_user_transactions returns a pandas DataFrame."""
        # Using a sample user ID from the database
        df = get_user_transactions("USER_0000")
        if df is not None:
            self.assertIsInstance(df, pd.DataFrame)
            print("\n[PASSED] get_user_transactions returns a DataFrame.")
        else:
            self.skipTest("Fraud_Agent.db not found for integration test.")

    def test_get_user_transactions_valid_user(self):
        """Verify that get_user_transactions returns data for a valid user."""
        df = get_user_transactions("USER_0000")
        if df is not None:
            # We know USER_0000 exists if the database was setup correctly
            # If the DB is empty, this might be 0, but it should still be a DF
            self.assertTrue(len(df) >= 0)
            print(f"[PASSED] get_user_transactions returned {len(df)} rows for USER_0000.")
        else:
            self.skipTest("Fraud_Agent.db not found for integration test.")

    def test_search_bank_policy_output_type(self):
        """Verify that search_bank_policy returns a string."""
        if os.path.exists("./chroma_db"):
            result = search_bank_policy("housing status")
            self.assertIsInstance(result, str)
            self.assertNotEqual(result, "No relevant policy found for this query.")
            print("[PASSED] search_bank_policy returns a valid string.")
        else:
            self.skipTest("./chroma_db not found for integration test.")

    def test_search_bank_policy_content(self):
        """Verify that search_bank_policy returns relevant content."""
        if os.path.exists("./chroma_db"):
            result = search_bank_policy("risk indicators")
            self.assertTrue(len(result) > 10)
            print(f"[PASSED] search_bank_policy returned relevant content: {result[:50]}...")
        else:
            self.skipTest("./chroma_db not found for integration test.")

unittest.main(argv=['first-arg-is-ignored'], exit=False)


EEEE
ERROR: test_get_user_transactions_output_type (__main__.TestCoreTools.test_get_user_transactions_output_type)
Verify that get_user_transactions returns a pandas DataFrame.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/var/folders/j_/r592xntd31b897xl6vt_hhl00000gn/T/ipykernel_52199/3919581003.py", line 23, in test_get_user_transactions_output_type
    df = get_user_transactions("USER_0000")
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'StructuredTool' object is not callable

ERROR: test_get_user_transactions_valid_user (__main__.TestCoreTools.test_get_user_transactions_valid_user)
Verify that get_user_transactions returns data for a valid user.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/var/folders/j_/r592xntd31b897xl6vt_hhl00000gn/T/ipykernel_52199/3919581003.py", line 32, in test_get_user_transactions_valid_user
    df = get_user_t

<unittest.main.TestProgram at 0x12c926f90>

## Phase 4

In [22]:
import joblib
import xgboost as xgb

# 1. Load Preprocessor and Model
preprocessor = joblib.load('preprocessor.joblib')
model = xgb.XGBClassifier()
model.load_model('XGBoostModel.json')

# 2. Get a sample high-risk transaction
# We'll use a sample from the dataframe created in Phase 2
sample_row = df.iloc[0].to_dict()

# 3. Generate SHAP explanation
explanation = get_shap_explanation(sample_row, model, preprocessor)
print(f"Fraud Score: {explanation['score']}")
print(f"Top Reasons: {explanation['top_reasons']}")

# 4. Store user_id for the agent
investigated_user_id = sample_row.get('user_id', 'USER_0000')


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Fraud Score: 0.09384600818157196
Top Reasons: ['payment_type_AC = 1.00', 'velocity_4w = 3158.28', 'bank_branch_count_8w = 0.00']


In [23]:
#!pip install -qU langchain-google-genai langgraph

In [29]:
from langchain_core.tools import tool 
from langchain.agents import create_agent
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
import os
from dotenv import load_dotenv
from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.prebuilt import create_react_agent

# Configure your API key
load_dotenv("GoogleAPI.env") 
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
print(GOOGLE_API_KEY)

# Prompt
Prompt = """You are a Senior Fraud Investigator.
You will be given a transaction and its SHAP explanations.
Use your tools to check the user's history and bank policy.
Then, write a 3-sentence memo concluding if it is fraud or not."""

# 1. Define Tools
tools = [get_user_transactions, search_bank_policy]

# 2. Initialize the model
llm = ChatGoogleGenerativeAI(
    model="gemini-3-flash-preview",
    api_key = GOOGLE_API_KEY,
    temperature=0
)

# Initialize the agent
agent = create_agent(llm, tools, system_prompt= Prompt)


# 4. Run the agent with dynamic query from SHAP
reasons_str = "\n".join([f"- {r}" for r in explanation['top_reasons']])
query = f"""
Investigate transaction for {investigated_user_id}.
Model Fraud Probability: {explanation['score']:.2f}
Top SHAP Contributing Factors:
{reasons_str}

Please check the user's transaction history and cross-reference with bank policy to generate a final memo.
"""
agent.invoke({"messages": [("user", query)]})


AIzaSyC3oAFLC1AvkrhWvSwmxuIsY_7bgT8XBY0

Data read from SQLite table:


Loading weights: 100%|██████████| 103/103 [00:00<00:00, 1291.04it/s, Materializing param=pooler.dense.weight]                             
Loading weights: 100%|██████████| 103/103 [00:00<00:00, 1214.77it/s, Materializing param=pooler.dense.weight]                             


{'messages': [HumanMessage(content="\nInvestigate transaction for USER_0000.\nModel Fraud Probability: 0.09\nTop SHAP Contributing Factors:\n- payment_type_AC = 1.00\n- velocity_4w = 3158.28\n- bank_branch_count_8w = 0.00\n\nPlease check the user's transaction history and cross-reference with bank policy to generate a final memo.\n", additional_kwargs={}, response_metadata={}, id='af124f9b-aa1a-4db9-87d0-c86208e4ba5a'),
  AIMessage(content=[], additional_kwargs={'function_call': {'name': 'get_user_transactions', 'arguments': '{"user_id": "USER_0000"}'}, '__gemini_function_call_thought_signatures__': {'375c97be-88da-4a01-ab5c-e41a88b61292': 'EtAECs0EAb4+9vuOyNXAsjs+eXaW7PIVCznuCi/0CidzQ5fBj+6bafawZXpwWHNGh3oMkSK5YUNCPbczJSg31A2zSTvFJEr6fZyKhdewooRfph0FMq0EhjT3L//ZKL2t9A+spTRQBifrSeLa72TER0z10PgtieZnomybsuWIedrubJjAdB7WnrXwnKIYZLxSobvzC12QtlUiV7AEaK5SmKd5gre7FPdVrcRB773nFcqAEvkcy3tIsYNgchwvH9ZMPzbgCAuvFR8KarH61vdGbJcZcuf3uMyeAki2AYM8AxnDeVD3r4AEt9fw2aXVfGCt8LgjM9QEKqyic+cz/HQK5FP4Y+EpUDY