In [1]:
import math
import sys
import pandas as pd
import numpy as np

In [2]:
def entropy(elements, base):
    length = float(len(elements))
    probs = [elements.count(element)/length for element in set(elements)]
    return -sum([p * math.log(p, base) for p in probs])

In [3]:
def split_dataframe(data, header):
    unique_values = data[header].unique()
    result_dict = {elem : pd.DataFrame for elem in unique_values}
    for key in result_dict.keys():
        result_dict[key] = data[:][data[header] == key]
    return result_dict

In [4]:
raw_data = pd.DataFrame(pd.read_csv("data.csv"))
headers = list(raw_data)[1:]
target_header = "os"
base = 2

In [5]:
min_value, min_header = np.inf, None
for header in headers[1:-1]:
    h = entropy(raw_data[header].tolist(), base)
    if h < min_value:
        min_value, min_header = h, header
print(min_value, min_header)

0.9927744539878084 drink


In [6]:
def tree_split(data):
    length = float(len(data))
    max_value, max_header = -np.inf, None
    max_splited = None
    H = entropy(data[target_header].tolist(), base)
    for header in list(data)[:-1]:
        splited_set = split_dataframe(data, header)
        print(splited_set)
        IS = 0
        for subset_header, subset in splited_set.items():
            subset_length = float(len(subset))
            subset_h = entropy(subset[target_header].tolist(), base)
            print(header, subset_header, subset_h)
            IS += subset_length/length * subset_h
        IG = H - IS
        print(header, H, IS)
        if IG > max_value:
            max_value, max_header = IG, header
            max_splited = splited_set

    return max_value, max_header, max_splited


In [7]:
max_value, max_header, max_splited = tree_split(raw_data[headers])
new_headers = [header for header in headers if header != max_header]

for split_value, split_data in max_splited.items():
    max_value, max_header, max_splited = tree_split(split_data[new_headers])

{'Python':       lang   drink       os
0   Python     tea  windows
2   Python     tea      mac
3   Python     tea      mac
4   Python  coffee      mac
5   Python  coffee      mac
6   Python  coffee  windows
7   Python     tea     unix
8   Python  coffee  windows
10  Python     tea  windows
11  Python  coffee  windows
13  Python     tea     unix, 'Kotlin':      lang   drink       os
1  Kotlin  coffee  windows
9  Kotlin     tea  windows, 'matlab':       lang   drink       os
12  matlab  coffee  windows, 'Java':     lang drink       os
14  Java   tea  windows, 'c++':    lang   drink       os
15  c++     tea  windows
17  c++  coffee  windows
18  c++     tea     unix, 'c#':    lang   drink    os
16   c#  coffee  unix, 'c':    lang drink       os
19    c   tea  windows}
lang Python 1.4949188482339508
lang Kotlin -0.0
lang matlab -0.0
lang Java -0.0
lang c++ 0.9182958340544896
lang c# -0.0
lang c -0.0
lang 1.3709505944546687 0.9599497416368464
{'tea':       lang drink       os
0   Python   te

In [8]:
class ID3Tree(object):
    class Node(object):
        def __init__(self, name):
            self.name = name
            self.connections = {}
        
        def connect(self, label, node):
            self.connections[label] = node
        
    
    def __init__(self, data, target_header, base=2):
        self.headers = list(data)[1:]
        self.data = data
        self.target_header = target_header
        self.base = base
        self.root = self.Node("Root")
        
    def build(self):
        self.step(self.root, "", self.data, self.headers)
        
        
    def step(self, parent_node, parent_connection_label, input_data, headers):
        max_value, max_header, max_splited = tree_split(input_data[headers])
        
        if not max_header:
            return

        node = self.Node(max_header)
        parent_node.connect(parent_connection_label, node)
        
        new_headers = [header for header in headers if header != max_header]
         
        for splited_value, splited_data in max_splited.items():
            self.step(node, splited_value, splited_data, new_headers)
        
        
        

In [9]:
tree = ID3Tree(raw_data, "os")
tree.build()

{'Python':       lang   drink       os
0   Python     tea  windows
2   Python     tea      mac
3   Python     tea      mac
4   Python  coffee      mac
5   Python  coffee      mac
6   Python  coffee  windows
7   Python     tea     unix
8   Python  coffee  windows
10  Python     tea  windows
11  Python  coffee  windows
13  Python     tea     unix, 'Kotlin':      lang   drink       os
1  Kotlin  coffee  windows
9  Kotlin     tea  windows, 'matlab':       lang   drink       os
12  matlab  coffee  windows, 'Java':     lang drink       os
14  Java   tea  windows, 'c++':    lang   drink       os
15  c++     tea  windows
17  c++  coffee  windows
18  c++     tea     unix, 'c#':    lang   drink    os
16   c#  coffee  unix, 'c':    lang drink       os
19    c   tea  windows}
lang Python 1.4949188482339508
lang Kotlin -0.0
lang matlab -0.0
lang Java -0.0
lang c++ 0.9182958340544896
lang c# -0.0
lang c -0.0
lang 1.3709505944546687 0.9599497416368464
{'tea':       lang drink       os
0   Python   te

In [10]:
def print_tree(node, tabs):
    print(tabs + node.name)
    for connection, child_node in node.connections.items():
        print_tree(child_node, tabs+"\t")

In [11]:
print_tree(tree.root, "")

Root
	lang
		drink
		drink
		drink
		drink
		drink
		drink
		drink
