# 手写决策树

In [1]:
from math import log

## 计算香农熵

In [25]:
def calcShannonEnt(dataSet):
    numEntires = len(dataSet) # 返回数据集的行数
    labelCounts = {}
    for featVec in dataSet: # 对每组特征向量进行统计
        currentLabel = featVec[-1] # 提取Label的值
        if currentLabel not in labelCounts.keys(): # Label没放入统计次数的字典则添加进去
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0 # 香农熵
    for key in labelCounts: # 计算香农熵
        prob = float(labelCounts[key] / numEntires) # 计算Label的概率
        shannonEnt -= prob * log(prob, 2) # 利用公式计算
    return shannonEnt

## 计算每个特征的每种值

In [105]:
def splitDataSet(dataSet, col, value): #把第col列值为value的拿出来
    retDataSet = [] # 创建返回的数据集列表
    for featVec in dataSet:
        if featVec[col] == value:
            #reducedFeatVec = featVec[:col] # 去掉col这一列
            #reducedFeatVec.extend(featVec[col + 1:]) # 将符合条件的添加到返回的数据集
            retDataSet.append(featVec)
    return retDataSet # 返回划分后的数据集

## 计算信息增益

In [113]:
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1 # 特征数量
    baseEntropy = calcShannonEnt(dataSet) # 计算数据集的香农熵
    labelCounts = {} # 保存每个Label出现的次数
    bestInfoGain = 0.0 # 信息增益
    bestFeature = -1 # 最优特征的索引值
    for i in range(numFeatures): # 遍历所有特征，一列一列遍历
        featList = [example[i] for example in dataSet] # 第i列的每行的值都保存到featLIst中
        uniqueVals = set(featList) # 创建set集合，保证元素不重复
        newEntropy = 0.0 # 香农熵
        for value in uniqueVals: # 计算信息增益
            subDataSet = splitDataSet(dataSet, i, value) # subDataSet划分后的子集
            porb = len(subDataSet) / float(len(dataSet)) # 计算子集的概率
            newEntropy += porb * calcShannonEnt(subDataSet) # 根据公式计算香农熵
        infoGain = baseEntropy - newEntropy # 信息增益
        #print("第%d个特征的信息增益为%.3f" % (i,infoGain)) # 打印每个特征的信息增益
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i # 记录增益最大的特征的索引
    return bestFeature

## 投票表决
当特征值划分到只有一个，但是依然无法归为一类时，采用最大投票法，选择最多类别作为该类标签

In [114]:
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key = operator.itemgtter(),reverse = True)
    return sortedClassCount[0][0]

## 核心程序

In [115]:
def createTree(dataSet, labels, featLabels):
    classList = [example[-1] for example in dataSet] # 取分类标签（是否房贷yes or no）
    
    # 两个if是递归终止条件
    if classList.count(classList[0]) == len(classList): # 类别完全相同--第一个终止条件
        return classList[0]
    if len(dataSet[0]) == 1: # 遍历完所有特征时用最大表决法--第二个终止条件
        return majorityCnt(classList)
    
    bestFeat = chooseBestFeatureToSplit(dataSet) # 选择最优特征
    bestFeatLabel = labels[bestFeat] # 取出最优特征对应的Label
    featLabels.append(bestFeatLabel)
    myTree = {bestFeatLabel : {}}
    del (labels[bestFeat]) # 删除该最优特征
    featValues = [example[bestFeat] for example in dataSet] # 取出最优特征每行的值
    uniqueVals = set(featValues) # 去掉重复属性值
    for value in uniqueVals:
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,
                                                              bestFeat, value),
                                                 labels,featLabels)
    return myTree

## 样本集

In [116]:
def createDataSet():
    dataSet = [
        [0,0,0,0,'no'],
        [0,0,0,1,'no'],
        [0,1,0,0,'yes'],
        [0,1,1,0,'yes'],
        [0,0,0,0,'no'],
        [1,0,0,0,'no'],
        [1,0,0,1,'no'],
        [1,1,1,1,'yes'],
        [1,0,1,2,'yes'],
        [1,0,1,2,'yes'],
        [2,0,1,2,'yes'],
        [2,0,1,1,'yes'],
        [2,1,0,1,'yes'],
        [2,1,0,2,'yes'],
        [2,0,0,0,'no'],
    ]
    labels = ['年龄','有工作','有房','信贷情况'] # 年龄离散化
    return dataSet, labels

In [131]:
def classify(inputTree, featLabels, testVec):
    firstStr = next(iter(inputTree)) # 获取决策树结点
    secondDict = inputTree[firstStr] # 下一个字典
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

## 主函数

In [142]:
dataSet, labels = createDataSet()
print("最优特征索引值：" + str(chooseBestFeatureToSplit(dataSet)))
featLabels = []
myTree = createTree(dataSet, labels, featLabels)
print(myTree)
testVec = [0,1] # 没房子，有工作
result = classify(myTree, featLabels, testVec)
if result == 'yes':
    print("放贷")
else:
    print("不放贷")

最优特征索引值：2
{'有房': {0: {'有工作': {0: 'no', 1: 'yes'}}, 1: 'yes'}}
放贷
