# Mistplay Fraud Demo â€“ Train & Register Model

This notebook uses Feature Store to build a training set, trains a fraud classifier, tracks metrics with MLflow, and registers the model.

In [0]:
!pip install mlflow


In [0]:
!pip install databricks-feature-store

In [0]:
%restart_python or dbutils.library.restartPython()

In [0]:
import mlflow
import mlflow.sklearn
import pandas as pd
from databricks.feature_store import FeatureStoreClient, FeatureLookup
from pyspark.sql import functions as F
from sklearn.compose import ColumnTransformer
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler

DB_NAME = "ramin_.mistplay_fraud_demo"
MODEL_LR = "ramin_.mistplay_fraud_demo.mistplay_fraud_model_lr"
MODEL_RF = "ramin_.mistplay_fraud_demo.mistplay_fraud_model_rf"


In [0]:

fs = FeatureStoreClient()

training_base = spark.table(f"{DB_NAME}.training_base")

lookups = [
    FeatureLookup(table_name=f"{DB_NAME}.account_features", lookup_key="account_id"),
    FeatureLookup(table_name=f"{DB_NAME}.device_features", lookup_key="device_id"),
]

training_set = fs.create_training_set(
    training_base,
    feature_lookups=lookups,
    label="is_fraud_label",
    exclude_columns=["account_id", "device_id"],
)


## scikitlearn package to use (normal ML)

In [0]:

training_df = training_set.load_df().toPandas()

label_col = "is_fraud_label"
categorical_cols = [
    "country",
    "platform",
    "marketing_channel",
    "device_type",
    "os_version",
]

numeric_cols = [
    c for c in training_df.columns if c not in categorical_cols + [label_col]
]

X = training_df.drop(columns=[label_col])
y = training_df[label_col]

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

preprocess = ColumnTransformer(
    transformers=[
        ("categorical", OneHotEncoder(handle_unknown="ignore"), categorical_cols),
        ("numeric", StandardScaler(), numeric_cols),
    ]
)


In [0]:
training_df

## mlflow to track metrics and model

In [0]:

def train_and_log(model_name: str, estimator):
    pipeline = Pipeline(steps=[("preprocess", preprocess), ("model", estimator)])
    with mlflow.start_run(run_name=model_name) as run:
        pipeline.fit(X_train, y_train)
        preds = pipeline.predict_proba(X_test)[:, 1]

        roc_auc = roc_auc_score(y_test, preds)
        pr_auc = average_precision_score(y_test, preds)

        mlflow.log_metric("roc_auc", roc_auc)
        mlflow.log_metric("pr_auc", pr_auc)

        fs.log_model(
            model=pipeline,
            artifact_path="model",
            flavor=mlflow.sklearn,
            training_set=training_set,
            registered_model_name=model_name,
        )

    return {
        "model_name": model_name,
        "roc_auc": roc_auc,
        "pr_auc": pr_auc,
    }

results = []

results.append(
    train_and_log(
        MODEL_LR,
        LogisticRegression(max_iter=200, class_weight="balanced"),
    )
)

results.append(
    train_and_log(
        MODEL_RF,
        RandomForestClassifier(
            n_estimators=200,
            max_depth=8,
            random_state=42,
            class_weight="balanced",
        ),
    )
)



In [0]:
results_df = pd.DataFrame(results)

best_row = results_df.sort_values("roc_auc", ascending=False).iloc[0]

best_model_name = best_row["model_name"]
best_roc_auc = float(best_row["roc_auc"])

best_model_df = spark.createDataFrame([
    {
        "model_name": best_model_name,
        "selection_metric": "roc_auc",
        "metric_value": best_roc_auc,
        "selected_at": None,
    }
]).withColumn("selected_at", F.current_timestamp())

best_model_df.write.mode("overwrite").saveAsTable(f"{DB_NAME}.model_selection")

print("Registered models:")
print(f"- {MODEL_LR}")
print(f"- {MODEL_RF}")
print("Best model by ROC AUC:")
print(best_model_name)