# **JAX Based ANN for Fraud Detection**

In [None]:
# Standard Imports
import os

# Add apth for accessing internal imports
os.sys.path.append("../src")

In [None]:
# Standard Imports
from typing import TypedDict
from pprint import pprint

# Third Party imports
import jax
import optax
import pandas as pd
import polars as pl
import jax.numpy as jnp
from omegaconf import OmegaConf
from tqdm.notebook import tqdm
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.calibration import calibration_curve
from sklearn.metrics import (
    roc_curve,
    precision_recall_curve,
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay,
)

# Internal Imports
from fraud_detection.config import load_config, Config

# **Config & Validation**

In [None]:
# Load the config
config, cfg_raw = load_config("../config/base.yaml")

# Display config
pprint(cfg_raw)

# **Load Data**

In [None]:
def get_data(
    config: Config,
) -> tuple[
    tuple[jax.Array, jax.Array],
    tuple[jax.Array, jax.Array],
    tuple[jax.Array, jax.Array],
    list[str],
]:
    # Read data
    df = pl.read_csv(
        source=config.data.input_path,
        ignore_errors=config.read_csv.ignore_errors,
        infer_schema_length=config.read_csv.infer_schema_length,
    )

    # Split data into X and y
    x = df.select(pl.exclude("Class")).to_pandas()
    y = df.select("Class").to_series().to_pandas()

    # Set type of the splits
    X_train: pd.DataFrame
    y_train: pd.Series
    X_test: pd.DataFrame
    y_test: pd.Series

    # Test Data Split
    X_train, X_test, y_train, y_test = train_test_split(
        x,
        y,
        stratify=y,
        train_size=config.split.test_split.train_size,
        random_state=config.split.test_split.random_state,
    )

    # Valid Data Split
    X_train, X_valid, y_train, y_valid = train_test_split(
        X_train,
        y_train,
        stratify=y_train,
        train_size=config.split.valid_split.train_size,
        random_state=config.split.valid_split.random_state,
    )

    # # The Target Mapping
    # target_mapping = {0: "Legitimate", 1: "Fraudulent"}

    # Convert DataFrame to Jax arrays
    X_train_jnp = jnp.array(X_train.values, dtype=jnp.float32)
    y_train_jnp = jnp.array(y_train.values, dtype=jnp.int32).reshape(-1, 1)
    X_valid_jnp = jnp.array(X_valid.values, dtype=jnp.float32)
    y_valid_jnp = jnp.array(y_valid.values, dtype=jnp.int32).reshape(-1, 1)
    X_test_jnp = jnp.array(X_test.values, dtype=jnp.float32)
    y_test_jnp = jnp.array(y_test.values, dtype=jnp.int32).reshape(-1, 1)

    # Feature names
    feature_names = [str(_) for _ in X_train.columns]

    return (
        (X_train_jnp, y_train_jnp),
        (X_valid_jnp, y_valid_jnp),
        (X_test_jnp, y_test_jnp),
        feature_names,
    )

# **Model & Gradient Structures**

In [None]:
class JaxStdScaler:
    def __init__(self) -> None:
        self.mean: None | jax.Array = None
        self.std: None | jax.Array = None

    def fit(self, X: jax.Array) -> None:
        self.mean = jnp.mean(X, axis=0)
        self.std = jnp.std(X, axis=0)

    def transform(self, X: jax.Array) -> jax.Array:
        if self.mean is None or self.std is None:
            raise ValueError("The JaxStdScaler has not been fitted yet.")
        return self._transform(X=X, mean=self.mean, std=self.std)

    def fit_transform(self, X: jax.Array) -> jax.Array:
        self.fit(X)
        return self.transform(X)

    @staticmethod
    @jax.jit
    def _transform(X: jax.Array, mean: jax.Array, std: jax.Array) -> jax.Array:
        return (X - mean) / std

In [None]:
class ModelParams(TypedDict, total=True):
    # Layer One
    weights_1: jnp.ndarray
    bias_1: jnp.ndarray
    # Layer Two
    weights_2: jnp.ndarray
    bias_2: jnp.ndarray


class ModelSize(TypedDict, total=True):
    # Layer 1
    l1_size: int
    # Layer 2
    l2_size: int
    # Layer 3
    l3_size: int


def init_model_params(model_size: ModelSize, seed: int = 1729) -> ModelParams:
    """Initialize model parameters."""
    # Set Keys
    key = jax.random.key(seed=seed)

    # Split keys
    w1_key, w2_key, b1_key, b2_key = jax.random.split(key=key, num=4)

    # 1. Weights one
    weights_1 = jax.random.normal(
        key=w1_key, shape=(model_size["l1_size"], model_size["l2_size"])
    )
    bias_1 = jax.random.normal(key=b1_key, shape=(model_size["l2_size"],))

    # 2. Weights two
    weights_2 = jax.random.normal(
        key=w2_key, shape=(model_size["l2_size"], model_size["l3_size"])
    )
    bias_2 = jax.random.normal(key=b2_key, shape=(model_size["l3_size"],))

    # Return data
    return {
        "weights_1": weights_1,
        "bias_1": bias_1,
        "weights_2": weights_2,
        "bias_2": bias_2,
    }


@jax.jit
def predict_logits(params: ModelParams, X: jax.Array) -> jax.Array:
    out1 = jax.nn.relu(jnp.matmul(X, params["weights_1"]) + params["bias_1"])
    out2 = jax.nn.relu(jnp.matmul(out1, params["weights_2"]) + params["bias_2"])
    return out2


@jax.jit
def predict_proba(params: ModelParams, X: jax.Array) -> jax.Array:
    # Get Log odds
    z = predict_logits(params=params, X=X)
    # Get Probabalities from Log odds
    return jax.nn.sigmoid(z)


@jax.jit
def weighted_sigmoid_bce(
    logits: jax.Array, labels: jax.Array, pos_weight: float, neg_weight: float
) -> jax.Array:
    # Base loss
    loss = optax.sigmoid_binary_cross_entropy(logits, labels)

    # weights
    Wy = (pos_weight * labels) + (neg_weight * (1 - labels))

    # Return Weighted BCE
    return (loss * Wy).mean()

In [None]:
@jax.jit
def forward_pass(
    params: ModelParams,
    X: jax.Array,
    y: jax.Array,
    pos_weight: float,
    neg_weight: float,
):
    # Get prediction in Logits
    logits = predict_logits(params=params, X=X)
    # Compute loss
    loss = weighted_sigmoid_bce(
        logits=logits, labels=y, pos_weight=pos_weight, neg_weight=neg_weight
    )
    # Return Loss
    return loss


# The Gradiet function
grad_func = jax.jit(jax.value_and_grad(forward_pass))

# **Training Models**

In [None]:
# Read and Process data
(
    (X_train_jnp, y_train_jnp),
    (X_valid_jnp, y_valid_jnp),
    (X_test_jnp, y_test_jnp),
    feature_names,
) = get_data(config=config)

# Scale the data
feature_scaler = JaxStdScaler()
X_train_sc = feature_scaler.fit_transform(X_train_jnp)
X_valid_sc = feature_scaler.transform(X_valid_jnp)
X_test_sc = feature_scaler.transform(X_test_jnp)

# Define Data Featues
model_size: ModelSize = {
    "l1_size": X_train_sc.shape[1],
    "l2_size": 10,
    "l3_size": 1,
}

In [None]:
class ResultRecord(TypedDict, total=True):
    epoch: int
    loss: float

In [None]:
# Write config to Runs folder
OmegaConf.save(cfg_raw, "../runs/config.yaml")
with open("../runs/config_resolved.json", "w") as f:
    f.write(config.model_dump_json(indent=2))

In [None]:
# Init parameters
params = init_model_params(model_size=model_size)

# Start learning rate & Schedulers
schedule = optax.schedules.piecewise_constant_schedule(
    init_value=config.lr_schedule.init_value,
    boundaries_and_scales=config.lr_schedule.boundaries_and_scales,
)

# Set Optimizer
optimizer = optax.adam(learning_rate=schedule)
# initialize Parameters
opt_state = optimizer.init(params)

# Containers
loss: float | None = None
history: list[ResultRecord] = list()

# Iterate through Epochs
for epoch in tqdm(
    range(config.train.epochs),
    desc=f"Last Epoch Loss: {loss}" if (loss is not None) else "Training Epoch Started",
):
    # Compute Loss and Grads
    loss, grads = grad_func(
        params,
        X=X_train_sc,
        y=y_train_jnp,
        pos_weight=config.class_weights.pos_weight,
        neg_weight=config.class_weights.neg_weight,
    )

    # Update params
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    # Loss logged
    history.append({"epoch": epoch, "loss": float(loss)})

# Create figure for loss plotting
fig = go.Figure()

# Add loss trace
fig.add_trace(
    go.Scatter(
        x=[h["epoch"] for h in history],
        y=[h["loss"] for h in history],
        mode="lines+markers",
        name="Training Loss",
        hovertemplate="Epoch: %{x}<br>Loss: %{y:.4f}<extra></extra>",
        line=dict(width=2),
        marker=dict(size=6),
    )
)

# Update layout with better formatting
fig.update_layout(
    title=dict(text="<b>Loss Tracking Through Epochs</b>", x=0.5, font=dict(size=20)),
    xaxis_title="<b>Epoch</b>",
    yaxis_title="<b>Loss</b>",
    hovermode="x unified",
    template="plotly_dark",
    width=1400,
    height=500,
    showlegend=True,
)

# Show plot
fig.show()

# **Validation**

In [None]:
# Get prediction as Probaba
y_pred_train = predict_proba(params=params, X=X_train_sc)
y_pred_valid = predict_proba(params=params, X=X_valid_sc)
y_pred_test = predict_proba(params=params, X=X_test_sc)

# Get prediction as Ligits (Log Odds)
y_pred_logits_train = predict_logits(params=params, X=X_train_sc)
y_pred_logits_valid = predict_logits(params=params, X=X_valid_sc)
y_pred_logits_test = predict_logits(params=params, X=X_test_sc)

In [None]:
fig = go.Figure()
fig.add_trace(go.Histogram(x=y_pred_logits_valid.ravel()))
fig.update_layout(template="plotly_dark")
fig.show()

In [None]:
# Preciion Recall Curve
precision, recall, threshold = precision_recall_curve(
    y_valid_jnp.reshape(-1), y_pred_valid.reshape(-1)
)
# Compute F1 Score
beta = 3
f1 = (1 + beta**2) * precision * recall / ((beta**2 * precision) + recall)

prft = (
    pd.DataFrame(
        dict(
            precision=precision[:-1],
            recall=recall[:-1],
            f1=f1[:-1],
            threshold=threshold,
        )
    )
    .sort_values(["recall", "precision"], ascending=[True, False])
    .reset_index(drop=True)
)

# Print Precision Recall Curve
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=prft["recall"][:-1],
        y=prft["precision"][:-1],
        name="<b>Precision Recall Curve</b>",
        customdata=prft[["f1", "threshold"]].values,
        hovertemplate=(
            "Precision: %{y:.2%}<br>"
            "Recall: %{x:.2%}<br>"
            "F1: %{customdata[0]:.2%}<br>"
            "Threshold: %{customdata[1]:%}"
        ),
    )
)
fig.update_layout(
    title=dict(text="Precision-Recall-Curve", x=0.5, font=dict(size=25)),
    xaxis_title="<b>Recall</b>",
    yaxis_title="<b>Precision</b>",
    template="plotly_dark",
)

# Plot the Maximum
max_f1_idx = prft["f1"].argmax()
max_f1_point = prft.iloc[max_f1_idx]
fig.add_trace(
    go.Scatter(
        x=[max_f1_point["recall"]],
        y=[max_f1_point["precision"]],
        mode="markers",
        name="<b>Maximum F1</b>",
        marker=dict(size=10, symbol="star", color="red"),
        hovertemplate=(
            "Precision: %{y:.2%}<br>"
            "Recall: %{x:.2%}<br>"
            f"F1: {max_f1_point['f1']:.2%}<br>"
            f"Threshold: {max_f1_point['threshold']:%}"
        ),
    )
)
fig.show()

In [None]:
# Get ROC curve values
fpr, tpr, thresholds = roc_curve(y_valid_jnp.reshape(-1), y_pred_valid.reshape(-1))

# Create DataFrame for plotting
roc_df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "threshold": thresholds})

# Create ROC curve plot
fig = go.Figure()

# Add ROC curve trace
fig.add_trace(
    go.Scatter(
        x=roc_df["fpr"],
        y=roc_df["tpr"],
        name="<b>ROC Curve</b>",
        customdata=roc_df["threshold"],
        hovertemplate=(
            "True Positive Rate: %{y:.2%}<br>"
            "False Positive Rate: %{x:.2%}<br>"
            "Threshold: %{customdata}"
        ),
    )
)

# Add diagonal reference line
fig.add_trace(
    go.Scatter(
        x=[0, 1],
        y=[0, 1],
        line=dict(dash="dash", color="gray"),
        name="<b>Random Classifier</b>",
    )
)

# Update layout
fig.update_layout(
    title=dict(
        text="Receiver Operating Characteristic (ROC) Curve", x=0.5, font=dict(size=25)
    ),
    xaxis_title="<b>False Positive Rate</b>",
    yaxis_title="<b>True Positive Rate</b>",
    template="plotly_dark",
)

fig.show()

In [None]:
# Get Class Predictions
y_pred_valid_class = (y_pred_valid.ravel() > 0.9989).astype(int)
# Clf Report
print(classification_report(y_valid_jnp.ravel(), y_pred_valid_class))
# Confusion matrix
ConfusionMatrixDisplay(confusion_matrix(y_valid_jnp.ravel(), y_pred_valid_class)).plot()
plt.show()

# **Caliberation**

In [None]:
prob_true, prob_pred = calibration_curve(
    y_valid_jnp.reshape(-1), y_pred_valid.reshape(-1), n_bins=10, strategy="uniform"
)

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=prob_pred,
        y=prob_true,
        mode="lines+markers",
        name="Model",
        marker=dict(size=8),
        line=dict(width=2),
        hovertemplate="Pred: %{x:.3f}<br>True: %{y:.3f}<extra></extra>",
    )
)
fig.add_trace(
    go.Scatter(
        x=[0, 1],
        y=[0, 1],
        mode="lines",
        name="Perfect Calibration",
        line=dict(dash="dash", color="gray"),
        hoverinfo="skip",
    )
)
fig.update_layout(
    title=dict(text="<b>Calibration Plot</b>", x=0.5, font=dict(size=18)),
    xaxis_title="Predicted probability",
    yaxis_title="True frequency of fraud",
    template="plotly_dark",
    width=500,
    height=500,
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
)
fig.show()

In [None]:
@jax.jit
def nll_from_logits(logits: jax.Array, targets: jax.Array) -> jax.Array:
    targets = targets.astype(jnp.float32)
    return jnp.mean(
        jnp.maximum(0, logits) - logits * targets + jnp.log1p(jnp.exp(-jnp.abs(logits)))
    )


@jax.jit
def temp_forward_pass(u: jax.Array, logits: jax.Array, targets: jax.Array) -> jax.Array:
    # Get Temperatures
    T = jnp.exp(u)
    # Compute scaled NLL
    scaled_nll = nll_from_logits(logits=(logits / T), targets=targets)
    # Return loss
    return scaled_nll


# Gradinet function
grad_func_temp = jax.jit(jax.value_and_grad(temp_forward_pass))

In [None]:
# init Temperature Param
u = jnp.float32(1.0)

# Set up optax optimizer
schedule = optax.schedules.piecewise_constant_schedule(
    1e-1, {20: 1e-2, 30: 1e-2, 40: 1e-4}
)
temp_optimizer = optax.adam(learning_rate=schedule)

# Get Optimzier state
temp_optimizer_state = temp_optimizer.init(u)

# For Recording
history: list[ResultRecord] = []

# Training loop
epochs = 50
for epoch in tqdm(range(epochs)):
    # Get grad and value
    loss, grad = grad_func_temp(u, y_pred_logits_valid, y_valid_jnp)

    # Start Ptimizing the values
    updates, temp_optimizer_state = temp_optimizer.update(grad, temp_optimizer_state)
    u = optax.apply_updates(u, updates)

    # Record
    history.append({"epoch": epoch, "loss": float(loss)})


# Create figure for loss plotting
fig = go.Figure()

# Add loss trace
fig.add_trace(
    go.Scatter(
        x=[h["epoch"] for h in history],
        y=[h["loss"] for h in history],
        mode="lines+markers",
        name="Training Loss",
        hovertemplate="Epoch: %{x}<br>Loss: %{y:.4f}<extra></extra>",
        line=dict(width=2),
        marker=dict(size=6),
    )
)

# Update layout with better formatting
fig.update_layout(
    title=dict(text="<b>Loss Tracking Through Epochs</b>", x=0.5, font=dict(size=20)),
    xaxis_title="<b>Epoch</b>",
    yaxis_title="<b>Loss</b>",
    hovermode="x unified",
    template="plotly_dark",
    width=1400,
    height=500,
    showlegend=True,
)

# Show plot
fig.show()

In [None]:
# So Updated Temperature
Temperature = jnp.exp(u)
# get caliberated probalility
y_pred_valid_calib = jax.nn.sigmoid(y_pred_logits_valid / Temperature)

In [None]:
prob_true, prob_pred = calibration_curve(
    y_valid_jnp.reshape(-1),
    y_pred_valid_calib.reshape(-1),
    n_bins=10,
    strategy="uniform",
)

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=prob_pred,
        y=prob_true,
        mode="lines+markers",
        name="Model",
        marker=dict(size=8),
        line=dict(width=2),
        hovertemplate="Pred: %{x:.3f}<br>True: %{y:.3f}<extra></extra>",
    )
)
fig.add_trace(
    go.Scatter(
        x=[0, 1],
        y=[0, 1],
        mode="lines",
        name="Perfect Calibration",
        line=dict(dash="dash", color="gray"),
        hoverinfo="skip",
    )
)
fig.update_layout(
    title=dict(text="<b>Calibration Plot</b>", x=0.5, font=dict(size=18)),
    xaxis_title="Predicted probability",
    yaxis_title="True frequency of fraud",
    template="plotly_dark",
    width=500,
    height=500,
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
)
fig.show()

In [None]:
# Preciion Recall Curve
precision, recall, threshold = precision_recall_curve(
    y_valid_jnp.reshape(-1), y_pred_valid_calib.reshape(-1)
)
# Compute F1 Score
beta = 3
f1 = (1 + beta**2) * precision * recall / ((beta**2 * precision) + recall)

prft = (
    pd.DataFrame(
        dict(
            precision=precision[:-1],
            recall=recall[:-1],
            f1=f1[:-1],
            threshold=threshold,
        )
    )
    .sort_values(["recall", "precision"], ascending=[True, False])
    .reset_index(drop=True)
)

# Print Precision Recall Curve
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=prft["recall"][:-1],
        y=prft["precision"][:-1],
        name="<b>Precision Recall Curve</b>",
        customdata=prft[["f1", "threshold"]].values,
        hovertemplate=(
            "Precision: %{y:.2%}<br>"
            "Recall: %{x:.2%}<br>"
            "F1: %{customdata[0]:.2%}<br>"
            "Threshold: %{customdata[1]:%}"
        ),
    )
)
fig.update_layout(
    title=dict(text="Precision-Recall-Curve", x=0.5, font=dict(size=25)),
    xaxis_title="<b>Recall</b>",
    yaxis_title="<b>Precision</b>",
    template="plotly_dark",
)

# Plot the Maximum
max_f1_idx = prft["f1"].argmax()
max_f1_point = prft.iloc[max_f1_idx]
fig.add_trace(
    go.Scatter(
        x=[max_f1_point["recall"]],
        y=[max_f1_point["precision"]],
        mode="markers",
        name="<b>Maximum F1</b>",
        marker=dict(size=10, symbol="star", color="red"),
        hovertemplate=(
            "Precision: %{y:.2%}<br>"
            "Recall: %{x:.2%}<br>"
            f"F1: {max_f1_point['f1']:.2%}<br>"
            f"Threshold: {max_f1_point['threshold']:%}"
        ),
    )
)
fig.show()

In [None]:
# Get ROC curve values
fpr, tpr, thresholds = roc_curve(
    y_valid_jnp.reshape(-1), y_pred_valid_calib.reshape(-1)
)

# Create DataFrame for plotting
roc_df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "threshold": thresholds})

# Create ROC curve plot
fig = go.Figure()

# Add ROC curve trace
fig.add_trace(
    go.Scatter(
        x=roc_df["fpr"],
        y=roc_df["tpr"],
        name="<b>ROC Curve</b>",
        customdata=roc_df["threshold"],
        hovertemplate=(
            "True Positive Rate: %{y:.2%}<br>"
            "False Positive Rate: %{x:.2%}<br>"
            "Threshold: %{customdata}"
        ),
    )
)

# Add diagonal reference line
fig.add_trace(
    go.Scatter(
        x=[0, 1],
        y=[0, 1],
        line=dict(dash="dash", color="gray"),
        name="<b>Random Classifier</b>",
    )
)

# Update layout
fig.update_layout(
    title=dict(
        text="Receiver Operating Characteristic (ROC) Curve", x=0.5, font=dict(size=25)
    ),
    xaxis_title="<b>False Positive Rate</b>",
    yaxis_title="<b>True Positive Rate</b>",
    template="plotly_dark",
)

fig.show()

In [None]:
# Get Class Predictions
y_pred_valid_class = (y_pred_valid_calib.ravel() > 0.6198).astype(int)
# Clf Report
print(classification_report(y_valid_jnp.ravel(), y_pred_valid_class))
# Confusion matrix
ConfusionMatrixDisplay(confusion_matrix(y_valid_jnp.ravel(), y_pred_valid_class)).plot()
plt.show()