# JS - PLS XGB - Utility

In [None]:
import joblib
import json
import os

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import xgboost as xgb

In [None]:
# load the model
model_folder = os.path.join(os.pardir, "input", "js-pls-xgb-training")

pp = joblib.load(os.path.join(model_folder, "preprocessor.pkl"))

model = xgb.XGBClassifier()
model.load_model(os.path.join(model_folder, "model.xgb"))

In [None]:
# read data as 32 bit floats
file = os.path.join(os.pardir, "input", "jane-street-market-prediction", "train.csv")
dtype = {c: np.float32 for c in pd.read_csv(file, nrows=1).columns}
full_df = pd.read_csv(file, engine="c", dtype=dtype)

# split into training and validation
valid_df = full_df[full_df["date"].between(425, 500)]

# build features and labels
features = [c for c in valid_df.columns if "feature" in c]
X = valid_df[features].to_numpy()

resp = valid_df["resp"].to_numpy()
weight = valid_df["weight"].to_numpy()
date = valid_df["date"].astype(np.int).to_numpy()

# get model predictions
X = pp.transform(X)
probs = model.predict_proba(X)
probs = probs[:, 1]

In [None]:
# utility function as defined above
def utility(threshold):
    action = np.heaviside(probs - threshold, 0.0)
    p = np.bincount(date, weight * resp * action)
    
    if np.sum(p ** 2) == 0.0:
        return 0.0
    
    t = np.sum(p) / np.sqrt(np.sum(p ** 2)) * np.sqrt(250 / p.size)
    u = min(max(t, 0), 6) * np.sum(p)
    
    return u

# compute utilities at different thresholds and select the best one
thresholds = np.arange(0.01, 1.0, 1e-5)
utilities = np.array([utility(thresh) for thresh in thresholds])
utilities = (utilities - utilities.min()) / (utilities.max() - utilities.min())
best = thresholds[np.argmax(utilities)]
best = np.round(best, 5)

# save the best threshold
with open(os.path.join(os.curdir, "threshold.json"), "w") as file:
    json.dump({"threshold": str(best)}, file)
    
# plot utility curve and print best threshold
print(f"Optimal threshold at p = {best}")

plt.figure(figsize=(8, 5))
plt.plot(thresholds, utilities, "tab:blue")
plt.xlabel("Threshold")
plt.xlabel("Utility")
plt.title("Utility at threshold")
plt.axis([0, 1, 0, 1])
plt.grid(True)
plt.show()