# 第三章                          决策树

## 3.1.1 信息增益 

In [4]:
from numpy import *
import operator

####  计算香农熵

In [5]:
d ={}
d.get('dd')

In [6]:
from math import log
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featvec in dataSet:
        currentlabel = featvec[-1]
        if currentlabel not in labelCounts.keys():
            labelCounts[currentlabel] = 0
        labelCounts[currentlabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -=prob * log(prob ,2)
    return shannonEnt

In [7]:
def createDataSet():
    dataSet = [[1,1,'yes'],
               [1,1,'yes'],
               [1,0,'no'],
               [0,1,'no'],
               [0,1,'no']]
    labels = ['no surfacing','flippers']
    return dataSet,labels

In [8]:
myDat,labels = createDataSet()

##  3.1.2 划分数据集

In [9]:
def splitdataset(dataset,axis,value):
    retdataset = []
    for featvec in dataset:
        if featvec[axis] == value:
            reducefeat = featvec[:axis]
            reducefeat.extend(featvec[axis+1:])
            retdataset.append(reducefeat)
    return retdataset

#### 选择最好的数据集划分方式

In [10]:
def chooseBestFeatureToSplit(dataset):
    numfeatues = len(dataset[0])-1
    baseEntropy = calcShannonEnt(dataset)
    bestInfogain = 0.0;bestfeat = -1
    for i in range(numfeatues):
        featList = [example[i] for example in dataset]
        uniquevals = set(featList)
        newEntropy = 0.0
        for value in uniquevals:
            subdataset = splitdataset(dataset,i,value)
            prob = len(subdataset)/float(len(dataset))
            newEntropy += prob * calcShannonEnt(subdataset)
        infogain = baseEntropy - newEntropy
        if infogain > bestInfogain:
            bestInfogain = infogain
            bestfeat = i
        return bestfeat

In [11]:
mydat,label = createDataSet()

In [12]:
splitdataset(mydat,0,0)

[[1, 'no'], [1, 'no']]

In [13]:
chooseBestFeatureToSplit(mydat)

0

##  3.1.3 递归构建决策树

In [14]:
def majorityCnt(classlist):
    classCount = {}
    for vote in classlist:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount += 1
    sortclassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)
    return sortclassCount[0][0]

In [15]:
def createTree(dataset,labels):
    classList = [example[-1] for example in dataset]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataset[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataset)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featureValues = [example[bestFeat] for example in dataset]
    uniqueVals =set(featureValues)
    for value in uniqueVals:
        sublabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitdataset(dataset,bestFeat,value),sublabels)
    return myTree

In [16]:
myTree = createTree(mydat,label)

In [17]:
myTree

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

#  在Python中使用Matplotlib注解回执树形图

In [18]:
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle = 'sawtooth',fc ='0.8')
leafNode = dict(boxstyle= "round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,
                            xycoords='axes fraction',
                            xytext=centerPt,
                            textcoords='axes fraction',
                            va='center',
                            ha='center',
                            bbox=nodeType,
                            
                            arrowprops=arrow_args)
def createPlot(inTree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axprops = dict(xticks=[],yticks=[])

    createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW = float(getNumleafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xoff = -0.5/plotTree.totalW;plotTree.yoff = 1.0
    plotTree(inTree,(0.5,1.0),'')
    plot.show()

In [32]:
def getNumleafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    seconddict = myTree[firstStr]
    for key in seconddict.keys():
        if type(seconddict[key]).__name__ == 'dict':
            numLeafs += getNumleafs(seconddict[key])
        else:
            numLeafs += 1
    return numLeafs
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    seconddeict = myTree[firstStr]
    for key in seconddeict.keys():
        if type(seconddeict[key]).__name__ == 'dict':
            thisDepth = 1+getTreeDepth(seconddeict[key])
        else:
            thisDepth = 1
    if thisDepth >maxDepth:
        maxDepth = thisDepth
    return maxDepth

In [33]:
getNumleafs(myTree)

3

In [34]:
getTreeDepth(myTree)

2

In [35]:
def plotMidText(cntPt, parentPt, txtString):
    xMid = (parentPt[0]-cntPt[0])/2.0+cntPt[0]
    yMid = (parentPt[1]-cntPt[1])/2.0+cntPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs = getNumleafs(myTree)
    depth = getTreeDepth(myTree)
    firststr = list(myTree.keys())[0]
    cntrPt = (plotTree.xoff + float(numLeafs)/2.0/plotTree.totalW,plotTree.yoff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firststr,cntrPt,parentPt,nodeTxt)
    secondDict = myTree[firststr]
    plotTree.yoff = plotTree.yoff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xoff = plotTree.xoff +1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xoff,plotTree.yoff),cntrPt,leafNode)
            plotMidText((plotTree.xoff,plotTree.yoff),cntrPt,str(key))
    plotTree.yoff = plotTree.yoff + 1.0/plotTree.totalD

In [36]:
def classify(inputTree,featLabels,testVect):
    firstStr = list(inputTree.keys())[0]
    seconddict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in seconddict.keys():
        if testVect[featIndex] == key:
            if type(seconddict[key]).__name__ == 'dict':
                classLabel = classify(seconddict[key],featLabels,testVect)
            else:
                classLabel = seconddict[key]
    return classLabel

In [37]:
myTree

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

In [38]:
classify(myTree,labels,[1,1])

'yes'

In [41]:
labels

['no surfacing', 'flippers']

In [42]:
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lenseLabel = ['age','prescript','astigmatic','tearRate']
lenseTree = createTree(lenses,lenseLabel)

In [43]:
lenseTree

{'age': {'pre': {'prescript': {'hyper': {'astigmatic': {'no': {'tearRate': {'normal': 'soft',
        'reduced': 'no lenses'}},
      'yes': 'no lenses'}},
    'myope': {'astigmatic': {'no': {'tearRate': {'normal': 'soft',
        'reduced': 'no lenses'}},
      'yes': {'tearRate': {'normal': 'hard', 'reduced': 'no lenses'}}}}}},
  'presbyopic': {'prescript': {'hyper': {'astigmatic': {'no': {'tearRate': {'normal': 'soft',
        'reduced': 'no lenses'}},
      'yes': 'no lenses'}},
    'myope': {'astigmatic': {'no': 'no lenses',
      'yes': {'tearRate': {'normal': 'hard', 'reduced': 'no lenses'}}}}}},
  'young': {'tearRate': {'hard': 'hard',
    'no lenses': 'no lenses',
    'soft': 'soft'}}}}

In [44]:
createPlot(lenseTree)

AttributeError: 'str' object has no attribute 'copy'