# Monotonic trend validation for an xgboost model

In [1]:
import xgboost as xgb
import pandas as pd
from sklearn.datasets import load_diabetes

In [2]:
import tabular_trees

# Build example xgboost model

## Load data from sklearn

In [3]:
data = load_diabetes()

## Create xgboost matrix

In [4]:
xgb_data = xgb.DMatrix(
    data = data['data'], 
    label = data['target'], 
    feature_names = data['feature_names']
)

In [5]:
monotonic_constraints = pd.Series([0] * len(data['feature_names']), index = data['feature_names'])
monotonic_constraints.loc[monotonic_constraints.index.isin(['age','bmi'])] = -1
monotonic_constraints.loc[monotonic_constraints.index.isin(['bp', 's6'])] = 1
monotonic_constraints

age   -1
sex    0
bmi   -1
bp     1
s1     0
s2     0
s3     0
s4     0
s5     0
s6     1
dtype: int64

## Build model

In [6]:
model = xgb.train(
    params = {
        'max_depth': 3,
        'monotone_constraints': tuple(monotonic_constraints),
        'tree_method': 'exact',
    }, 
    dtrain = xgb_data, 
    num_boost_round = 100
)

In [7]:
model.get_score()

{'age': 35.0,
 'sex': 19.0,
 'bmi': 7.0,
 'bp': 53.0,
 's1': 113.0,
 's2': 133.0,
 's3': 80.0,
 's4': 35.0,
 's5': 144.0,
 's6': 22.0}

# Check monotonic trends have been implemented correctly in the model

## Convert to tabular trees object

In [8]:
xgboost_model_trees = tabular_trees.trees.export_tree_data(model)
model_trees = xgboost_model_trees.to_tabular_trees()

## Check monotonic trends

In [9]:
monotonic_constraint_check = tabular_trees.validate.validate_monotonic_constraints(
    tabular_trees = model_trees, 
    constraints = monotonic_constraints.loc[monotonic_constraints != 0].to_dict()
)

In [10]:
monotonic_constraint_check.summary

{'age': True, 'bmi': True, 'bp': True, 's6': True}

In [11]:
monotonic_constraint_check.results.head()

Unnamed: 0,variable,tree,node,monotonic_trend,monotonic,child_nodes_left_max_prediction,child_nodes_right_min_prediction,child_nodes_left,child_nodes_right
0,age,3,4,-1,True,10.657875,-3.318347,[9.0],[10.0]
1,age,5,1,-1,True,2.655762,-12.411342,"[3.0, 7.0, 8.0]","[4.0, 9.0, 10.0]"
2,age,5,6,-1,True,27.366072,-4.267937,[13.0],[14.0]
3,age,6,3,-1,True,11.539597,2.654022,[7.0],[8.0]
4,age,7,5,-1,True,10.23708,-4.882254,[9.0],[10.0]
