### 实验要求：实现决策树ID3算法，该算法在特征选择时使用的是信息增益。不要求对决策树进行剪枝。

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu May  3 20:03:55 2018

@author: ronething
"""

from math import log
import operator


In [None]:
"""
创建一个简单的数据集。这个数据集根据两个属性来判断一个海洋生物是否属于鱼类，
第一个属性是不浮出水面是否可以生存，第二个属性是是否有鳍。数据集中的第三列是分类结果。
"""
def createDataSet():
    dataSet = [[1, 1, 'yes'],
      			[1, 1, 'yes'],
      			[1, 0, 'no'],
      			[0, 1, 'no'],
      			[0, 1, 'no']]
    labels = ['no surfacing','flippers']
    
    return dataSet,labels

In [None]:
# 编写函数计算熵
def calcEntropy(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    
    entropy = 0.0
    for key in labelCounts.keys():
        pxi = float(labelCounts[key])/numEntries
        entropy -= pxi*log(pxi,2)
        
    return entropy

In [None]:
# 实现按照给定特征划分数据集
def splitDataSet(dataSet,axis,value):
    returnDataSet = []
    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            returnDataSet.append(reducedFeatVec)
    
    return returnDataSet

In [None]:
# 实现特征选择函数。遍历整个数据集，循环计算熵和splitDataSet()函数，
# 找到最好的特征划分方式。
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1 #获取属性个数
    baseEntropy = calcEntropy(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet] # 获取数据集中某一属性的所有取值
        uniqueVals = set(featList) # 获取该属性所有不重复的取值
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet,i,value)
            pxi = len(subDataSet)/float(len(dataSet))
            # 特征A对数据集D的信息增益公式实现
            newEntropy += pxi*calcEntropy(subDataSet)
        
        infoGain = baseEntropy - newEntropy
        if(infoGain>bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    
    return bestFeature

In [None]:
"""
决策树创建过程中会采用递归的原则处理数据集。递归的终止条件为：程序遍历完所有划分数据集的属性；
或者每一个分支下的所有实例都具有相同的分类。如果数据集已经处理了所有属性，
但是类标签依然不是唯一的，此时我们需要决定如何定义该叶子节点，
在这种情况下，通常会采用多数表决的方法决定分类
"""
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(),\
                              key=operator.itemgetter(1),reverse=True)
    
    return sortedClassCount[0][0]

In [None]:
# 创建决策树
def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet] #获取类别列表
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(dataSet)
    
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    
    myTree = {bestFeatLabel:{}}
    
    del(labels[bestFeat])
    
    featList = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featList)
    for value in uniqueVals:
        subLabel = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabel)
    
    return myTree

In [None]:
# 利用构建好的决策树进行分类
def classify(inputTree,featLabels,testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key],featLabels,testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

In [3]:
# 测试
if __name__=="__main__":
    data,labels = createDataSet()
    templabels = labels[:]
    # 这里很神奇 如果用labels传入createTree方法 最后labels会被删除'no surfacing' key 值
    # 如果你知道记得告诉我
    myTree = createTree(data,templabels)
    print(classify(myTree,labels,[1,0]))
    print(classify(myTree,labels,[1,1]))

no
yes
