In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
import xgboost as xgb

from typing import Tuple

from utils import methods

In [2]:
train_df = pd.read_csv('../../data/gen_test_v3.csv')
test_df = pd.read_csv('../../data/gen_test_v3.csv')

In [3]:
train_df.head()

h = 1 / 15
c = 25

In [4]:
X_train = train_df[['N', 'n', 'mean_n', 'std_n', 'alpha_hat', 'beta_hat', 'u_star_hat']]
y_train = train_df['u_star']

X_test = test_df[['N', 'n', 'mean_n', 'std_n', 'alpha_hat', 'beta_hat', 'u_star_hat']]
y_test = test_df['u_star']

dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)

In [5]:
def custom_cost(pred: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]:
    y = dtrain.get_label()
    diff = y - pred
    slope = 0
    grad = np.where(diff > 0, h, slope)
    hess = np.where(diff > 0, 0, 0)
    return grad, hess

def custom_metric(pred: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
    y = dtrain.get_label()
    diff = y - pred
    error = np.where(diff > 0, h * diff, c)
    return 'CusCost', float(np.sum(error) / len(y))


In [6]:
results = []
model = xgb.train({'tree_method': 'hist', 'seed': 1994, 'disable_default_eval_metric': 1},  # any other tree method is fine.
           dtrain=dtrain,
           num_boost_round=10,
           obj=custom_cost,
           custom_metric=custom_metric,
           evals=[(dtrain, 'dtrain'), (dtest, 'dtest')])


[0]	dtrain-CusCost:21.58849	dtest-CusCost:21.58849
[1]	dtrain-CusCost:21.58849	dtest-CusCost:21.58849
[2]	dtrain-CusCost:21.58849	dtest-CusCost:21.58849
[3]	dtrain-CusCost:21.58849	dtest-CusCost:21.58849
[4]	dtrain-CusCost:21.58849	dtest-CusCost:21.58849
[5]	dtrain-CusCost:21.58849	dtest-CusCost:21.58849
[6]	dtrain-CusCost:21.58849	dtest-CusCost:21.58849
[7]	dtrain-CusCost:21.58849	dtest-CusCost:21.58849
[8]	dtrain-CusCost:21.58849	dtest-CusCost:21.58849
[9]	dtrain-CusCost:21.58849	dtest-CusCost:21.58849


In [7]:
test_df['predicted_u_star'] = model.predict(dtest)
test_df['actual_cost'] = test_df.apply(lambda row: methods.cal_cost(row['c'], row['h'], row['u'], row['predicted_u_star']), axis=1)
print(f'Actual Mean cost: {test_df['actual_cost'].mean():.2f}, Actual Median cost: {test_df['actual_cost'].median():.2f}')
print(f'Optimal Mean cost: {test_df['optimal_cost'].mean():.2f}, Optimal Median cost: {test_df['optimal_cost'].median():.2f}')

Actual Mean cost: 25.07, Actual Median cost: 21.21
Optimal Mean cost: 4.36, Optimal Median cost: 3.17


In [8]:
test_df.head(10)

Unnamed: 0,alpha,beta,h,c,N,n,mean_n,std_n,alpha_hat,beta_hat,intervals_str,u,u_star,u_star_hat,z,optimal_cost,actual_cost,predicted_u_star
0,5,2.5,0.066667,25,29,5,12.932681,4.299833,9.046363,1.4296,9.873611453932869_12.181558790848907_20.439509...,373.970646,308.401152,330.455982,0.933259,4.3713,24.898043,0.5
1,7,1.5,0.066667,25,16,5,9.336215,3.195497,8.536206,1.09372,8.266907462181138_9.999238032738244_14.2029088...,161.389327,135.114076,121.730806,1.109942,1.751683,10.725955,0.5
2,7,1.0,0.066667,25,21,5,9.326094,4.260114,4.792446,1.945999,9.216064224518064_8.939377602897661_6.74126765...,143.176863,120.497418,157.110477,0.76696,1.511963,9.511791,0.5
3,3,2.0,0.066667,25,16,5,6.171769,2.741844,5.066802,1.21808,4.268162630400795_5.646163787763241_10.9732434...,108.552055,67.076424,74.632356,0.898758,2.765042,7.20347,0.5
4,3,1.0,0.066667,25,30,5,3.408672,3.727784,0.836121,4.07677,2.7139503189213166_1.8650227067252927_9.985086...,82.809339,68.600863,63.23327,1.084886,0.947232,5.487289,0.5
5,3,1.5,0.066667,25,20,5,3.769411,2.560751,2.166767,1.739648,3.11735382807551_1.9783134491147916_0.91413133...,87.63787,64.876472,50.806139,1.276942,1.517427,5.809191,0.5
6,3,2.5,0.066667,25,15,5,9.931571,2.022021,24.124854,0.411674,10.139845582668807_13.268196488069144_8.596625...,106.97037,78.709432,130.343916,0.60386,1.884062,7.098025,0.5
7,3,3.0,0.066667,25,36,5,7.128026,5.330372,1.788231,3.986078,3.956626898229583_1.8282626746323363_13.739765...,334.042238,268.574966,200.318648,1.340739,4.364485,22.236149,0.5
8,5,3.0,0.066667,25,16,5,18.323201,6.923162,7.004767,2.615819,21.25558763838263_19.034362899181332_8.6877522...,233.670311,190.557184,242.398413,0.786132,2.874208,15.544687,0.5
9,3,1.5,0.066667,25,15,5,5.778133,2.196956,6.917228,0.835325,8.134566061803032_6.967224733616182_2.94335748...,53.291864,45.443092,67.092193,0.677323,0.523251,3.519458,0.5
