In [29]:
## Written by: Wai Zin Linn
## Attribution: Hugh Liu's solutions for CS540 2021 Epic and Hongtao Hao

import numpy as np

threshold_list = range(1, 11)

with open('breast-cancer-wisconsin.data', 'r') as f:
    data_raw = [l.strip('\n').split(',') for l in f if '?' not in l]
data = np.array(data_raw).astype(int)   # training data

# Adjust the following parameters by yourself
part_one_feature = [7]
feature_list = [8, 10, 5, 4, 3, 6]
target_depth = np.inf

def entropy(data):
    entropy = 0
    count = len(data) # total number of instances
    n2 = np.sum(data[:, -1] == 2) # number of k1
    n4 = np.sum(data[:, -1] == 4) # number of k2
    print(n2,',',n4)
    if n2 == 0 or n4 == 0: 
        return 0
    else: 
        for n in [n2, n4]:
            p = n/count
            entropy += - (p * np.log2(p))
        return entropy

def infogain(data, feature, threshold):
    count = len(data)
    d1 = data[data[:, feature - 1] <= threshold]
    d2 = data[data[:, feature - 1] > threshold]
    print(d1)
    print(d2)
    proportion_d1 = len(d1) / count
    proportion_d2 = len(d2) / count
    return entropy(data) - proportion_d1 * entropy(d1) - proportion_d2 * entropy(d2)

def get_best_split(data, feature_list, threshold_list):
    c = len(data)
    c0 = sum(b[-1] == 2 for b in data)
    if c0 == c: return 2, None, None, None
    if c0 == 0: return 4, None, None, None
    ig = [[infogain(
        data, feature, threshold) for threshold in threshold_list] for feature in feature_list]
    ig = np.array(ig)
    max_ig = max(max(i) for i in ig)
    if max_ig == 0:
        if c0 >= c - c0:
            return 2, None, None, None
        else:
            return 4, None, None, None
    
    idx = np.unravel_index(np.argmax(ig, axis=None), ig.shape)
    feature, threshold = feature_list[idx[0]], threshold_list[idx[1]]
    dl = data[data[:, feature - 1] <= threshold]
    dl_n2 = np.sum(dl[:,-1] == 2)
    dl_n4 = np.sum(dl[:,-1] == 4)
    
    if dl_n2 >= dl_n4:
        dl_prediction = 2
    else:
        dl_prediction = 4
    dr = data[data[:, feature - 1] > threshold]
    dr_n2 = np.sum(dr[:,-1] == 2)
    dr_n4 = np.sum(dr[:,-1] == 4)
    
    print("q3:",dl_n2,dr_n2,dl_n4,dr_n4)
    
    if dr_n2 >= dl_n4:
        dr_prediction = 2
    else:
        dr_prediction = 4
    return feature, threshold, dl_prediction, dr_prediction


class Node:
    def __init__(self, feature = None, threshold = None, l_prediction = None, r_prediction = None):
        self.feature = feature
        self.threshold = threshold
        self.l_prediction = l_prediction
        self.r_prediction = r_prediction
        self.l = None
        self.r = None
        self.correct = 0

def split(data, node):
    # split the data into two parts
    feature, threshold = node.feature, node.threshold
    d1 = data[data[:,feature-1] <= threshold]
    d2 = data[data[:,feature-1] > threshold]
    return (d1,d2)

def create_tree(data, node, feature_list):
    d1,d2 = split(data, node)
    f1, t1, l1_prediction, r1_prediction = get_best_split(d1, feature_list, threshold_list)
    f2, t2, l2_prediction, r2_prediction = get_best_split(d2, feature_list, threshold_list)
    if t1 == None: 
        node.l_pre = f1
    else:
        node.l = Node(f1, t1, l1_prediction, r1_prediction)
        create_tree(d1, node.l, feature_list)
    if t2 == None: 
        node.r_pre = f2
    else:
        node.r = Node(f2, t2, l2_prediction, r2_prediction)
        create_tree(d2, node.r, feature_list)  
    
def maxDepth(node):
    if node is None:
        return 0 ;
    else :
        left_depth = maxDepth(node.l)
        right_depth = maxDepth(node.r)
        return max(left_depth, right_depth) + 1

def expand_root(data, feature_list, threshold_list):
    feature, threshold, dl, dr = get_best_split(
        data, feature_list, threshold_list)
    root = Node(feature, threshold)
    # first split
    data1, data2 = split(data, root)
    create_tree(data, root, feature_list)
    return root

feature, threshold, dl, dr = get_best_split(
    data, feature_list, threshold_list)

root = expand_root(data, feature_list, threshold_list)

# For Q5 & Q8
def print_tree(node, f, prefix=''):
    feature = node.feature
    threshold = node.threshold
    l_prediction = node.l_prediction
    r_prediction = node.r_prediction
    l = node.l
    r = node.r
    if l == None:
        f.write(prefix+'if (x'+str(feature)+') <= '+str(threshold)+') return '+str(l_prediction)+'\n')
    else:
        f.write(prefix+'if (x'+str(feature)+') <= '+str(threshold)+')\n')
        print_tree(l, f, prefix+' ')
    if r == None:
        f.write(prefix+'else return '+str(r_prediction)+'\n')
    else:
        f.write(prefix+'else\n')
        print_tree(r, f, prefix+' ')

with open('./test.txt', 'r') as f:
    test_data = [l.strip('\n').split(',') for l in f if '?' not in l]

test_data = np.array(test_data).astype(int)   # test

# For Q7 & Q9
def tree_prediction(node, x):
	# to predict for test data
    feature = node.feature
    threshold = node.threshold
    l_prediction = node.l_prediction
    r_prediction = node.r_prediction
    l = node.l
    r = node.r
    if x[feature-1] <= threshold:
        if l_prediction == x[-1]:
            node.correct += 1
        if l == None:
            return l_prediction
        else:
            return tree_prediction(l, x)
    else:
        if r_prediction == x[-1]:
            node.correct += 1
        if r == None:
            return r_prediction
        else:
            return tree_prediction(r, x)

# For Q8
def prune(node, depth):
    if depth == 1:
        node.l = None
        node.r = None
    else:
        if node.l != None:
            prune(node.l, depth-1)
        if node.r != None:
            prune(node.r, depth-1)
prune(root, depth=target_depth)

[[1033078       2       1 ...       1       5       2]
 [1070935       1       1 ...       1       1       2]
 [1156948       3       1 ...       1       1       2]
 ...
 [ 714039       3       1 ...       1       1       2]
 [ 776715       3       1 ...       1       1       2]
 [ 841769       2       1 ...       1       1       2]]
[[1000025       5       1 ...       1       1       2]
 [1002945       5       4 ...       2       1       2]
 [1015425       3       1 ...       1       1       2]
 ...
 [ 888820       5      10 ...      10       2       4]
 [ 897471       4       8 ...       6       1       4]
 [ 897471       4       8 ...       4       1       4]]
444 , 239
148 , 2
296 , 237
[[1033078       2       1 ...       1       5       2]
 [1033078       4       2 ...       1       1       2]
 [1036172       2       1 ...       1       1       2]
 ...
 [ 763235       3       1 ...       1       2       2]
 [ 776715       3       1 ...       1       1       2]
 [ 841769       2   

[[1000025       5       1 ...       1       1       2]
 [1015425       3       1 ...       1       1       2]
 [1017023       4       1 ...       1       1       2]
 ...
 [ 763235       3       1 ...       1       2       2]
 [ 776715       3       1 ...       1       1       2]
 [ 841769       2       1 ...       1       1       2]]
[[1113483       5       2       3       1       6      10       5       1
        1       4]
 [1168359       8       2       3       1       6       3       7       1
        1       4]
 [1177512       1       1       1       1      10       1       1       1
        1       2]
 [1187457       3       1       1       3       8       1       5       8
        1       2]
 [1189266       7       2       4       1       6      10       5       4
        3       4]]
406 , 12
404 , 9
2 , 3
[[1000025       5       1 ...       1       1       2]
 [1015425       3       1 ...       1       1       2]
 [1017023       4       1 ...       1       1       2]
 ...
 [ 76

6 , 8
0 , 0
[[1081791       6       2       1       1       1       1       7       1
        1       2]
 [1113038       8       2       4       1       5       1       5       4
        4       4]
 [1113483       5       2       3       1       6      10       5       1
        1       4]
 [1137156       2       2       2       1       1       1       7       1
        1       2]
 [1147044       3       1       1       1       2       2       7       1
        1       2]
 [1168359       8       2       3       1       6       3       7       1
        1       4]
 [1187457       3       1       1       3       8       1       5       8
        1       2]
 [1189266       7       2       4       1       6      10       5       4
        3       4]
 [ 128059       1       1       1       1       2       5       5       1
        1       2]
 [ 543558       6       1       3       1       4       5       5      10
        1       4]
 [  63375       9       1       2       6       4      10 

26 , 1
26 , 1
0 , 0
[[1017023       4       1       1       3       2       1       3       1
        1       2]
 [1106095       4       1       1       3       2       1       3       1
        1       2]
 [1213383       5       1       1       4       2       1       3       1
        1       2]
 [1226012       4       1       1       3       1       5       2       1
        1       4]
 [1270479       5       1       3       3       2       2       2       3
        1       2]
 [1276091       3       1       1       3       1       1       3       1
        1       2]
 [ 836433       5       1       1       3       2       1       1       1
        1       2]
 [ 657753       3       1       1       4       3       1       2       2
        1       2]
 [1276091       5       1       1       3       4       1       3       2
        1       2]
 [ 558538       4       1       3       3       2       1       1       1
        1       2]
 [ 734111       1       1       1       3       2 

        1       2]]
[]
2 , 1
2 , 1
0 , 0
[[1226012       4       1       1       3       1       5       2       1
        1       4]
 [1276091       3       1       1       3       1       1       3       1
        1       2]
 [1201870       4       1       1       3       1       1       2       1
        1       2]]
[]
2 , 1
2 , 1
0 , 0
[[1226012       4       1       1       3       1       5       2       1
        1       4]
 [1276091       3       1       1       3       1       1       3       1
        1       2]
 [1201870       4       1       1       3       1       1       2       1
        1       2]]
[]
2 , 1
2 , 1
0 , 0
[[1226012       4       1       1       3       1       5       2       1
        1       4]
 [1276091       3       1       1       3       1       1       3       1
        1       2]
 [1201870       4       1       1       3       1       1       2       1
        1       2]]
[]
2 , 1
2 , 1
0 , 0
[[1226012       4       1       1       3       1       

6 , 2
4 , 1
2 , 1
[[1065726       5       2       3       4       2       7       3       6
        1       4]
 [1208301       1       2       3       1       2       1       3       1
        1       2]
 [ 428903       7       2       4       1       3       4       3       3
        1       4]
 [1158405       1       2       3       1       2       1       2       1
        1       2]
 [1206314       1       2       3       1       2       1       1       1
        1       2]
 [1334659       5       2       4       1       1       1       1       1
        1       2]
 [1225382       6       2       3       1       2       1       1       1
        1       2]
 [1275807       4       2       4       3       2       2       2       1
        1       2]]
[]
6 , 2
6 , 2
0 , 0
[[1065726       5       2       3       4       2       7       3       6
        1       4]
 [1208301       1       2       3       1       2       1       3       1
        1       2]
 [ 428903       7       2     

6 , 2
0 , 0
6 , 2
[[1187457       3       1       1       3       8       1       5       8
        1       2]
 [ 128059       1       1       1       1       2       5       5       1
        1       2]
 [ 673637       3       1       1       1       2       5       5       1
        1       2]
 [ 752904      10       1       1       1       2      10       5       4
        1       4]]
[[1081791       6       2       1       1       1       1       7       1
        1       2]
 [1137156       2       2       2       1       1       1       7       1
        1       2]
 [1147044       3       1       1       1       2       2       7       1
        1       2]
 [  63375       9       1       2       6       4      10       7       7
        2       4]]
6 , 2
3 , 1
3 , 1
[[1187457       3       1       1       3       8       1       5       8
        1       2]
 [ 128059       1       1       1       1       2       5       5       1
        1       2]
 [ 673637       3       1       

[]
2 , 1
2 , 1
0 , 0
[[128059      1      1      1      1      2      5      5      1      1
       2]
 [673637      3      1      1      1      2      5      5      1      1
       2]
 [752904     10      1      1      1      2     10      5      4      1
       4]]
[]
2 , 1
2 , 1
0 , 0
[[128059      1      1      1      1      2      5      5      1      1
       2]
 [673637      3      1      1      1      2      5      5      1      1
       2]
 [752904     10      1      1      1      2     10      5      4      1
       4]]
[]
2 , 1
2 , 1
0 , 0
[[128059      1      1      1      1      2      5      5      1      1
       2]
 [673637      3      1      1      1      2      5      5      1      1
       2]
 [752904     10      1      1      1      2     10      5      4      1
       4]]
[]
2 , 1
2 , 1
0 , 0
[[128059      1      1      1      1      2      5      5      1      1
       2]
 [673637      3      1      1      1      2      5      5      1      1
       2]
 [752904   

[]
[[1016277       6       8 ...       7       1       2]
 [1017122       8      10 ...       7       1       4]
 [1044572       8       7 ...       5       4       4]
 ...
 [ 888820       5      10 ...      10       2       4]
 [ 897471       4       8 ...       6       1       4]
 [ 897471       4       8 ...       4       1       4]]
3 , 172
0 , 0
3 , 172
[[1123061       6      10       2       8      10       2       7       8
       10       4]]
[[1016277       6       8 ...       7       1       2]
 [1017122       8      10 ...       7       1       4]
 [1044572       8       7 ...       5       4       4]
 ...
 [ 888820       5      10 ...      10       2       4]
 [ 897471       4       8 ...       6       1       4]
 [ 897471       4       8 ...       4       1       4]]
3 , 172
0 , 1
3 , 171
[[1091262       2       5       3       3       6       7       7       5
        1       4]
 [1123061       6      10       2       8      10       2       7       8
       10       4]
 

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



3 , 2
3 , 2
0 , 0
[[1016277       6       8       8       1       3       4       3       7
        1       2]
 [1105257       3       7       7       4       4       9       4       8
        1       4]
 [ 242970       5       7       7       1       5       8       3       4
        1       2]
 [1293439       6       9       7       5       5       8       4       2
        1       2]
 [1268275       9       8       8       9       6       3       4       1
        1       4]]
[]
3 , 2
3 , 2
0 , 0
[[1016277       6       8       8       1       3       4       3       7
        1       2]
 [1105257       3       7       7       4       4       9       4       8
        1       4]
 [ 242970       5       7       7       1       5       8       3       4
        1       2]
 [1293439       6       9       7       5       5       8       4       2
        1       2]
 [1268275       9       8       8       9       6       3       4       1
        1       4]]
[]
3 , 2
3 , 2
0 , 0
[[101627

In [22]:
print(entropy(data))

444 , 239
0.9340026588217948


In [30]:
get_best_split(data, feature_list, threshold_list)

[[1033078       2       1 ...       1       5       2]
 [1070935       1       1 ...       1       1       2]
 [1156948       3       1 ...       1       1       2]
 ...
 [ 714039       3       1 ...       1       1       2]
 [ 776715       3       1 ...       1       1       2]
 [ 841769       2       1 ...       1       1       2]]
[[1000025       5       1 ...       1       1       2]
 [1002945       5       4 ...       2       1       2]
 [1015425       3       1 ...       1       1       2]
 ...
 [ 888820       5      10 ...      10       2       4]
 [ 897471       4       8 ...       6       1       4]
 [ 897471       4       8 ...       4       1       4]]
444 , 239
148 , 2
296 , 237
[[1033078       2       1 ...       1       5       2]
 [1033078       4       2 ...       1       1       2]
 [1036172       2       1 ...       1       1       2]
 ...
 [ 763235       3       1 ...       1       2       2]
 [ 776715       3       1 ...       1       1       2]
 [ 841769       2   

(3, 2, 2, 2)

In [41]:
print(infogain(data, part_one_feature[-1], 9))

[[1000025       5       1 ...       1       1       2]
 [1015425       3       1 ...       1       1       2]
 [1016277       6       8 ...       7       1       2]
 ...
 [ 888820       5      10 ...      10       2       4]
 [ 897471       4       8 ...       6       1       4]
 [ 897471       4       8 ...       4       1       4]]
[[1002945       5       4 ...       2       1       2]
 [1017122       8      10 ...       7       1       4]
 [1018099       1       1 ...       1       1       2]
 ...
 [1334015       7       8 ...       2       3       4]
 [1369821      10      10 ...      10       7       4]
 [1371026       5      10 ...       6       3       4]]
444 , 239
441 , 110
3 , 129
0.32193988700966986


In [25]:
with open('q5.txt', 'w') as f:
    print_tree(root, f)

In [26]:
maxDepth(root)

13

In [50]:
with open('q7.txt', 'w') as f:
    for j in test_data:
        result = str(tree_prediction(root, j))+","
        f.write(result)

In [52]:
with open('q8.txt', 'w') as f:
    print_tree(root, f)

In [53]:
with open('q9.txt', 'w') as f:
    for j in test_data:
        result = str(tree_prediction(root, j))+","
        f.write(result)