# XAI Fraud Agent

## Phase 1: The ML & Explainability Layer

### Import Libraries for XGBoost Model

In [1]:
import xgboost as xgb
import json
import pandas as pd
import numpy as np
import shap


  from .autonotebook import tqdm as notebook_tqdm


### Preprocessing

In [2]:
def preprocess_transaction(transaction_row, preprocessor=None):
    """
    Helper function to preprocess a single transaction row.
    Returns processed DataFrame.
    """
    # Convert to DataFrame if it's a dict or Series
    if not isinstance(transaction_row, pd.DataFrame):
        df_input = pd.DataFrame([transaction_row])
    else:
        df_input = transaction_row.copy()
        
    # 1. Drop irrelevant columns
    cols_to_drop = ['month', 'device_fraud_count', 'fraud_bool']
    df_input = df_input.drop(columns=[c for c in cols_to_drop if c in df_input.columns], errors='ignore')
    
    # 2. Convert types
    for col in df_input.columns:
        if col not in ['payment_type', 'employment_status', 'housing_status', 'source', 'device_os']:
             df_input[col] = pd.to_numeric(df_input[col], errors='coerce')
                
    # 3. Handle Missing Values
    missing_cols = [
        "prev_address_months_count", "current_address_months_count",
        "bank_months_count", "session_length_in_minutes"
    ]
    for col in missing_cols:
        if col in df_input.columns:
             df_input[col] = df_input[col].replace(-1, np.nan)
    
    # 4. Apply OneHotEncoding
    if preprocessor:
        try:
            X_transformed = preprocessor.transform(df_input)
            return X_transformed
        except Exception as e:
            print(f"Preprocessing error: {e}")
            return None
    else:
        return df_input

### Predict Function

In [3]:
def predict(transaction_row, preprocessor, model_params_path='XGBoostModelParameters.json', model_path='XGBoostModel.json'):
    """
    Takes a single transaction row, preprocesses it, and returns the fraud probability.
    """
    # Load parameters
    try:
        with open(model_params_path, 'r') as file:
            loaded_params = json.load(file)
    except FileNotFoundError:
        print(f"Error: {model_params_path} not found.")
        return None

    X_transformed = preprocess_transaction(transaction_row, preprocessor)
    if X_transformed is None:
        return None
    
    X_numpy = X_transformed.to_numpy()

    # Load Model (Note: This assumes model file exists)
    try:
        model = xgb.XGBClassifier(**loaded_params)
        model.load_model(model_path)
    except Exception as e:
        print(f"Error loading model: {e}")
        return None
    
    # Inference
    try:
        probability = model.predict_proba(X_numpy)[0, 1]
        return float(probability)
    except Exception as e:
        print(f"Prediction error: {e}")
        return None


### Create SHAP Explanation Values

In [4]:
def get_shap_explanation(transaction_data, model, preprocessor):
    """
    Generates a SHAP explanation for a single transaction.
    Returns a dictionary with fraud probability and top 3 contributing features.
    """
    # Preprocess
    X_transformed = preprocess_transaction(transaction_data, preprocessor)
    if X_transformed is None:
        return {"error": "Preprocessing failed"}
    
    # Ensure we use DataFrame for column names in SHAP
    feature_names = preprocessor.get_feature_names_out() if hasattr(preprocessor, 'get_feature_names_out') else X_transformed.columns
    X_df = pd.DataFrame(X_transformed, columns=feature_names)
    
    # Calculate SHAP values
    explainer = shap.TreeExplainer(model)
    shap_values = explainer(X_df)
    
    # Get values for the first (and only) row
    # shap_values.values shape is (1, n_features)
    # Binary classification: some shap versions output values for both classes, some just one.
    # For XGBClassifier binary, it usually outputs log-odds for class 1.
    
    row_values = shap_values.values[0]
    # base_value = shap_values.base_values[0] # Not strictly needed for top 3
    data_values = X_df.iloc[0]
    
    # Calculate probability
    prob = model.predict_proba(X_df)[0, 1]
    
    # Identify top 3 features pushing score HIGHER (positive contribution to fraud class)
    # We want features that increase the probability of fraud.
    
    # Create list of (feature_name, shap_value, feature_value)
    contributions = []
    
    # Handle multi-class output shape if SHAP returns (1, n_features, 2)
    if len(row_values.shape) > 1:
        # Assuming class 1 is index 1
        row_values = row_values[:, 1]

    for name, val, feat_val in zip(X_df.columns, row_values, data_values):
        contributions.append((name, val, feat_val))
    
    # Sort by SHAP value descending (highest positive impact first)
    contributions.sort(key=lambda x: x[1], reverse=True)
    
    top_3 = contributions[:3]
    
    top_reasons = []
    for name, val, feat_val in top_3:
        # Clean up feature name (remove 'cat__' etc if present)
        clean_name = str(name).replace('cat__', '').replace('remainder__', '')
        
        # Format based on value type
        if isinstance(feat_val, (int, float)):
             reason = f"{clean_name} = {feat_val:.2f}"
        else:
             reason = f"{clean_name} = {feat_val}"
        
        top_reasons.append(reason)
        
    return {
        "score": float(prob),
        "top_reasons": top_reasons
    }

## Phase 2: The Data Environment

In [5]:
import sqlite3
import kagglehub
from faker import Faker

# Setup SQLLite connection 
connection = sqlite3.connect("Fraud_Agent.db")

# Download latest version
path = kagglehub.dataset_download("sgpjesus/bank-account-fraud-dataset-neurips-2022")

#print("Path to dataset files:", path)
# ensure we point to a .csv file (dataset_download may return a path without extension)
csv_path = str(path) + "/Base.csv"

# read the CSV into a DataFrame and setup the final test data
df_OG = pd.read_csv(csv_path)
mask = df_OG["month"] == 7

full_test_data = df_OG[mask].sample(frac=0.5).reset_index(drop=True).drop('month',axis=1) 

df = full_test_data

# Add 
fake = Faker()
num_users = len(df) // 8
user_ids = [f"USER_{i:04d}" for i in range(num_users)]

df['user_id'] = (user_ids * 9)[:len(df)]

print(df)

# Setup a Table in SQL
table_name = "transaction_history"
full_test_data.to_sql(table_name, connection, if_exists='replace', index=False)

# Verify the data was written by reading it back into a new DataFrame
query = f"SELECT * FROM {table_name}"
result_df = pd.read_sql_query(query, connection)
print("\nData read from SQLite table:")
#print(result_df)

# Close the database connection
connection.close()


       fraud_bool  income  name_email_similarity  prev_address_months_count  \
0               0     0.7               0.523318                         -1   
1               0     0.8               0.873555                         -1   
2               0     0.8               0.999981                         -1   
3               0     0.1               0.700272                        110   
4               0     0.8               0.990024                         86   
...           ...     ...                    ...                        ...   
48417           0     0.9               0.068822                         93   
48418           0     0.8               0.898829                         -1   
48419           0     0.6               0.863616                         61   
48420           0     0.9               0.843187                         -1   
48421           0     0.1               0.893419                         -1   

       current_address_months_count  customer_age  

In [6]:
def get_db_size_pragma(db_path):
    """
    Gets the size of a SQLite database in bytes using PRAGMA statements.
    """
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # Get page count
    cursor.execute("PRAGMA page_count;")
    page_count = cursor.fetchone()[0]

    # Get page size
    cursor.execute("PRAGMA page_size;")
    page_size = cursor.fetchone()[0]

    conn.close()

    # Calculate total size in bytes
    size_in_bytes = page_count * page_size
    return size_in_bytes

db_file_path = "Fraud_Agent.db"
size = get_db_size_pragma(db_file_path)

print(f"The size of the database is: {size} bytes (via PRAGMA)")

The size of the database is: 13254656 bytes (via PRAGMA)


In [7]:
import os
#os.remove("Fraud_Agent.db")

### The Vector Store. 
Write a script to read that PDF, split it into chunks using LangChain's RecursiveCharacterTextSplitter, and save it into a local ChromaDB vector database.

In [8]:

import os
import chromadb
from langchain_chroma import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter

# Fix for SQLite on Mac (common issue with ChromaDB)
import sqlite3
import sys
if sys.platform.startswith('darwin'):
    try:
        __import__('pysqlite3')
        sys.modules['sqlite3'] = sys.modules['pysqlite3']
    except ImportError:
        pass

# Configuration
CHROMA_PATH = "./chroma_db"
CHROMA_COLLECTION_NAME = "bank_policies"
PDF_PATH = "Fraud_Detection_Policy.pdf"

def ingest_pdf():
    # 1. Load PDF
    if not os.path.exists(PDF_PATH):
        print(f"Error: {PDF_PATH} not found.")
        return

    print(f"Loading {PDF_PATH}...")
    loader = PyPDFLoader(PDF_PATH)
    documents = loader.load()

    # 2. Split Text
    print("Splitting text...")
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    chunked_documents = text_splitter.split_documents(documents)
    
    print(f"Created {len(chunked_documents)} chunks.")

    # 3. Initialize Embeddings
    print("Initializing embeddings...")
    embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

    # 4. Initialize Chroma Client
    print("Initializing ChromaDB client...")
    chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)

    # 5. Add to Chroma
    print(f"Adding documents to collection '{CHROMA_COLLECTION_NAME}'...")
    Chroma.from_documents(
        documents=chunked_documents,
        embedding=embedding_function,
        collection_name=CHROMA_COLLECTION_NAME,
        client=chroma_client,
    )
    
    print(f"Successfully added {len(chunked_documents)} chunks to ChromaDB at {CHROMA_PATH}")

if __name__ == "__main__":
    ingest_pdf()



Loading Fraud_Detection_Policy.pdf...
Splitting text...
Created 3 chunks.
Initializing embeddings...


Loading weights: 100%|██████████| 103/103 [00:00<00:00, 1395.74it/s, Materializing param=pooler.dense.weight]                             
[1mBertModel LOAD REPORT[0m from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


Initializing ChromaDB client...
Adding documents to collection 'bank_policies'...
Successfully added 3 chunks to ChromaDB at ./chroma_db


## Building the Agent's "Tools" 
Step 3.1: Database Tool.

In [9]:
def get_user_transactions(user_id: str):
    db_path = "Fraud_Agent.db"

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # SELECT any 5 transactions of the user at random
    query = f"SELECT * FROM transaction_history WHERE user_id = '{user_id}' LIMIT 5"
    result_df = pd.read_sql_query(query, conn)
    print("\nData read from SQLite table:")

    conn.close()

    return result_df

user_id = 'USER_0000'

get_user_transactions(user_id)


Data read from SQLite table:


Unnamed: 0,fraud_bool,income,name_email_similarity,prev_address_months_count,current_address_months_count,customer_age,days_since_request,intended_balcon_amount,payment_type,zip_count_4w,...,has_other_cards,proposed_credit_limit,foreign_request,source,session_length_in_minutes,device_os,keep_alive_session,device_distinct_emails_8w,device_fraud_count,user_id
0,0,0.7,0.523318,-1,16,30,0.019999,-0.968113,AB,857,...,0,200.0,0,INTERNET,3.357721,windows,1,1,0,USER_0000
1,0,0.6,0.307123,60,14,30,0.022186,-1.372846,AB,1149,...,0,1500.0,0,INTERNET,2.908656,other,1,1,0,USER_0000
2,0,0.9,0.507479,-1,214,50,0.005081,-0.857296,AB,750,...,0,1500.0,0,INTERNET,2.856984,windows,1,1,0,USER_0000
3,0,0.6,0.884954,-1,145,40,0.000755,-1.502515,AC,1330,...,1,200.0,0,INTERNET,1.713412,linux,1,1,0,USER_0000
4,0,0.7,0.072035,-1,170,40,0.022907,14.367664,AA,1286,...,1,200.0,0,INTERNET,4.331581,other,1,1,0,USER_0000


## Policy RAG Tool
Write a function search_bank_policy(query: str) that queries your ChromaDB and returns the most relevant paragraph from your PDF.

In [10]:
import logging
from transformers import logging as transformers_logging
transformers_logging.set_verbosity_error()
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain.tools import tool
import chromadb
import sys

# Fix for SQLite on Mac (common issue with ChromaDB)
if sys.platform.startswith('darwin'):
    try:
        __import__('pysqlite3')
        sys.modules['sqlite3'] = sys.modules['pysqlite3']
    except ImportError:
        pass

CHROMA_PATH = "./chroma_db"
CHROMA_COLLECTION_NAME = "bank_policies"

@tool
def search_bank_policy(query: str) -> str:
    """
    Searches the official Bank Anti-Fraud Policy documentation. 
    Use this tool when you need to verify if a flagged transaction 
    violates specific banking regulations or internal risk thresholds.
    """
    # Consistency: Use same embeddings and paths as ingestion
    embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

    # Initialize Persistent Client
    chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)

    vector_db = Chroma(
        client=chroma_client,
        embedding_function=embeddings,
        collection_name=CHROMA_COLLECTION_NAME
    )
    
    # Perform the Similarity Search
    # k=1 returns only the single most relevant paragraph
    docs = vector_db.similarity_search(query, k=1)
    
    if not docs:
        return "No relevant policy found for this query."
    
    # Return the text content of the best match
    return docs[0].page_content

# Test the tool manually
try:
    # result = search_bank_policy.invoke("What is the limit for overseas wire transfers?")
    result = search_bank_policy.invoke("housing status")
    print(f"Policy Found: {result}")
except Exception as e:
    print(f"Error during search: {e}")


  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
Loading weights: 100%|██████████| 103/103 [00:00<00:00, 704.86it/s, Materializing param=pooler.dense.weight]                             


Policy Found: Housing Status (BE, BB, BC): These anonymized categories are historically associated with lower
fraud rates.
Escalation Rules


  vector_db = Chroma(


In [13]:
import logging
from transformers import logging as transformers_logging
transformers_logging.set_verbosity_error()
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
from langchain_chroma import Chroma
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
import unittest

# Fix for SQLite on Mac (common issue with ChromaDB)
if sys.platform.startswith('darwin'):
    try:
        __import__('pysqlite3')
        sys.modules['sqlite3'] = sys.modules['pysqlite3']
    except ImportError:
        pass

# --- Unit Tests ---
class TestCoreTools(unittest.TestCase):
    
    def test_get_user_transactions_output_type(self):
        """Verify that get_user_transactions returns a pandas DataFrame."""
        # Using a sample user ID from the database
        df = get_user_transactions("USER_0000")
        if df is not None:
            self.assertIsInstance(df, pd.DataFrame)
            print("\n[PASSED] get_user_transactions returns a DataFrame.")
        else:
            self.skipTest("Fraud_Agent.db not found for integration test.")

    def test_get_user_transactions_valid_user(self):
        """Verify that get_user_transactions returns data for a valid user."""
        df = get_user_transactions("USER_0000")
        if df is not None:
            # We know USER_0000 exists if the database was setup correctly
            # If the DB is empty, this might be 0, but it should still be a DF
            self.assertTrue(len(df) >= 0)
            print(f"[PASSED] get_user_transactions returned {len(df)} rows for USER_0000.")
        else:
            self.skipTest("Fraud_Agent.db not found for integration test.")

    def test_search_bank_policy_output_type(self):
        """Verify that search_bank_policy returns a string."""
        if os.path.exists("./chroma_db"):
            result = search_bank_policy("housing status")
            self.assertIsInstance(result, str)
            self.assertNotEqual(result, "No relevant policy found for this query.")
            print("[PASSED] search_bank_policy returns a valid string.")
        else:
            self.skipTest("./chroma_db not found for integration test.")

    def test_search_bank_policy_content(self):
        """Verify that search_bank_policy returns relevant content."""
        if os.path.exists("./chroma_db"):
            result = search_bank_policy("risk indicators")
            self.assertTrue(len(result) > 10)
            print(f"[PASSED] search_bank_policy returned relevant content: {result[:50]}...")
        else:
            self.skipTest("./chroma_db not found for integration test.")

unittest.main(argv=['first-arg-is-ignored'], exit=False)


..EE
ERROR: test_search_bank_policy_content (__main__.TestCoreTools.test_search_bank_policy_content)
Verify that search_bank_policy returns relevant content.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/var/folders/j_/r592xntd31b897xl6vt_hhl00000gn/T/ipykernel_11143/3919581003.py", line 54, in test_search_bank_policy_content
    result = search_bank_policy("risk indicators")
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'StructuredTool' object is not callable

ERROR: test_search_bank_policy_output_type (__main__.TestCoreTools.test_search_bank_policy_output_type)
Verify that search_bank_policy returns a string.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/var/folders/j_/r592xntd31b897xl6vt_hhl00000gn/T/ipykernel_11143/3919581003.py", line 44, in test_search_bank_policy_output_type
    result = search_bank_policy("housing status")
   


Data read from SQLite table:

[PASSED] get_user_transactions returns a DataFrame.

Data read from SQLite table:
[PASSED] get_user_transactions returned 5 rows for USER_0000.


<unittest.main.TestProgram at 0x1301bc410>