In [1]:
import pandas as pd
import plotly_express as px
import numpy as np

In [8]:
raw_data = pd.read_csv('./decision_tree_example_data.csv')
print(raw_data)

     Outlook Temperature Humidity    Wind Play Tennis
0      Sunny         Hot     High    Weak          No
1      Sunny         Hot     High  Strong          No
2   Overcast         Hot     High    Weak         Yes
3       Rain        Mild     High    Weak         Yes
4       Rain        Cool   Normal    Weak         Yes
5       Rain        Cool   Normal  Strong          No
6   Overcast        Cool   Normal    Weak         Yes
7      Sunny        Mild     High    Weak          No
8      Sunny        Cool   Normal    Weak         Yes
9       Rain        Mild   Normal  Strong         Yes
10     Sunny        Mild   Normal  Strong         Yes
11  Overcast        Mild     High  Strong         Yes
12  Overcast         Hot   Normal    Weak         Yes
13      Rain        Mild     High  Strong          No


In [42]:
# raw_data[raw_data['Outlook'] == 'Sunny']
idx = raw_data['Outlook'] == 'Sunny'
raw_data[idx].groupby(raw_data[idx]['Temperature']).count().index.values

array(['Cool', 'Hot', 'Mild'], dtype=object)

In [27]:
disc_count = raw_data['Play Tennis'].groupby(raw_data['Play Tennis']).count().values
disc_count/sum(disc_count)
entropy_arr = (-1 * (disc_count/sum(disc_count)) * np.log(disc_count/sum(disc_count)))/np.log(2)
sum(entropy_arr)

0.9402859586706309

In [28]:
def entropy(raw_data, target_var_name='Play Tennis'):
    disc_count = raw_data[target_var_name].groupby(raw_data[target_var_name]).count().values
    entropy_arr = (-1 * (disc_count/sum(disc_count)) * np.log(disc_count/sum(disc_count)))/np.log(2)
    return sum(entropy_arr)

In [43]:
entropy(raw_data[raw_data[attrib] == 'Rain'])

0

In [46]:
raw_data['Outlook'].unique().tolist()

['Sunny', 'Overcast', 'Rain']

In [55]:
def information_gain_by_attrib(raw_data, attrib='Outlook', target_var_name='Play Tennis'):
    entropy_value_wise = []
    length_value_wise = []
    parent_entropy = entropy(raw_data, target_var_name)
    for attrib_value in raw_data[attrib].unique():
        filtered_df = raw_data[raw_data[attrib] == str(attrib_value)]
        entropy_value_wise.append(entropy(filtered_df, target_var_name))
        length_value_wise.append(len(filtered_df))
    return parent_entropy - np.sum(np.array(length_value_wise) / sum(length_value_wise) * np.array(entropy_value_wise))

In [67]:
information_gain_by_attrib(raw_data, 'Outlook')
raw_data['Outlook'].unique()

array(['Sunny', 'Overcast', 'Rain'], dtype=object)

In [129]:
decision_tree = {}
visited = [False]*(len(raw_data.columns))
def build_decision_tree(raw_data, parent='', parent_branch='', visited=[False]*(len(raw_data.columns)), target_var_name='Play Tennis'):
    max_inf_gain = -1
    max_inf_gain_column = ''
    max_inf_gain_column_idx = 0
    for i, column in enumerate(raw_data.columns.values):
        if not visited[i] and column != target_var_name:
            inf_gain = information_gain_by_attrib(raw_data, column, target_var_name)
            if inf_gain > max_inf_gain:
                max_inf_gain = inf_gain
                max_inf_gain_column = column
                max_inf_gain_column_idx = i
    print(f'max_ig_column = {max_inf_gain_column}, parent = {parent}, branch = {parent_branch}, max_ig = {round(max_inf_gain,2)}, visited = {visited}')
    if not decision_tree.get(parent) and parent != '':
        decision_tree[parent] = []
    if parent_branch in [edge.get('branch_name') for edge in decision_tree.get(parent, []) if parent != '']:
        branch_idx = [edge.get('branch_name') for edge in decision_tree[parent]].index(parent_branch)
    else:
        branch_idx = -1
    if branch_idx >= 0:
        decision_tree[parent][branch_idx]['leads_to'] = max_inf_gain_column
    if max_inf_gain <= 0:
        #leaf node to be pushed in parent, if not parent, set root equal to max_inf_gain_column 
        #and parent array in decision tree
        if not parent:
            print('Empty decision tree, i.e, the data has zero entropy(uncertainity)')
            return
        if branch_idx >= 0:
            #write result
            decision_tree[parent][branch_idx]['result'] = raw_data[target_var_name].values[0]
        else:
            decision_tree[parent].append({'branch_name': parent_branch, 'result': str(raw_data[target_var_name].values[0]), 'parent': parent})
        return
    for disc_value in raw_data[max_inf_gain_column].unique():
        visited[max_inf_gain_column_idx] = True
        if disc_value in [edge.get('branch_name') for edge in decision_tree.get(max_inf_gain_column, [])]:
            continue
        if not decision_tree.get(max_inf_gain_column):
            decision_tree[max_inf_gain_column] = []
        decision_tree[max_inf_gain_column].append({'branch_name': disc_value, 'result': None, 'parent': max_inf_gain_column})
        build_decision_tree(raw_data[raw_data[max_inf_gain_column]==disc_value], 
                            parent=max_inf_gain_column,parent_branch=disc_value,
                            visited=visited,target_var_name=target_var_name)
        visited[max_inf_gain_column_idx] = False
        

In [130]:
build_decision_tree(raw_data, visited=visited)

max_ig_column = Outlook, parent = , branch = , max_ig = 0.25, visited = [False, False, False, False, False]
max_ig_column = Humidity, parent = Outlook, branch = Sunny, max_ig = 0.97, visited = [True, False, False, False, False]
max_ig_column = Temperature, parent = Humidity, branch = High, max_ig = 0.0, visited = [True, False, True, False, False]
max_ig_column = Temperature, parent = Humidity, branch = Normal, max_ig = 0.0, visited = [True, False, True, False, False]
max_ig_column = Temperature, parent = Outlook, branch = Overcast, max_ig = 0.0, visited = [True, False, False, False, False]
max_ig_column = Wind, parent = Outlook, branch = Rain, max_ig = 0.42, visited = [True, False, False, False, False]
max_ig_column = Temperature, parent = Wind, branch = Weak, max_ig = 0.0, visited = [True, False, False, True, False]
max_ig_column = Temperature, parent = Wind, branch = Strong, max_ig = 0.25, visited = [True, False, False, True, False]
max_ig_column = Humidity, parent = Temperature, bra

In [131]:
decision_tree

{'Outlook': [{'branch_name': 'Sunny',
   'result': None,
   'parent': 'Outlook',
   'leads_to': 'Humidity'},
  {'branch_name': 'Overcast',
   'result': 'Yes',
   'parent': 'Outlook',
   'leads_to': 'Temperature'},
  {'branch_name': 'Rain',
   'result': None,
   'parent': 'Outlook',
   'leads_to': 'Wind'}],
 'Humidity': [{'branch_name': 'High',
   'result': 'No',
   'parent': 'Humidity',
   'leads_to': 'Temperature'},
  {'branch_name': 'Normal',
   'result': 'Yes',
   'parent': 'Humidity',
   'leads_to': 'Temperature'}],
 'Wind': [{'branch_name': 'Weak',
   'result': 'Yes',
   'parent': 'Wind',
   'leads_to': 'Temperature'},
  {'branch_name': 'Strong',
   'result': None,
   'parent': 'Wind',
   'leads_to': 'Temperature'}],
 'Temperature': [{'branch_name': 'Cool',
   'result': 'No',
   'parent': 'Temperature',
   'leads_to': 'Humidity'},
  {'branch_name': 'Mild',
   'result': None,
   'parent': 'Temperature',
   'leads_to': 'Humidity'}]}

In [132]:
def find_root(decision_tree):
    nodes_lead_to = set()
    for edges in decision_tree.values():
        for edge in edges:
            if edge.get('leads_to'):
                nodes_lead_to.add(edge.get('leads_to'))
    for key in decision_tree.keys():
        if key not in nodes_lead_to:
            return key
    return ''

print(f'root = {find_root(decision_tree)}')

root = Outlook


In [138]:
example = {'Outlook': 'Sunny', 'Wind': 'Strong', 'Temperature': 'Mild', 'Humidity': 'Normal'}
root = find_root(decision_tree)
def predict_for_one_example(decision_tree, root, example={}):
    for branch_info in decision_tree[root]:
        if branch_info.get('branch_name') == example[root]:
            if branch_info.get('result'):
                return branch_info.get('result')
            root = branch_info.get('leads_to')
    return predict_for_one_example(decision_tree, root, example)
    
print(predict_for_one_example(decision_tree, root, example))

Yes
