In [None]:
# Standard Imports
from typing import TypedDict

# Third Party imports
import jax
import optax
import pandas as pd
import polars as pl
import jax.numpy as jnp
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    precision_recall_curve,
    roc_curve,
)
from tqdm.notebook import tqdm
import plotly.graph_objects as go

In [None]:
def get_data() -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
    # Read data
    df = pl.read_csv(
        r"D:\Codebase\fraud-detection\data\input\creditcard.csv",
        ignore_errors=False,
        infer_schema_length=1000_000,
    )

    # Split data into X and y
    x = df.select(pl.exclude("Class")).to_pandas()
    y = df.select("Class").to_series().to_pandas()

    # Set type of the splits
    X_train: pd.DataFrame
    y_train: pd.Series
    X_test: pd.DataFrame
    y_test: pd.Series

    # Train test split
    X_train, X_test, y_train, y_test = train_test_split(
        x, y, test_size=0.2, random_state=42, stratify=y
    )

    # The class distribution in training set
    display(y_train.value_counts(normalize=False))
    display(y_train.value_counts(normalize=True).round(4) * 100)

    # # The Target Mapping
    # target_mapping = {0: "Legitimate", 1: "Fraudulent"}

    # Convert DataFrame to Jax arrays
    X_train_jnp = jnp.array(X_train.values)
    y_train_jnp = jnp.array(y_train.values).reshape(-1, 1)
    X_test_jnp = jnp.array(X_test.values)
    y_test_jnp = jnp.array(y_test.values).reshape(-1, 1)

    return X_train_jnp, y_train_jnp, X_test_jnp, y_test_jnp

In [None]:
class JaxStdScaler:
    def __init__(self) -> None:
        self.mean: None | jax.Array = None
        self.std: None | jax.Array = None

    def fit(self, X: jax.Array) -> None:
        self.mean = jnp.mean(X, axis=0)
        self.std = jnp.std(X, axis=0)

    def transform(self, X: jax.Array) -> jax.Array:
        if self.mean is None or self.std is None:
            raise ValueError("The JaxStdScaler has not been fitted yet.")
        return self._transform(X=X, mean=self.mean, std=self.std)

    def fit_transform(self, X: jax.Array) -> jax.Array:
        self.fit(X)
        return self.transform(X)

    @staticmethod
    @jax.jit
    def _transform(X: jax.Array, mean: jax.Array, std: jax.Array) -> jax.Array:
        return (X - mean) / std

In [None]:
class ModelParams(TypedDict, total=True):
    weights: jnp.ndarray
    bias: jnp.ndarray


class DataFeatures(TypedDict, total=True):
    n_features: int
    n_targets: int


def init_model_params(data_features: DataFeatures, seed: int = 1729) -> ModelParams:
    """Initialize model parameters."""
    # Set Keys
    key = jax.random.key(seed=seed)

    # Split keys
    w_key, b_key = jax.random.split(key=key, num=2)
    # init weights
    weights = jax.random.normal(
        key=w_key, shape=(data_features["n_features"], data_features["n_targets"])
    )
    bias = jax.random.normal(key=b_key, shape=(data_features["n_targets"],))

    # Return data
    return {"weights": weights, "bias": bias}


@jax.jit
def predict_logits(params: ModelParams, X: jax.Array) -> jax.Array:
    return jnp.matmul(X, params["weights"]) + params["bias"]


@jax.jit
def predict_proba(params: ModelParams, X: jax.Array) -> jax.Array:
    # Get Log odds
    z = predict_logits(params=params, X=X)
    # Get Probabalities from Log odds
    return jax.nn.sigmoid(z)


@jax.jit
def weighted_sigmoid_bce(
    logits: jax.Array, labels: jax.Array, pos_weight: float, neg_weight: float
) -> jax.Array:
    # Base loss
    loss = optax.sigmoid_binary_cross_entropy(logits, labels)

    # weights
    Wy = (pos_weight * labels) + (neg_weight * (1 - labels))

    # Return Weighted BCE
    return (loss * Wy).mean()

In [None]:
@jax.jit
def forward_pass(
    params: ModelParams,
    X: jax.Array,
    y: jax.Array,
    pos_weight: float,
    neg_weight: float,
):
    # Get prediction in Logits
    logits = predict_logits(params=params, X=X)
    # Compute loss
    loss = weighted_sigmoid_bce(
        logits=logits, labels=y, pos_weight=pos_weight, neg_weight=neg_weight
    )
    # Return Loss
    return loss


# The Gradiet function
grad_func = jax.jit(jax.value_and_grad(forward_pass))

In [None]:
# Read and Process data
X_train_jnp, y_train_jnp, X_test_jnp, y_test_jnp = get_data()

# Scale the data
feature_scaler = JaxStdScaler()
X_train_sc = feature_scaler.fit_transform(X_train_jnp)
X_test_sc = feature_scaler.transform(X_test_jnp)

# Define Data Featues
data_features: DataFeatures = {
    "n_features": X_train_sc.shape[1],
    "n_targets": 1,
}

In [None]:
class ResultRecord(TypedDict, total=True):
    epoch: int
    loss: float

In [None]:
# Init parameters
params = init_model_params(data_features=data_features)

# Weight for imblalnced classes
pos_weight: float = float(99.83 / 0.17)
neg_weight: float = float(1)

# Get Logit (Log Odds)
logits_train = predict_logits(params, X_train_sc)

# Start learning rate & Schedulers
start_lr = 1e-1
schedule = optax.schedules.piecewise_constant_schedule(
    start_lr, {10: 1e-2, 30: 1e-3, 40: 1e-4}
)

# Set Optimizer
optimizer = optax.adam(learning_rate=schedule)
# initialize Parameters
opt_state = optimizer.init(params)

# number of epochs
epochs: int = 50
loss: float | None = None
history: list[ResultRecord] = list()

# Iterate through Epochs
for epoch in tqdm(
    range(epochs),
    desc=f"Last Epoch Loss: {loss}" if (loss is not None) else "Training Epoch Started",
):
    # Compute Loss and Grads
    loss, grads = grad_func(
        params,
        X=X_train_sc,
        y=y_train_jnp,
        pos_weight=pos_weight,
        neg_weight=neg_weight,
    )

    # Update params
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    # Loss logged
    history.append({"epoch": epoch, "loss": float(loss)})

# Create figure for loss plotting
fig = go.Figure()

# Add loss trace
fig.add_trace(
    go.Scatter(
        x=[h["epoch"] for h in history],
        y=[h["loss"] for h in history],
        mode="lines+markers",
        name="Training Loss",
        hovertemplate="Epoch: %{x}<br>Loss: %{y:.4f}<extra></extra>",
        line=dict(width=2),
        marker=dict(size=6),
    )
)

# Update layout with better formatting
fig.update_layout(
    title=dict(text="<b>Loss Tracking Through Epochs</b>", x=0.5, font=dict(size=20)),
    xaxis_title="<b>Epoch</b>",
    yaxis_title="<b>Loss</b>",
    hovermode="x unified",
    template="plotly_dark",
    width=1400,
    height=500,
    showlegend=True,
)

# Show plot
fig.show()

In [None]:
# Get prediction
y_pred_train = predict_proba(params=params, X=X_train_sc)

In [None]:
# Preciion Recall Curve
precision, recall, threshold = precision_recall_curve(
    y_train_jnp.reshape(-1), y_pred_train.reshape(-1)
)
# Compute F1 Score
beta = 15
f1 = (1 + beta**2) * precision * recall / ((beta**2 * precision) + recall)

prft = (
    pd.DataFrame(
        dict(
            precision=precision[:-1],
            recall=recall[:-1],
            f1=f1[:-1],
            threshold=threshold,
        )
    )
    .sort_values(["recall", "precision"], ascending=[True, False])
    .reset_index(drop=True)
)

# Print Precision Recall Curve
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=prft["recall"][:-1],
        y=prft["precision"][:-1],
        name="<b>Precision Recall Curve</b>",
        customdata=prft[["f1", "threshold"]].values,
        hovertemplate=(
            "Precision: %{y:.2%}<br>"
            "Recall: %{x:.2%}<br>"
            "F1: %{customdata[0]:.2%}<br>"
            "Threshold: %{customdata[1]:%}"
        ),
    )
)
fig.update_layout(
    title=dict(text="Precision-Recall-Curve", x=0.5, font=dict(size=25)),
    xaxis_title="<b>Recall</b>",
    yaxis_title="<b>Precision</b>",
    template="plotly_dark",
)

# Plot the Maximum
max_f1_idx = prft["f1"].argmax()
max_f1_point = prft.iloc[max_f1_idx]
fig.add_trace(
    go.Scatter(
        x=[max_f1_point["recall"]],
        y=[max_f1_point["precision"]],
        mode="markers",
        name="<b>Maximum F1</b>",
        marker=dict(size=10, symbol="star", color="red"),
        hovertemplate=(
            "Precision: %{y:.2%}<br>"
            "Recall: %{x:.2%}<br>"
            f"F1: {max_f1_point['f1']:.2%}<br>"
            f"Threshold: {max_f1_point['threshold']:%}"
        ),
    )
)
fig.show()

In [None]:
# Get ROC curve values
fpr, tpr, thresholds = roc_curve(y_train_jnp.reshape(-1), y_pred_train.reshape(-1))

# Create DataFrame for plotting
roc_df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "threshold": thresholds})

# Create ROC curve plot
fig = go.Figure()

# Add ROC curve trace
fig.add_trace(
    go.Scatter(
        x=roc_df["fpr"],
        y=roc_df["tpr"],
        name="<b>ROC Curve</b>",
        customdata=roc_df["threshold"],
        hovertemplate=(
            "True Positive Rate: %{y:.2%}<br>"
            "False Positive Rate: %{x:.2%}<br>"
            "Threshold: %{customdata}"
        ),
    )
)

# Add diagonal reference line
fig.add_trace(
    go.Scatter(
        x=[0, 1],
        y=[0, 1],
        line=dict(dash="dash", color="gray"),
        name="<b>Random Classifier</b>",
    )
)

# Update layout
fig.update_layout(
    title=dict(
        text="Receiver Operating Characteristic (ROC) Curve", x=0.5, font=dict(size=25)
    ),
    xaxis_title="<b>False Positive Rate</b>",
    yaxis_title="<b>True Positive Rate</b>",
    template="plotly_dark",
)

fig.show()