#### 决策树算法  
#### 信息增益：   
**输入**：训练数据集$D$和特征$A$；  
**输出**：特征$A$对训练数据集$D$的信息增益$g(D,A)$。  
（1）计算数据集$D$的经验熵$H(D)$  
&emsp;&emsp;&emsp;$H(D)=-\sum_{k=1}^K\frac{|C_k|}{|D|}\log_2\frac{|C_k|}{|D|}$  
（3）计算信息增益  
&emsp;&emsp;&emsp;$g(D,A)=H(D)-H(D|A)$

#### 加载文件  
$fileName$: 要加载的文件路径  
*return*: 数据集和标签集

In [11]:
import numpy as np

In [12]:
def loadData(fileName):
    dataArr = []; labelArr = []
    fr = open(fileName)
    #遍历文件中的每一行
    for line in fr.readlines():
        curLine = line.strip().split(',')
        dataArr.append([int(int(num) > 128) for num in curLine[1:]])
        labelArr.append(int(curLine[0]))
    #返回数据集和标记
    return dataArr, labelArr

#### 找到当前标签集中占数目最大的标签  
$labelArr$: 标签集  
*return*: 最大的标签

In [13]:
def majorClass(labelArr):
    #建立字典，用于不同类别的标签技术
    classDict = {}
    #遍历所有标签
    for i in range(len(labelArr)):
        if labelArr[i] in classDict.keys():
            # 若在字典中存在该标签，则直接加1
            classDict[labelArr[i]] += 1
        else:
            #若无该标签，设初值为1，表示出现了1次了
            classDict[labelArr[i]] = 1
    #对字典依据值进行降序排序
    classSort = sorted(classDict.items(), key=lambda x: x[1], reverse=True)
    #返回最大一项的标签，即占数目最多的标签
    return classSort[0][0]

#### 计算数据集D的经验熵  
$trainLabelArr$: 当前数据集的标签集  
*return*: 经验熵

In [14]:
def calc_H_D(trainLabelArr):
    #初始化为0
    H_D = 0
    trainLabelSet = set([label for label in trainLabelArr])
    #遍历每一个出现过的标签
    for i in trainLabelSet:
        #计算|Ck|/|D|
        p = trainLabelArr[trainLabelArr == i].size / trainLabelArr.size
        #对经验熵的每一项累加求和
        H_D += -1 * p * np.log2(p)

    #返回经验熵
    return H_D

#### 计算经验条件熵  
$trainDataArr\_DevFeature$:切割后只有feature那列数据的数组  
$trainLabelArr$: 标签集数组  
*return*: 经验条件熵

In [15]:
def calcH_D_A(trainDataArr_DevFeature, trainLabelArr):
    H_D_A = 0
    #在featue那列放入集合中，是为了根据集合中的数目知道该feature目前可取值数目是多少
    trainDataSet = set([label for label in trainDataArr_DevFeature])
    #对于每一个特征取值遍历计算条件经验熵的每一项
    for i in trainDataSet:
        #计算H(D|A)
        H_D_A += trainDataArr_DevFeature[trainDataArr_DevFeature == i].size / trainDataArr_DevFeature.size \
                * calc_H_D(trainLabelArr[trainDataArr_DevFeature == i])
    #返回得出的条件经验熵
    return H_D_A

#### 计算信息增益最大的特征
$trainDataList$: 当前数据集  
$trainLabelList$: 当前标签集  
*return*: 信息增益最大的特征及最大信息增益值

In [16]:
def calcBestFeature(trainDataList, trainLabelList):
    trainDataArr = np.array(trainDataList)
    trainLabelArr = np.array(trainLabelList)
    
    featureNum = trainDataArr.shape[1]
    
    maxG_D_A = -1
    maxFeature = -1
    #1.计算数据集D的经验熵H(D)
    H_D = calc_H_D(trainLabelArr)
    #对每一个特征进行遍历计算
    for feature in range(featureNum):
        #2.计算条件经验熵H(D|A)
        trainDataArr_DevideByFeature = np.array(trainDataArr[:, feature].flat)
        #3.计算信息增益G(D|A)    G(D|A) = H(D) - H(D | A)
        G_D_A = H_D - calcH_D_A(trainDataArr_DevideByFeature, trainLabelArr)
        if G_D_A > maxG_D_A:
            maxG_D_A = G_D_A
            maxFeature = feature
            
    return maxFeature, maxG_D_A

#### 更新数据集和标签集  
$trainDataArr$: 要更新的数据集  
$trainLabelArr$: 要更新的标签集  
$A$: 要去除的特征索引  
$a$: 当data[A]== a时，说明该行样本时要保留的  
*return*: 新的数据集和标签集

In [17]:
def getSubDataArr(trainDataArr, trainLabelArr, A, a):
    retDataArr = []
    retLabelArr = []
    #对当前数据的每一个样本进行遍历
    for i in range(len(trainDataArr)):
        #当前样本的特征为指定特征值a
        if trainDataArr[i][A] == a:
            retDataArr.append(trainDataArr[i][0:A] + trainDataArr[i][A+1:])
            retLabelArr.append(trainLabelArr[i])
    #返回新的数据集和标签集
    return retDataArr, retLabelArr

#### 递归创建决策树  
$dataSet$: (trainDataList， trainLabelList) <<-- 元祖形式  
*return*: 新的子节点或该叶子节点的值

In [18]:
def createTree(*dataSet):
    #设置Epsilon
    Epsilon = 0.1
    #从参数中获取trainDataList和trainLabelList
    trainDataList = dataSet[0][0]
    trainLabelList = dataSet[0][1]
    
    print('start a node', len(trainDataList[0]), len(trainLabelList))

    classDict = {i for i in trainLabelList}

    if len(classDict) == 1:
        return trainLabelList[0]

    if len(trainDataList[0]) == 0:
        return majorClass(trainLabelList)

    Ag, EpsilonGet = calcBestFeature(trainDataList, trainLabelList)

    if EpsilonGet < Epsilon:
        return majorClass(trainLabelList)

    treeDict = {Ag:{}}
    #特征值为0时，进入0分支
    treeDict[Ag][0] = createTree(getSubDataArr(trainDataList, trainLabelList, Ag, 0))
    treeDict[Ag][1] = createTree(getSubDataArr(trainDataList, trainLabelList, Ag, 1))

    return treeDict

#### 预测标签  
$testDataList$: 样本  
$param tree$: 决策树  
*return*: 预测结果

In [19]:
def predict(testDataList, tree):
    #循环，直到找到一个有效地分类
    while True:
        (key, value), = tree.items()
        #如果当前的value是字典，说明还需要遍历下去
        if type(tree[key]).__name__ == 'dict':
            #获取目前所在节点的feature值，需要在样本中删除该feature
            dataVal = testDataList[key]
            del testDataList[key]
            #将tree更新为其子节点的字典
            tree = value[dataVal]
            #如果当前节点的子节点的值是int，就直接返回该int值
            if type(tree).__name__ == 'int':
                #返回该节点值，也就是分类值
                return tree
        else:
            #如果当前value不是字典，那就返回分类值
            return value

#### 测试准确率  
$testDataList$:待测试数据集  
$testLabelList$: 待测试标签集  
$tree: 训练集生成的树  
*return*: 准确率

In [20]:
def model_test(testDataList, testLabelList, tree):
    #错误次数计数
    errorCnt = 0
    #遍历测试集中每一个测试样本
    for i in range(len(testDataList)):
        #判断预测与标签中结果是否一致
        if testLabelList[i] != predict(testDataList[i], tree):
            errorCnt += 1
    #返回准确率
    return 1 - errorCnt / len(testDataList)

#### 开始实验

In [21]:
# 获取训练集
trainDataList, trainLabelList = loadData('Mnist/mnist_train.csv')

# 获取测试集
testDataList, testLabelList = loadData('Mnist/mnist_test.csv')

#创建决策树
print('create decision tree')
tree = createTree((trainDataList, trainLabelList))
print('tree is:', tree)

#测试准确率
print('start test')
accuracy = model_test(testDataList, testLabelList, tree)
print('the accuracy is:', accuracy)

create decision tree
start a node 784 60000
start a node 783 33587
start a node 782 23938
start a node 781 18700
start a node 780 11336
start a node 779 8724
start a node 778 7677
start a node 777 6831
start a node 776 6242
start a node 775 5840
start a node 775 402
start a node 774 220
start a node 773 171
start a node 772 142
start a node 771 133
start a node 770 128
start a node 769 125
start a node 768 123
start a node 768 2
start a node 767 1
start a node 767 1
start a node 769 3
start a node 768 1
start a node 768 2
start a node 770 5
start a node 769 2
start a node 769 3
start a node 768 2
start a node 768 1
start a node 771 9
start a node 770 6
start a node 769 5
start a node 769 1
start a node 770 3
start a node 769 2
start a node 768 1
start a node 768 1
start a node 769 1
start a node 772 29
start a node 771 8
start a node 770 3
start a node 769 1
start a node 769 2
start a node 770 5
start a node 771 21
start a node 770 7
start a node 769 6
start a node 769 1
start a node 7

start a node 774 81
start a node 773 76
start a node 772 73
start a node 771 71
start a node 771 2
start a node 772 3
start a node 771 2
start a node 770 1
start a node 770 1
start a node 771 1
start a node 773 5
start a node 774 29
start a node 773 23
start a node 772 17
start a node 771 14
start a node 770 12
start a node 770 2
start a node 771 3
start a node 770 2
start a node 770 1
start a node 772 6
start a node 771 5
start a node 771 1
start a node 773 6
start a node 772 4
start a node 772 2
start a node 771 1
start a node 771 1
start a node 776 367
start a node 775 140
start a node 774 46
start a node 773 34
start a node 772 27
start a node 771 23
start a node 770 21
start a node 769 20
start a node 769 1
start a node 770 2
start a node 771 4
start a node 770 2
start a node 770 2
start a node 769 1
start a node 769 1
start a node 772 7
start a node 771 5
start a node 770 2
start a node 770 3
start a node 771 2
start a node 773 12
start a node 772 8
start a node 771 7
start a nod

start a node 775 169
start a node 774 115
start a node 773 98
start a node 772 96
start a node 772 2
start a node 773 17
start a node 772 10
start a node 771 5
start a node 770 2
start a node 770 3
start a node 769 2
start a node 769 1
start a node 771 5
start a node 772 7
start a node 771 6
start a node 771 1
start a node 774 54
start a node 773 18
start a node 772 11
start a node 771 5
start a node 770 2
start a node 769 1
start a node 769 1
start a node 770 3
start a node 771 6
start a node 770 4
start a node 770 2
start a node 772 7
start a node 773 36
start a node 772 28
start a node 771 18
start a node 770 1
start a node 770 17
start a node 771 10
start a node 770 5
start a node 770 5
start a node 772 8
start a node 771 6
start a node 771 2
start a node 770 1
start a node 770 1
start a node 775 35
start a node 774 19
start a node 773 8
start a node 772 1
start a node 772 7
start a node 773 11
start a node 772 10
start a node 772 1
start a node 774 16
start a node 773 10
start a n

start a node 772 3
start a node 772 4
start a node 771 2
start a node 771 2
start a node 773 12
start a node 772 3
start a node 771 2
start a node 770 1
start a node 770 1
start a node 771 1
start a node 772 9
start a node 774 12
start a node 773 6
start a node 772 4
start a node 772 2
start a node 773 6
start a node 772 3
start a node 771 2
start a node 770 1
start a node 770 1
start a node 771 1
start a node 772 3
start a node 776 65
start a node 775 25
start a node 774 18
start a node 773 8
start a node 772 7
start a node 772 1
start a node 773 10
start a node 772 8
start a node 771 7
start a node 771 1
start a node 772 2
start a node 774 7
start a node 773 3
start a node 772 2
start a node 772 1
start a node 773 4
start a node 775 40
start a node 774 21
start a node 773 4
start a node 772 3
start a node 772 1
start a node 773 17
start a node 772 13
start a node 772 4
start a node 771 2
start a node 771 2
start a node 770 1
start a node 770 1
start a node 774 19
start a node 773 16


start a node 773 3
start a node 772 2
start a node 772 1
start a node 773 2
start a node 775 20
start a node 774 12
start a node 773 5
start a node 772 3
start a node 772 2
start a node 773 7
start a node 774 8
start a node 773 3
start a node 772 2
start a node 771 1
start a node 771 1
start a node 772 1
start a node 773 5
start a node 778 2191
start a node 777 1960
start a node 776 796
start a node 775 569
start a node 774 377
start a node 773 245
start a node 772 211
start a node 771 180
start a node 770 39
start a node 769 31
start a node 768 16
start a node 767 9
start a node 766 5
start a node 765 3
start a node 764 2
start a node 764 1
start a node 765 2
start a node 766 4
start a node 765 3
start a node 765 1
start a node 767 7
start a node 766 6
start a node 766 1
start a node 768 15
start a node 767 13
start a node 767 2
start a node 766 1
start a node 766 1
start a node 769 8
start a node 770 141
start a node 769 23
start a node 768 8
start a node 767 2
start a node 767 6
sta

start a node 776 178
start a node 775 165
start a node 774 152
start a node 773 145
start a node 773 7
start a node 772 6
start a node 772 1
start a node 774 13
start a node 773 7
start a node 772 4
start a node 772 3
start a node 771 2
start a node 771 1
start a node 773 6
start a node 775 13
start a node 774 10
start a node 773 9
start a node 773 1
start a node 774 3
start a node 773 1
start a node 773 2
start a node 776 48
start a node 775 21
start a node 774 14
start a node 773 12
start a node 773 2
start a node 774 7
start a node 773 4
start a node 772 2
start a node 772 2
start a node 771 1
start a node 771 1
start a node 773 3
start a node 772 1
start a node 772 2
start a node 775 27
start a node 774 13
start a node 773 1
start a node 773 12
start a node 774 14
start a node 773 12
start a node 772 11
start a node 772 1
start a node 773 2
start a node 779 555
start a node 778 284
start a node 777 198
start a node 776 161
start a node 775 149
start a node 774 42
start a node 773 3

start a node 774 37
start a node 773 4
start a node 772 2
start a node 771 1
start a node 771 1
start a node 772 2
start a node 773 33
start a node 774 18
start a node 773 8
start a node 772 3
start a node 772 5
start a node 771 2
start a node 771 3
start a node 770 2
start a node 770 1
start a node 773 10
start a node 772 5
start a node 772 5
start a node 771 3
start a node 771 2
start a node 770 1
start a node 770 1
start a node 775 72
start a node 774 29
start a node 773 16
start a node 772 9
start a node 771 8
start a node 771 1
start a node 772 7
start a node 771 5
start a node 771 2
start a node 770 1
start a node 770 1
start a node 773 13
start a node 772 10
start a node 771 2
start a node 770 1
start a node 770 1
start a node 771 8
start a node 772 3
start a node 774 43
start a node 773 36
start a node 772 35
start a node 772 1
start a node 773 7
start a node 772 4
start a node 771 3
start a node 771 1
start a node 772 3
start a node 777 190
start a node 776 163
start a node 77

start a node 774 136
start a node 773 128
start a node 772 105
start a node 771 96
start a node 771 9
start a node 770 4
start a node 769 3
start a node 769 1
start a node 770 5
start a node 772 23
start a node 771 21
start a node 770 11
start a node 769 7
start a node 768 2
start a node 768 5
start a node 769 4
start a node 770 10
start a node 771 2
start a node 770 1
start a node 770 1
start a node 773 8
start a node 772 7
start a node 772 1
start a node 774 13
start a node 773 12
start a node 773 1
start a node 775 43
start a node 774 12
start a node 773 8
start a node 773 4
start a node 772 3
start a node 772 1
start a node 774 31
start a node 773 28
start a node 773 3
start a node 772 2
start a node 771 1
start a node 771 1
start a node 772 1
start a node 776 144
start a node 775 31
start a node 774 10
start a node 773 3
start a node 772 2
start a node 771 1
start a node 771 1
start a node 772 1
start a node 773 7
start a node 774 21
start a node 773 20
start a node 773 1
start a 

start a node 783 26413
start a node 782 15142
start a node 781 11780
start a node 780 6873
start a node 779 3386
start a node 778 2344
start a node 777 1339
start a node 776 1178
start a node 775 1089
start a node 774 789
start a node 773 706
start a node 772 658
start a node 772 48
start a node 771 30
start a node 770 12
start a node 769 7
start a node 769 5
start a node 768 2
start a node 767 1
start a node 767 1
start a node 768 3
start a node 770 18
start a node 769 16
start a node 768 15
start a node 768 1
start a node 769 2
start a node 771 18
start a node 770 6
start a node 769 2
start a node 769 4
start a node 770 12
start a node 769 11
start a node 769 1
start a node 773 83
start a node 772 61
start a node 771 54
start a node 770 31
start a node 769 23
start a node 768 18
start a node 767 15
start a node 767 3
start a node 768 5
start a node 767 1
start a node 767 4
start a node 769 8
start a node 768 6
start a node 768 2
start a node 770 23
start a node 771 7
start a node 770

start a node 771 3
start a node 770 2
start a node 770 1
start a node 771 3
start a node 774 20
start a node 773 2
start a node 772 1
start a node 772 1
start a node 773 18
start a node 772 16
start a node 772 2
start a node 771 1
start a node 771 1
start a node 776 72
start a node 775 33
start a node 774 12
start a node 773 6
start a node 772 2
start a node 772 4
start a node 773 6
start a node 772 3
start a node 771 2
start a node 771 1
start a node 772 3
start a node 774 21
start a node 773 17
start a node 772 12
start a node 772 5
start a node 771 3
start a node 771 2
start a node 773 4
start a node 772 2
start a node 771 1
start a node 771 1
start a node 772 2
start a node 771 1
start a node 771 1
start a node 775 39
start a node 774 29
start a node 773 24
start a node 772 16
start a node 771 10
start a node 771 6
start a node 770 4
start a node 769 1
start a node 769 3
start a node 770 2
start a node 772 8
start a node 771 7
start a node 771 1
start a node 773 5
start a node 772 

start a node 770 1
start a node 770 2
start a node 771 3
start a node 772 6
start a node 771 5
start a node 771 1
start a node 774 114
start a node 773 99
start a node 773 15
start a node 772 3
start a node 772 12
start a node 771 3
start a node 771 9
start a node 770 1
start a node 770 8
start a node 776 36
start a node 775 32
start a node 774 1
start a node 774 31
start a node 775 4
start a node 774 3
start a node 774 1
start a node 777 117
start a node 776 88
start a node 775 50
start a node 774 40
start a node 773 36
start a node 772 28
start a node 772 8
start a node 771 4
start a node 770 3
start a node 770 1
start a node 771 4
start a node 773 4
start a node 772 3
start a node 772 1
start a node 774 10
start a node 773 6
start a node 773 4
start a node 775 38
start a node 774 33
start a node 773 19
start a node 772 3
start a node 771 1
start a node 771 2
start a node 772 16
start a node 773 14
start a node 772 7
start a node 771 3
start a node 771 4
start a node 772 7
start a no

start a node 773 11
start a node 773 12
start a node 772 6
start a node 771 3
start a node 770 2
start a node 770 1
start a node 771 3
start a node 772 6
start a node 776 532
start a node 777 230
start a node 776 156
start a node 775 119
start a node 774 98
start a node 773 91
start a node 772 83
start a node 771 66
start a node 771 17
start a node 770 4
start a node 769 2
start a node 769 2
start a node 768 1
start a node 768 1
start a node 770 13
start a node 769 12
start a node 769 1
start a node 772 8
start a node 771 4
start a node 770 2
start a node 769 1
start a node 769 1
start a node 770 2
start a node 771 4
start a node 770 1
start a node 770 3
start a node 773 7
start a node 772 3
start a node 771 2
start a node 771 1
start a node 772 4
start a node 771 3
start a node 771 1
start a node 774 21
start a node 773 10
start a node 772 8
start a node 771 7
start a node 771 1
start a node 772 2
start a node 773 11
start a node 772 10
start a node 772 1
start a node 775 37
start a n

start a node 780 7620
start a node 779 6749
start a node 778 5818
start a node 777 5524
start a node 777 294
start a node 776 105
start a node 775 81
start a node 774 70
start a node 773 23
start a node 772 17
start a node 772 6
start a node 771 5
start a node 771 1
start a node 773 47
start a node 772 39
start a node 771 25
start a node 770 15
start a node 769 11
start a node 768 9
start a node 768 2
start a node 767 1
start a node 767 1
start a node 769 4
start a node 770 10
start a node 769 5
start a node 768 1
start a node 768 4
start a node 769 5
start a node 768 4
start a node 768 1
start a node 771 14
start a node 770 1
start a node 770 13
start a node 772 8
start a node 771 7
start a node 771 1
start a node 774 11
start a node 773 10
start a node 773 1
start a node 775 24
start a node 774 23
start a node 774 1
start a node 776 189
start a node 775 175
start a node 774 171
start a node 773 28
start a node 772 10
start a node 771 2
start a node 771 8
start a node 770 4
start a no

start a node 774 40
start a node 773 38
start a node 773 2
start a node 772 1
start a node 772 1
start a node 774 1
start a node 775 6
start a node 774 3
start a node 773 2
start a node 773 1
start a node 774 3
start a node 773 1
start a node 773 2
start a node 778 323
start a node 777 146
start a node 776 42
start a node 775 35
start a node 774 33
start a node 773 31
start a node 773 2
start a node 772 1
start a node 772 1
start a node 774 2
start a node 773 1
start a node 773 1
start a node 775 7
start a node 774 4
start a node 773 3
start a node 773 1
start a node 774 3
start a node 773 1
start a node 773 2
start a node 776 104
start a node 775 85
start a node 774 80
start a node 773 3
start a node 773 77
start a node 772 76
start a node 772 1
start a node 774 5
start a node 773 2
start a node 773 3
start a node 772 2
start a node 771 1
start a node 771 1
start a node 772 1
start a node 775 19
start a node 774 7
start a node 773 2
start a node 773 5
start a node 774 12
start a node 

start a node 777 517
start a node 776 458
start a node 775 49
start a node 774 36
start a node 773 8
start a node 772 5
start a node 772 3
start a node 771 1
start a node 771 2
start a node 773 28
start a node 772 25
start a node 772 3
start a node 771 2
start a node 770 1
start a node 770 1
start a node 771 1
start a node 774 13
start a node 773 8
start a node 772 2
start a node 771 1
start a node 771 1
start a node 772 6
start a node 773 5
start a node 772 4
start a node 772 1
start a node 775 409
start a node 776 59
start a node 775 37
start a node 774 16
start a node 773 10
start a node 772 6
start a node 772 4
start a node 771 3
start a node 771 1
start a node 773 6
start a node 772 3
start a node 771 2
start a node 771 1
start a node 772 3
start a node 771 2
start a node 770 1
start a node 770 1
start a node 771 1
start a node 774 21
start a node 773 2
start a node 773 19
start a node 775 22
start a node 774 13
start a node 774 9
start a node 773 2
start a node 772 1
start a node