In [None]:
# 04 - Tree Creation

In [None]:
import numpy as np
import numpy.random as rn
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import tree
import xgboost as xgb
import dtreeviz
#also, must install graphviz for plotting to work with xgb

In [None]:
pos_center = 12
pos_count = 100
neg_center = 7
neg_count = 1000
rs = rn.RandomState(rn.MT19937(rn.SeedSequence(5)))
gini = pd.DataFrame({'value': np.append((pos_center) + rs.randn(pos_count),
                           (neg_center) + rs.randn(neg_count)), 
                    'label': ['pos']*pos_count + ['neg']*neg_count})

In [None]:
gini

In [None]:
fig, ax = plt.subplots(figsize=(8,4))
_ = (gini
.groupby('label')
[['value']]
    .plot.hist(bins=30, alpha=.5, ax=ax, edgecolor='black')
)
ax.legend(['Negative','Positive'])

In [None]:
def calc_gini(df: pd.DataFrame, val_col: str, label_col: str, pos_val, split_point: float, debug=False) -> float:
    """
    This function calculates the Gini impurity of a dataset. Gini impurity is a measure of the probability of a random sample bine gclassified incorrectly when a feature is used to split the data. The lower the impurity, the better the split.
    Parameters:
    df (pd.DataFrame): The dataframe containing the data
    val_col (str): The column name of the feature used to split the data
    label_col (str): The column name of the target variable
    pos_val (str or int): The value of the target variable that represents the positive class
    split_point (float): The threshold used to split the data.
    debug (bool): optional, when set to True, prints the calculated Gini impurities and the final weighted average
    """
    ge_split = df[val_col] >= split_point
    eq_pos = df[label_col] == pos_val
    tp = df[ge_split & eq_pos].shape[0] #num rows greater than split point.
    fp = df[ge_split & ~eq_pos].shape[0]
    tn = df[~ge_split & ~eq_pos].shape[0]
    fn = df[~ge_split & eq_pos].shape[0]
    pos_size = tp + fp
    neg_size = tn + fn
    total_size = len(df)
    if pos_size == 0:
        gini_pos = 0
    else:
        gini_pos = 1 - (tp/pos_size)**2 - (fp/pos_size)**2
    if neg_size == 0:
        gini_neg = 0
    else:
        gini_neg = 1 - (tn/neg_size)**2 - (fn/neg_size)**2
    weighted_avg = gini_pos * (pos_size/total_size) + \
        gini_neg*(neg_size/total_size)
    if debug:
        print(f'{gini_pos=:.3} {gini_neg=:.3} {weighted_avg=:.3}') #the =:.3 is a precision specification. it says, specify to 3 sig figs
    return weighted_avg


In [None]:
calc_gini(gini, val_col='value', label_col='label', pos_val='pos', split_point=9.24, debug=True)

In [None]:
# demonstrating sig fig. if do it with blah = 100, get an error that no sig figs with ints
blah = 100.01
print(f'{blah=:.3}')

In [None]:

values = np.arange(5,15,.1) #like, array range. From 5 to 15 by .1
values

In [None]:

ginis = []
for v in values: #so, splitting them for each of the values
    ginis.append(calc_gini(gini, val_col='value', label_col = 'label', pos_val='pos', split_point=v))
fig, ax = plt.subplots(figsize=(8,4))
ax.plot(values, ginis)
ax.set_title('Gini Coefficient')
ax.set_ylabel('Gini Coefficient')
ax.set_xlabel('Split Point')


In [None]:
pd.Series(ginis, index=values).loc[9.5:10.5] #this doesn't include 9.5 but does 10.5. but values includes 9.5 for sure

In [None]:
pd.DataFrame({'gini':ginis, 'split':values}).query('gini<= gini.min()') #note that gini.min() didn't have to be wrapped with an f-string!

In [None]:
stump = tree.DecisionTreeClassifier(max_depth=1)
stump.fit(gini[['value']], gini.label)

In [None]:
fig, ax = plt.subplots(figsize=(8,4))
tree.plot_tree(stump, feature_names=['value'], filled=True, class_names=stump.classes_, ax=ax) 
#note that tree.plot_tree is an sklearn object
# note that the tree is split around 9.6 as well

In [None]:
xg_stump = xgb.XGBClassifier(n_estimators=1, max_depth=1)
xg_stump.fit(gini[['value']], (gini.label=='pos'))

In [None]:
help(xgb.plot_tree)

In [None]:
xgb.plot_tree(xg_stump, num_trees=0) #num_trees is the index of the tree. 
#xg_stump is a single decision tree, so only index=0 has one. 
#if we fit a random forest, there would be many

In [None]:
viz = dtreeviz.model(xg_stump, X_train=gini[['value']], y_train=gini.label=='pos',
    target_name='positive',
    feature_names=['value'],
    class_names=['negative','positive'],
    tree_index=0)
viz.view()

In [None]:
xg_stump.coef_

In [None]:
dir(xg_stump)