# 构建ID3 决策树
## 目录
* ID3算法简介
* 用自己构建的数据集验证ID3算法
* ID3算法的局限性

## ID3算法简介
构建的树是多叉树，按照某一特征进行切分  
在寻取分类特征时，采用最大信息增益


其中信息增益=父节点的信息熵-子节点的信息熵  
信息熵的计算公式
$$H =\sum\limits_{i=1}^k p_i log(p_i)$$

基尼系数的计算公式
$$G = 1- \sum\limits_{i=1}^k p_i^2$$

算法过程：
* 按最大信息增益选取最佳的切分列
* 将划分的子集作为其子节点
* 递归结束的标志在于 数据子集只有一个特征 或者 全部特征都以用完

局限：
* 往往会选取分支度越高的特征
* 不能直接处理连续型变量
* 对缺失值敏感

In [566]:
import numpy as np
import pandas as pd
from collections import Counter
from graphviz import Digraph
import graphviz
class Decision_tree():
    
    def __init__(self,cal = "Entropy"):
        self.tree = None
        self.cal = cal
        self.columns=None
    
    def calEnt(self,y):
        count = Counter(y)
        p = np.array(list(count.values()))/len(y)
        Ent = (-p* np.log2(p)).sum()
        return Ent
    
    def calGini(self,y):
        count = Counter(y)
        p = np.array(list(count.values()))/len(y)
        Gini = 1-(p**2).sum()
        return Gini
    
    def calimpurity(self,y):
        """根据参数来选择是使用Entropy还是Gini来计算不纯度"""
        if self.cal == "Entropy":
            return self.calEnt(y)
        else:
            return self.calGini(y)
        
    def fit_ID3(self,X,y,featurename):
        
        def bestSplit(X,y):
            """让数据集根据某一个特征值进行划分，返回数据集最佳切分列索引"""
            bestFeature=-1
            bestGain=-1
            baseEnt = self.calimpurity(y)
            
            for i in range(X.shape[1]):
                label = list(Counter(X[:,i]).keys())
                sub_ent = 0
                for j in label:
                    subData = y[X[:,i] == j]
                    childEnt = self.calimpurity(subData)
                    sub_ent += childEnt*len(subData)/len(y)
                sub_ent = baseEnt - sub_ent
#                 print("feature {} Gain is {}".format(i,sub_ent))
                if sub_ent > bestGain: 
                    bestGain = sub_ent
                    bestFeature = i
            return bestFeature
        
        
        def splitData(X,y,feature,label):
            """按指定的特征和标签来划分数据子集"""
            subX = X[X[:,feature] == label]
            suby = y[X[:,feature] == label]
            subX = np.delete(subX,feature,axis=1)
            return subX,suby
        
        
        def createTree(X,y,featurename):
            """用字典的形式保存最终的树"""
            if X.shape[1]==1 or len(list(Counter(y)))==1:#即没有再可以划分的特征，或者子集已经只有一列，则迭代结束
                return Counter(y).most_common(1)[0][0]#返回所占比例最多的类别
            
            bestfeature = bestSplit(X,y)
            bestfeaturename = featurename[bestfeature]
            labellist = set(Counter(X[:,bestfeature]))
            dic = {}
            for label in labellist:
                subX,suby = splitData(X,y,bestfeature,label)
                col = featurename.copy()
                del col[bestfeature]
                
                dic[label] = createTree(subX,suby,col)
            mytree = {bestfeaturename:dic}
            return mytree
         
        self.columns = featurename    
        self.tree = createTree(X,y,featurename)
        return self
        
        
    def _predict(self,test):
        """对单条测试集进行预测"""

        def __predict(tree,test,columns):
            feature = next(iter(tree))
            secondDic = tree[feature]
            index = columns.index(feature)
            content = test[index]
            for key in secondDic:
                if key == content:
                    if type(secondDic[key]) == dict :
                        return __predict(secondDic[key],test,columns)
                    else:
                        return secondDic[key]

        assert self.tree is not None,"fit before predict"
        tree = self.tree
        columns = self.columns
        return __predict(tree,test,columns)
    
    def predict(self,X_test):
        return np.array([self._predict(test) for test in X_test])
            
    def score(self,X_test,y_test):
        """计算模型的准确率"""
        y_predict = self.predict(X_test)
        return (y_test == y_predict).mean()
    
    def draw_tree(self):
        from graphviz import Digraph
        
        def export_graphviz(tree,root_index): 
            root = next(iter(tree))
            text_node.append([str(root_index),root])
            secondDic = tree[root]
            for key in secondDic:
                if type(secondDic[key]) == dict:
                    i[0]+=1
                    secondrootindex=i[0]
                    text_edge.append([str(root_index),str(secondrootindex),str(key)])
                    export_graphviz(secondDic[key],secondrootindex)
                else:
                    i[0] += 1
                    text_node.append([str(i[0]),str(secondDic[key])])
                    text_edge.append([ str(root_index) , str(i[0]) , str(key) ])
          
        
        tree = self.tree
        text_node=[]
        text_edge=[]
        i=[1]
        export_graphviz(tree,i[0])
        dot = Digraph()
        for line in text_node:
            dot.node(line[0],line[1])
        for line in text_edge:
            dot.edge(line[0],line[1],line[2])
        
        dot.view()

## 使用自己建造的数据集验证实现的ID3tree

In [567]:
data = pd.DataFrame([
    ["<=30","high","no","fair","no"],
    ["<=30","high","no","excellent","no"],
    ["31~40","high","no","fair","yes"],
    [">40","medium","no","fair","yes"],
    [">40","low","yes","fair","yes"],
    [">40","low","yes","excellent","no"],
    ["31~40","low","yes","excellent","yes"],
    ["<=30","medium","no","fair","no"],
    ["<=30","low","yes","fair","yes"],
    [">40","medium","yes","fair","yes"],
    ["<=30","medium","yes","excellent","yes"],
    ["31~40","medium","no","excellent","yes"],
    ["31~40","high","yes","fair","yes"],
    [">40","medium","no","excellent","no"]
                 ],columns=["age","income","student","credit_rating","Class"])

In [568]:
X = data.iloc[:,:-1]
y = data.iloc[:,-1]

In [569]:
X

Unnamed: 0,age,income,student,credit_rating
0,<=30,high,no,fair
1,<=30,high,no,excellent
2,31~40,high,no,fair
3,>40,medium,no,fair
4,>40,low,yes,fair
5,>40,low,yes,excellent
6,31~40,low,yes,excellent
7,<=30,medium,no,fair
8,<=30,low,yes,fair
9,>40,medium,yes,fair


In [570]:
clf = Decision_tree()
clf.fit_ID3(np.array(X),np.array(y),list(X.columns))
X = np.array(X)
y = np.array(y)
print(clf.tree)

{'age': {'<=30': {'student': {'yes': 'yes', 'no': 'no'}}, '>40': {'credit_rating': {'excellent': 'no', 'fair': 'yes'}}, '31~40': 'yes'}}


In [571]:
clf.score(np.array(X),y)

1.0

准确率为100%，说明树的建造没有问题

In [573]:
clf.draw_tree()
clf.tree

{'age': {'<=30': {'student': {'yes': 'yes', 'no': 'no'}},
  '>40': {'credit_rating': {'excellent': 'no', 'fair': 'yes'}},
  '31~40': 'yes'}}

## ID3的局限性

In [574]:
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
data = load_wine()
X = data.data
y = data.target
columns = data.feature_names
X_train,X_test,y_train,y_test = train_test_split(X,y)

In [576]:
clf = Decision_tree()
clf.fit_ID3(X_train,y_train,columns)
clf.draw_tree()

In [554]:
list(Counter(y))

['no', 'yes']

这就叫不能直接处理 连续性变量

In [577]:
data = pd.DataFrame([
    [1,"<=30","high","no","fair","no"],
    [2,"<=30","high","no","excellent","no"],
    [3,"31~40","high","no","fair","yes"],
    [4,">40","medium","no","fair","yes"],
    [5,">40","low","yes","fair","yes"],
    [6,">40","low","yes","excellent","no"],
    [7,"31~40","low","yes","excellent","yes"],
    [8,"<=30","medium","no","fair","no"],
    [9,"<=30","low","yes","fair","yes"],
    [10,">40","medium","yes","fair","yes"],
    [11,"<=30","medium","yes","excellent","yes"],
    [12,"31~40","medium","no","excellent","yes"],
    [13,"31~40","high","yes","fair","yes"],
    [14,">40","medium","no","excellent","no"]
                 ],columns=["index","age","income","student","credit_rating","Class"])

In [578]:
X = data.iloc[:,:-1]
y = data.iloc[:,-1]

In [579]:
clf = Decision_tree()
clf.fit_ID3(np.array(X),np.array(y),list(X.columns))
X = np.array(X)
y = np.array(y)
print(clf.tree)

{'index': {1: 'no', 2: 'no', 3: 'yes', 4: 'yes', 5: 'yes', 6: 'no', 7: 'yes', 8: 'no', 9: 'yes', 10: 'yes', 11: 'yes', 12: 'yes', 13: 'yes', 14: 'no'}}


In [580]:
clf.draw_tree()

这就叫往往会选取分支程度较高的