In [165]:
import numpy as np
from math import log
import time

In [166]:
class Node:
    def __init__(self, data, level):
        self.level = level
        self.parent = None
        self.le_yes_child = None
        self.gt_no_child = None
        self.data = data
        self.is_leaf = None
        self.pure = None
        self.f = None
        self.t = None
        
    def is_pure(self):
        if self.pure == None:
            if len(set(self.data[:,-1])) != 1:
                self.pure = False
                return False
            else:
                self.pure = True
                return True
        else:
            return self.pure
        
    def set_le_yes_child(self, child):
        self.le_yes_child = child
        
    def set_gt_no_child(self, child):
        self.gt_no_child = child
        
    

In [182]:
def read_file(filename):
    f = open(filename, 'r')
    ls = []
    lines = f.readlines()
    for line in lines:
        if line:
            content = line.strip().split()
            ls.append([float(x) for x in content])
        
    f.close()
    ls = np.asarray(ls)
    return ls


def find_splitting_rule(node):
    num_of_features = node.data.shape[1] - 1
    max_IG = 0.0
    max_t = None
    max_f = None
    for i in range(num_of_features):
        feature_vals = list(set(node.data[:,i]))
        feature_vals_sorted = sorted(feature_vals)
        #print(feature_vals_sorted)
        thresholds = [(feature_vals_sorted[j] + feature_vals_sorted[j+1]) / 2 for j in range(len(feature_vals_sorted)-1)]
        #print(thresholds)
        X_entropy = entropy(node.data)
        for t in thresholds:
            X_Z_entropy = conditional_entropy(node.data, i, t)
            IG = X_entropy - X_Z_entropy
            if IG > max_IG:
                max_IG = IG
                max_f = i
                max_t = t
    node.is_leaf = False
    node.f = max_f
    node.t = max_t
    yes_data, no_data = split_data(node.data, node.f, node.t)
#     print('rule is: feature_{} <= {} (feature index from 0).'.format(max_f, max_t))
#     print(yes_data.shape, no_data.shape)
    
    node.set_le_yes_child(Node(yes_data, node.level+1))
    node.set_gt_no_child(Node(no_data, node.level+1))
    
def split_data(data, f, t):
    features = data[:,f]
    yes_index = np.where(features <= t)
    no_index = np.where(features > t)
    yes_branch = data[yes_index]
    no_branch = data[no_index]
    
    return yes_branch, no_branch
    
            
        
        
def entropy(data, labels = None):
    if labels is None:
        labels = data[:,-1]
    total = labels.shape[0] + 0.0
    zeros = np.count_nonzero(labels) + 0.0
    if zeros == 0.0:
        return 0.0
    ones = total - zeros
    if ones == 0.0:
        return 0.0
    p_zero = zeros / total
    p_one = ones / total
    ent = -p_zero * log(p_zero) - p_one * log(p_one)

    return ent

def conditional_entropy(data, i, t):
    labels = data[:,-1]
    features = data[:,i]
    yes_branch = np.where(features <= t)
    no_branch = np.where(features > t)
#     print(features)
#     print(t)
#     print(yes_branch[0].shape)
#     print(no_branch[0].shape)
    yes_labels = labels[yes_branch]
    no_labels = labels[no_branch]
    num_yes = yes_labels.shape[0]
    num_no = no_labels.shape[0]
#     print(num_yes)
#     print(num_no)
    p_yes = float(num_yes) / float(num_yes + num_no)
    p_no = float(num_no) / float(num_yes + num_no)
    #print(p_yes, p_no)
    X_yes_entropy = entropy(data, yes_labels)
    X_no_entropy = entropy(data, no_labels)
    
    rs = p_yes * X_yes_entropy + p_no * X_no_entropy
#     print('conditional entropy: {}'.format(rs))
    return rs

def preorderTraversal(node):
        res = []
        if node:
            res.append(node.level)
            res = preorderTraversal(node.le_yes_child)
            res = preorderTraversal(node.gt_no_child)
        return res
    

In [183]:
train_data = read_file('pa2train.txt')
num_of_samples = train_data.shape[0]
num_of_features = train_data.shape[1] - 1

In [184]:
start = time.time()
impure_nodes = []
root = Node(train_data, 0)
if root.is_pure():
    print('root is pure. No need to continue.')
else:
    impure_nodes.append(root)
    
while(impure_nodes):
    cur = impure_nodes.pop()
    if cur.is_pure():
        #print(cur.data[:,-1])
        continue
    else:
        find_splitting_rule(cur)
        impure_nodes.append(cur.le_yes_child)
        impure_nodes.append(cur.gt_no_child)
    
end = time.time()
total = end - start
print('time used: {} seconds.'.format(total))

time used: 3.1704294681549072 seconds.


In [186]:
rs = preorderTraversal(root)
print(rs)

[]
