<a href="https://colab.research.google.com/github/yeswhos/MSc-Project/blob/master/Control%20System/Decision%20Tree/Decision_Tree_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
def createDataSet():
    dataSet = [[1, 1, 'YES'],
               [1, 1, 'YES'],
               [1, 0, 'NO'],
               [0, 1, 'NO'],
               [0, 1, 'NO']]
    featNames = ['no surfacing','flippers']
    return dataSet, featNames


In [2]:
from math import log
 
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt


In [3]:
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]  # chop out axis used for splitting
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


In [7]:

dataSet, feats = createDataSet()
print(splitDataSet(dataSet, 0, 1)) #按特征0划分，特征值为1
print(splitDataSet(dataSet, 0, 0))

[[1, 'YES'], [1, 'YES'], [0, 'NO']]
[[1, 'NO'], [1, 'NO']]


In [8]:
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1  #每条数据的特征数量
    baseEntropy = calcShannonEnt(dataSet)  #划分前的熵
    bestInfoGain = 0.0;  #记录最高信息增益
    bestFeature = -1  #记录最佳特征
    for i in range(numFeatures):  #遍历每个特征
        featList = [data[i] for data in dataSet]  #把所有数据的该特征值抽出来放到一个list里面
        uniqueVals = set(featList)  #利用set找出该特征所有不同的值
        newEntropy = 0.0
        for value in uniqueVals:  #按这些不同的特征值，分别划分成子数据集
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))  #子数据集的权重
            newEntropy += prob * calcShannonEnt(subDataSet)  #子数据集的权重*熵
        infoGain = baseEntropy - newEntropy  #计算信息增益
        if infoGain > bestInfoGain:  #更新最高信息增益和最佳特征
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature  #返回最佳特征的下标


In [10]:

dataSet, feats = createDataSet()
print(chooseBestFeatureToSplit(dataSet))

0


In [11]:
import operator
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


In [12]:

def createTree(dataSet, featNames):
    classList = [data[-1] for data in dataSet]  #当前数据集的所有标签
    if classList.count(classList[0]) == len(classList):  #如果这堆标签全都一样的话，返回这个标签。
        return classList[0]
    if len(dataSet[0]) == 1:  #如果当前数据集一个特征都不剩了，那就不用再划分下去了
        return majorityCnt(classList)  #直接投票，返回出现次数最多的标签
    bestFeat = chooseBestFeatureToSplit(dataSet)  #选出用于划分的最佳属性
    bestFeatName = featNames[bestFeat]  #最佳属性的属性名称
    myTree = {bestFeatName:{}}  #字典：记录最佳属性对应的标签种类、出现次数情况
    del(featNames[bestFeat])  #在属性名称列表中剔除最佳属性
    featValues = [data[bestFeat] for data in dataSet]  #当前数据集中最佳属性的所有属性值
    uniqueVals = set(featValues) #最佳属性的不同属性值
    for value in uniqueVals:
        subfeatNames = featNames[:]  #去除最佳属性后的属性名称列表
        # 构建最佳属性的值为value的子树
        myTree[bestFeatName][value] = createTree(splitDataSet(dataSet, bestFeat, value),subfeatNames)
    return myTree


In [14]:
dataSet, feats = createDataSet()
theTree = createTree(dataSet, feats)
print(theTree)

{'no surfacing': {0: 'NO', 1: {'flippers': {0: 'NO', 1: 'YES'}}}}


In [28]:

def classify(inputTree,featNames,testVec):
    #得换成list
    firstStr = list(inputTree.keys())[0]  #当前树的根节点的特征名称
    secondDict = inputTree[firstStr]  #根节点的所有子节点
    featIndex = featNames.index(firstStr)  #找到根节点特征对应的下标
    key = testVec[featIndex]  #找出待测数据的特征值
    valueOfFeat = secondDict[key]  #拿这个特征值在根节点的子节点中查找，看它是不是叶节点
    if isinstance(valueOfFeat, dict):  #如果不是叶节点
        classLabel = classify(valueOfFeat, featNames, testVec)  #递归地进入下一层节点
    else: classLabel = valueOfFeat  #如果是叶节点：确定待测数据的分类
    return classLabel


In [29]:
dataSet, feats = createDataSet()
feats_copy = feats[:]  #由于createTree函数会改变feats，所以先深复制一份feats
theTree = createTree(dataSet, feats)
# print feats_copy, '\n', theTree
# print classify(theTree, feats_copy, [1,0])
# print classify(theTree, feats_copy, [1,1])
print(feats_copy)
print(theTree)
print(classify(theTree, feats_copy, [1,0]))

['no surfacing', 'flippers']
{'no surfacing': {0: 'NO', 1: {'flippers': {0: 'NO', 1: 'YES'}}}}
NO


In [30]:
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'w')
    pickle.dump(inputTree, fw)
    fw.close()
def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

In [None]:
storeTree(theTree, 'storage.txt')
print grabTree('storage.txt')