In [261]:
import numpy as np
from math import log
import time
import random
from copy import deepcopy

In [262]:
class Node:
    def __init__(self, data, level):
        self.level = level
        self.parent = None
        self.le_yes_child = None
        self.gt_no_child = None
        self.data = data
        self.is_leaf = None
        self.pure = None
        self.f = None
        self.t = None
        
        if self.pure == None:
            if len(set(self.data[:,-1])) != 1:
                self.pure = False
                self.is_leaf = False
            else:
                self.pure = True
                self.is_leaf = True
                
                
    def set_data(self, data):
        self.data = data
        
        if len(set(self.data[:,-1])) != 1:
            self.pure = False
            self.is_leaf = False
        else:
            self.pure = True
            self.is_leaf = True
                
    def majority(self):
        labels = self.data[:,-1]
        ones = np.count_nonzero(labels)
        zeros = labels.shape[0] - ones
        if ones > zeros:
            return 1.0
        elif ones < zeros:
            return 0.0
        else:
            return random.choice([0.0, 1.0])
        
    def is_pure(self):
        return self.pure
        
    def set_le_yes_child(self, child):
        self.le_yes_child = child
        
    def set_gt_no_child(self, child):
        self.gt_no_child = child
        
    def print_node(self, outfile, level=10000, print_data = False):
        if self.level < level:
            if self.le_yes_child is not None:
                self.le_yes_child.print_node(outfile, level, print_data)
            if self.le_yes_child is not None:
                self.gt_no_child.print_node(outfile, level, print_data)
            outfile.write('Node level: {}, is_leaf: {}, rule: feature_{} <= {}.\n'.format(self.level, self.is_leaf, self.f, self.t))
            if print_data:
                outfile.write('data in node has labels: {}\n'.format(self.data[:,-1]))
        
    

In [263]:
def read_file(filename):
    f = open(filename, 'r')
    ls = []
    lines = f.readlines()
    for line in lines:
        if line:
            content = line.strip().split()
            ls.append([float(x) for x in content])
        
    f.close()
    ls = np.asarray(ls)
    return ls


def find_splitting_rule(node):
    num_of_features = node.data.shape[1] - 1
    max_IG = 0.0
    max_t = None
    max_f = None
    for i in range(num_of_features):
        feature_vals = list(set(node.data[:,i]))
        feature_vals_sorted = sorted(feature_vals)
        #print(feature_vals_sorted)
        thresholds = [(feature_vals_sorted[j] + feature_vals_sorted[j+1]) / 2 for j in range(len(feature_vals_sorted)-1)]
        #print(thresholds)
        X_entropy = entropy(node.data)
        for t in thresholds:
            X_Z_entropy = conditional_entropy(node.data, i, t)
            IG = X_entropy - X_Z_entropy
            if IG > max_IG:
                max_IG = IG
                max_f = i
                max_t = t
    node.is_leaf = False
    node.f = max_f
    node.t = max_t
    yes_data, no_data = split_data(node.data, node.f, node.t)
#     print('rule is: feature_{} <= {} (feature index from 0).'.format(max_f, max_t))
#     print(yes_data.shape, no_data.shape)
    
    l_node = Node(yes_data, node.level+1)
    r_node = Node(no_data, node.level+1)
    l_node.parent = node
    r_node.parent = node
    
    node.set_le_yes_child(l_node)
    node.set_gt_no_child(r_node)
    
def split_data(data, f, t):
    features = data[:,f]
    yes_index = np.where(features <= t)
    no_index = np.where(features > t)
    yes_branch = data[yes_index]
    no_branch = data[no_index]
    
    return yes_branch, no_branch
    
            
        
        
def entropy(data, labels = None):
    if labels is None:
        labels = data[:,-1]
    total = labels.shape[0] + 0.0
    ones = np.count_nonzero(labels) + 0.0
    if ones == 0.0:
        return 0.0
    zeros = total - ones
    if zeros == 0.0:
        return 0.0
    p_zero = zeros / total
    p_one = ones / total
    ent = -p_zero * log(p_zero) - p_one * log(p_one)

    return ent

def conditional_entropy(data, i, t):
    labels = data[:,-1]
    features = data[:,i]
    yes_branch = np.where(features <= t)
    no_branch = np.where(features > t)
#     print(features)
#     print(t)
#     print(yes_branch[0].shape)
#     print(no_branch[0].shape)
    yes_labels = labels[yes_branch]
    no_labels = labels[no_branch]
    num_yes = yes_labels.shape[0]
    num_no = no_labels.shape[0]
#     print(num_yes)
#     print(num_no)
    p_yes = float(num_yes) / float(num_yes + num_no)
    p_no = float(num_no) / float(num_yes + num_no)
    #print(p_yes, p_no)
    X_yes_entropy = entropy(data, yes_labels)
    X_no_entropy = entropy(data, no_labels)
    
    rs = p_yes * X_yes_entropy + p_no * X_no_entropy
#     print('conditional entropy: {}'.format(rs))
    return rs

def preorderTraversal(node):
        res = []
        if node:
            print(node.level)
            res.append(node.level)
            res = preorderTraversal(node.le_yes_child)
            res = preorderTraversal(node.gt_no_child)
        return res
    
def predict(root, sample):
    if root.is_pure():
        return root.data[:,-1][0]
    cur = root
    while not cur.is_leaf:
        f = cur.f
        t = cur.t
        if sample[f] <= t:
            cur = cur.le_yes_child
        else:
            cur = cur.gt_no_child
    return cur.data[:,-1][0]

def evaluate(root, eval_data):
    err_count = 0.0
    total_count = eval_data.shape[0]
    for each in eval_data:
        prediction = predict(root, each)
        if prediction != each[-1]:
            err_count += 1.0
    error = err_count / total_count
    return error
    
def prune(root, val_data):
    queue = []
    queue.append(root)
    #prune_count = 0
    min_error = evaluate(root, val_data)
    while queue:
        cur = queue.pop()
        majority = cur.majority()
        old_data = deepcopy(cur.data)
        cur.set_data(np.array([[majority]]))
        val_error = evaluate(root, val_data)
        if val_error < min_error:
#             print('find a node to prune.')
#             print('old data: {}, new data: {}'.format(old_data, cur.data))
            cur.le_yes_child = None
            cur.gt_no_child = None
            return val_error
        else:
            cur.set_data(old_data)
            if not cur.le_yes_child is None:
                queue.append(cur.le_yes_child)
            if not cur.gt_no_child is None:
                queue.append(cur.gt_no_child)
        
    

In [264]:
train_data = read_file('./data/pa2train.txt')
num_of_samples = train_data.shape[0]
num_of_features = train_data.shape[1] - 1

In [265]:
start = time.time()
impure_nodes = []
root = Node(train_data, 0)
if root.is_pure():
    print('root is pure. No need to continue.')
else:
    impure_nodes.append(root)
    
while(impure_nodes):
    cur = impure_nodes.pop()
    if cur.is_pure():
        #print(cur.data[:,-1])
        continue
    else:
        find_splitting_rule(cur)
        impure_nodes.append(cur.le_yes_child)
        impure_nodes.append(cur.gt_no_child)
        count += 2
    
end = time.time()
total = end - start
print('time used: {} seconds.'.format(total))

time used: 2.724031448364258 seconds.


In [266]:
train_error = evaluate(root, train_data)
print('training error: {}'.format(train_error))



training error: 0.0


In [267]:
# pruning
val_data = read_file('./data/pa2validation.txt')
val_error = evaluate(root, val_data)
print('before pruning the val error is: {}.'.format(val_error))
val_error = prune(root, val_data)
print('after 1 pruning the val error is: {}.'.format(val_error))
val_error = prune(root, val_data)
print('after 2 pruning the val error is: {}.'.format(val_error))



before pruning the val error is: 0.178.
after 1 pruning the val error is: 0.163.
after 2 pruning the val error is: 0.107.


In [268]:
test_data = read_file('./data/pa2test.txt')
test_error = evaluate(root, test_data)
print('testing error: {}'.format(test_error))

testing error: 0.103


In [269]:
out = open('out', 'w')
root.print_node(out)
out.close()

In [197]:
a = Node(np.array([[0.0]]), 1)
b = Node(np.array([[1.0]]), 2)
a.le_yes_child = b

b.data = np.array([[2.0]])

print(a.le_yes_child.data)

[[ 2.]]
