In [1]:
# -*- coding: utf-8 -*
import numpy as np
import pandas as pd

In [2]:
df = pd.read_csv('watermelon3_0_En.csv', encoding = "utf-8")
df = df.drop(['No.','Density','SugerRatio'],axis=1)
print(df)

    Color       Root   Knocks   Texture Umbilicus Touch  Label
0   green       curl  heavily  distinct    sunken  hard      1
1   black       curl     dull  distinct    sunken  hard      1
2   black       curl  heavily  distinct    sunken  hard      1
3   green       curl     dull  distinct    sunken  hard      1
4   white       curl  heavily  distinct    sunken  hard      1
5   green  lightCurl  heavily  distinct    dimple  soft      1
6   black  lightCurl  heavily      blur    dimple  soft      1
7   black  lightCurl  heavily  distinct    dimple  hard      1
8   black  lightCurl     dull      blur    dimple  hard      0
9   green      stiff    clear  distinct    smooth  soft      0
10  white      stiff    clear     fuzzy    smooth  hard      0
11  white       curl  heavily     fuzzy    smooth  soft      0
12  green  lightCurl  heavily      blur    sunken  hard      0
13  white  lightCurl     dull      blur    sunken  hard      0
14  black  lightCurl  heavily  distinct    dimple  soft

In [3]:
'''
结点类：

@para attribute：结点要划分的属性
@para label：结点的类别
@para attribute_next：下一个结点
'''
class Node(object):
    def __init__(self, attribute=None, label=None, node_next={}):
        self.attribute = attribute
        self.label = label
        self.node_next = node_next
        self.leaf = False
    

In [4]:
'''
统计数据集中样例各个类别的数量，在西瓜数据3.0中即为是和否的数量：

@para   df：数据集
@return label_count：各个类别的数量：label_count[0]:正例数量，label_count[1]反例数量
'''
def LabelCount(df):
    label_count=[]
    label_count.append(df[df['Label']==1]['Label'].count())
    label_count.append(df[df['Label']==0]['Label'].count())
    return label_count

'''
样本数据比较

@para    df：数据集
@return  True：表示数据集中所有样本在所有属性上取值相同
         False：数据集中又不同取值
'''
def DataCompare(df, A):
    if len(df.shape)==0 :
        return True
    a = list(attr_set.keys())
    for row in range(df.shape[0]-1): 
        com  = df[a].iloc[0] == df[a].iloc[row+1]#iloc:索引第i行；loc：索引行标签为i的行
        if(com[com==True].count()<com.count()):
            return False
    return True
'''
创建属性集

@para    df：数据集
@return  attr_set:属性集
                 字典格式如：{'敲声': ['浊响', '沉闷', '清脆']}
'''
def CreateAttrSet(df):
    attr_set = {}
    attr_name = df.columns
    for i in attr_name[:6]:
        attr_set[i] = df[i].unique()
    return attr_set

In [5]:
'''
计算基尼值:西瓜书公式4.5

@para    df：数据集
@return  gini:数据集df的基尼值
'''
def GetGini(df):
    label_count = LabelCount(df)
    gini_temp = 0.0
    for k in label_count:
        if len(df)!=0:
            gini_temp += np.square(k/len(df))
    gini = 1-gini_temp
    return gini
'''
计算属性attr的基尼指数：西瓜书公式4.6

@para    df：数据集
@para    attr:属性
@return  gini_index:属性attr的基尼指数
'''
def GetGini_index(df,attr):
    gini_index = 0.0
    for v in attr[1]:
        df_v = df[df[attr[0]]==v]
        gini_index += len(df_v)/len(df)*GetGini(df_v)
    return gini_index
'''
基于基尼指数属性选择

@para    df：数据集
@para    attr_set:属性集
@return  best_attr:属性集中基尼指数最小的属性
'''
def AttrSelectBaseGainIndex(df,attr_set):
    best_attr = None
    best_gini_index = 100
    for attr in attr_set.items():
        gini_index = GetGini_index(df,attr)
        
        if best_gini_index>=gini_index:
            best_gini_index = gini_index
            best_attr = attr
    return best_attr

In [6]:
'''
测试决策树绘图
'''

def DrawPNG(root, out_file):
    '''
    visualization of decision tree from root.
    @param root: Node, the root node for tree.
    @param out_file: str, name and path of output file
    '''
    try:
        from pydotplus import graphviz
    except ImportError:
        print("module pydotplus.graphviz not found")
        
    g = graphviz.Dot()  # generation of new dot   

    TreeToGraph(0, g, root)
    g2 = graphviz.graph_from_dot_data( g.to_string() )
                                                                                            
    g2.write_png(out_file) 

def TreeToGraph(i, g, root):
    '''
    build a graph from root on
    @param i: node number in this tree
    @param g: pydotplus.graphviz.Dot() object
    @param root: the root node
    
    @return i: node number after modified  
#   @return g: pydotplus.graphviz.Dot() object after modified
    @return g_node: the current root node in graphviz
    '''
    try:
        from pydotplus import graphviz
    except ImportError:
        print("module pydotplus.graphviz not found")
    
    if root.attribute == None:
        g_node_label = "Node:%d\nLeaf:%d\nAttr:%s\nLabel:%s" % (i, root.leaf,root.attribute,root.label)
    else:
        g_node_label = "Node:%d\nLeaf:%d\nAttr:%s\nLabel:%s" % (i, root.leaf,root.attribute,root.label)
    g_node = i
    g.add_node( graphviz.Node( g_node, label = g_node_label ))
    
    for value in list(root.node_next):
        i, g_child = TreeToGraph(i+1, g, root.node_next[value])
        g.add_edge( graphviz.Edge(g_node, g_child, label = value) ) 

    return i, g_node

In [7]:
'''
创建决策树

@para    data_set：数据集
@para    attr_set:属性集
@return  node:树的根节点

'''
def CreateDecisionTree(data_set, attr_set):
    node = Node(None, None, {})
    label_count = LabelCount(data_set)
    #情形1：数据集中包含的样本全属于同一类时，递归结束，返回
    if(label_count[0]==data_set['Label'].count()):
        node.label = 1
        node.leaf = True
        return node
    if(label_count[1]==data_set['Label'].count()):
        node.label = 0
        node.leaf = True
        return node
    #情形2：属性集为空或样本在所有属性上取值相同，无法划分
    if(len(attr_set)==0 or DataCompare(data_set, attr_set)==True):
        if(label_count[0]>label_count[1]):
            node.label = 1
            node.leaf = True
        else:
            node.label = 0
            node.leaf = True
        return node
    
    #划分选择，挑选最优的属性作为当前结点
    best_attr = AttrSelectBaseGainIndex(data_set,attr_set)
    node.attribute = best_attr[0]
    for i in best_attr[1]:
        node.node_next[i]=Node(None, Node, {})
        data_set_a = data_set[data_set[best_attr[0]]==i]#挑选数据集中best_attr[0]属性中包含i值的样例
        
        if(data_set_a.empty):
            if(label_count[0]>label_count[1]):
                node.node_next[i].label = 1
                node.node_next[i].leaf = True
            else:
                node.node_next[i].label = 0
                node.node_next[i].leaf = True
        else:
            attr_set.pop(best_attr[0],best_attr[1])
            node.node_next[i] = CreateDecisionTree(data_set_a, attr_set)  
    return node


In [8]:
attr_set = CreateAttrSet(df)
root = CreateDecisionTree(df, attr_set)


In [9]:
DrawPNG(root, "decision_tree_ID3.png")

后剪枝处理

In [10]:
'''
剪枝后结点的标签

@para    df：数据集
@return  label:标签
'''
def SelectLabel(df):
    #print(df[attr])
    count = LabelCount(df)
    if count[0]>count[1]:
        return 1
    else:
        return 0

In [11]:
'''
预测一组数据的标签

@para    predict_data:测试数据
@para    node:决策树的跟结点
@return  label:决策树对一组数据进行验证的标签
'''
def PredictOneData(predict_data,node):
    if(node.attribute==None):
        return node.label
    else:
        attr_val = predict_data[node.attribute]
        return PredictOneData(predict_data,node.node_next[attr_val])

'''
对测试集进行预测，并计算精度

@para    test_set:测试集
@para    root:决策树的跟
@return  acc:决策树对测试集进行验证的精度
'''
def PredictTree(test_set,root):
    predict_label = []
    for i in range(len(test_set)):
        #print(test_set.iloc[i])
        label = PredictOneData(test_set.iloc[i],root)
        predict_label.append(label)
    test_label = test_set['Label']
    acc = 1 - np.abs(test_label-predict_label).sum()/len(predict_label)
    return acc


In [12]:
PredictTree(df,root)

0.9411764705882353

In [16]:
'''
对测试集进行预测，并计算精度

@para    test_set:测试集
@para    root:决策树的跟
@para    node:决策树的跟
'''
def AfterPrune(test_set,root,node):
    for i in node.node_next:
        AfterPrune(test_set,root,node.node_next[i])
    if not node.leaf:
        a = PredictTree(test_set,root)
        print(a)
        node_temp = node.node_next
        attr_temp = node.attribute
        node.node_next = {}
        node.attribute = None
        node.leaf = True
        node.label = SelectLabel(test_set)
        print('Prune:',PredictTree(test_set,root))
        if a>PredictTree(test_set,root):
            node.attribute = attr_temp
            node.node_next = node_temp
            node.label = None
            node.leaf = False

In [19]:
AfterPrune(df,root,root)

0.9411764705882353
Prune: 0.8823529411764706
0.9411764705882353
Prune: 0.8823529411764706
0.9411764705882353
Prune: 0.5294117647058824
0.9411764705882353
Prune: 0.5294117647058824


In [20]:
DrawPNG(root, "decision_tree_Prune.png")