### Lecture 10 - An Example for Implementing Decision Trees

In [None]:
import numpy as np
import pandas as pd

The key for constructing the decision tree is to find the spliting points where we have the minimal entropy, thus prioritizing those features that can separate the data into distinct groups with more precision. 

In [3]:
# define a function for calculating entropy, given a list of elements

from collections import Counter

def entropy(elements):
    counter = Counter(elements)
    probabilities = [counter[e] / len(elements) for e in elements]
    return -sum(p * np.log10(p) for p in probabilities)

In [4]:
entropy([1,1,1,1])

-0.0

In [5]:
entropy([1,1,1,0])

0.4316271552006655

In [6]:
mock_data = {
    'gender':['F', 'F', 'F', 'F', 'M', 'M', 'M'],
    'income': ['+10', '-10', '+10', '+10', '+10', '+10', '-10'],
    'family_size': [1, 1, 2, 1, 1, 1, 2],
    'bought': [1, 1, 1, 0, 0, 0, 1],
}

In [37]:
dataset = pd.DataFrame.from_dict(mock_data)

In [38]:
dataset

Unnamed: 0,gender,income,family_size,bought
0,F,10,1,1
1,F,-10,1,1
2,F,10,2,1
3,F,10,1,0
4,M,10,1,0
5,M,10,1,0
6,M,-10,2,1


In [10]:
# suppose we use 'family_size' as a spliter

sub_split_1 = dataset[dataset['family_size'] == 1]['bought'].tolist()
sub_split_2 = dataset[dataset['family_size'] != 1]['bought'].tolist()

In [11]:
# then the resulting entropy at this point is:

entropy(sub_split_1) + entropy(sub_split_2)

0.7176797562470717

In [12]:
# define a function to find the spliter with minimal entropy

def find_the_min_spliter(data: pd.DataFrame, target: str) -> str:
    """
    @ data = the training data used, in the format of a pandas.DataFrame
    @ target = the column name of the response variable in the data set
    """
    
    # get all the features
    x_vector = set(data.columns.tolist()) - {target}
    
    # initialize parameters to be returned, spliter and min_entropy
    spliter = None
    min_entropy = float('inf')
    
    for f in x_vector:
        elements = set(data[f]) # get the unique levels for each feature
        for e in elements:
            sub_split_1 = data[dataset[f] == e][target].tolist()
            entropy_1 = entropy(sub_split_1)
            sub_split_2 = data[dataset[f] != e][target].tolist()
            entropy_2 = entropy(sub_split_2)
            entropy_total = entropy_1 + entropy_2
            
            if entropy_total  < min_entropy:
                min_entropy = entropy_total
                spliter = (f, e)
            
        print('spliter is: {}'.format(spliter))
        print('the min entropy is :{}'.format(min_entropy))
        
        return spliter

In [13]:
find_the_min_spliter(dataset, 'bought')

spliter is: ('family_size', 1)
the min entropy is :0.7176797562470717


('family_size', 1)

In [14]:
find_the_min_spliter(dataset, 'income')

spliter is: ('family_size', 1)
the min entropy is :0.7509360381569654


('family_size', 1)

Now we could try the built-in function for training a decision-tree classifier from `scikit-learn` and visualize the tree branches.

In [15]:
from sklearn import tree

In [39]:
features = list(dataset.columns[:3])
print(features)

['gender', 'income', 'family_size']


In [40]:
# note -- need to recode gender into integers

dataset['gender'] = dataset['gender'].map({'M': 0, 'F': 1})
dataset['gender']

0    1
1    1
2    1
3    1
4    0
5    0
6    0
Name: gender, dtype: int64

In [43]:
x = dataset[features]
y = dataset['bought']

# alternatively,
# x = dataset.drop('bought', axis=1)

In [46]:
clf = tree.DecisionTreeClassifier(random_state=1)
clf.fit(x, y)

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
                       max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort=False,
                       random_state=1, splitter='best')

In [47]:
tree.plot_tree(clf.fit(x, y))

[Text(297.6, 323.4, 'X[2] <= 1.5\nentropy = 0.49\nsamples = 7\nvalue = [3, 4]'),
 Text(198.4, 230.99999999999997, 'X[0] <= 0.5\nentropy = 0.48\nsamples = 5\nvalue = [3, 2]'),
 Text(99.2, 138.6, 'entropy = 0.0\nsamples = 2\nvalue = [2, 0]'),
 Text(297.6, 138.6, 'X[1] <= 0.0\nentropy = 0.444\nsamples = 3\nvalue = [1, 2]'),
 Text(198.4, 46.19999999999999, 'entropy = 0.0\nsamples = 1\nvalue = [0, 1]'),
 Text(396.8, 46.19999999999999, 'entropy = 0.5\nsamples = 2\nvalue = [1, 1]'),
 Text(396.8, 230.99999999999997, 'entropy = 0.0\nsamples = 2\nvalue = [0, 2]')]