# Decison Tree

In [2]:
!ls ~/aistudio/data

ls: /Users/afirez/aistudio/data: No such file or directory


In [3]:
!ls ./work

In [4]:
!which pip3

/usr/local/anaconda3/envs/py38/bin/pip3


In [5]:
# 如果需要进行持久化安装, 需要使用持久化路径, 如下方代码示例:
# If a persistence installation is required, 
# you need to use the persistence path as the following: 
!mkdir -p ~/aistudio/external-libraries
!pip3 install beautifulsoup4 -t ~/aistudio/external-libraries

Collecting beautifulsoup4
  Using cached beautifulsoup4-4.10.0-py3-none-any.whl (97 kB)
Collecting soupsieve>1.2
  Using cached soupsieve-2.3.1-py3-none-any.whl (37 kB)
Installing collected packages: soupsieve, beautifulsoup4
Successfully installed beautifulsoup4-4.10.0 soupsieve-2.3.1


In [6]:
# 同时添加如下代码, 这样每次环境(kernel)启动的时候只要运行下方代码即可: 
# Also add the following code, 
# so that every time the environment (kernel) starts, 
# just run the following code: 
import sys 
sys.path.append('~/aistudio/external-libraries')

In [7]:
import operator
from math import log

In [8]:
def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

In [10]:
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVex in dataSet:
        currentLable = featVex[-1]
        if currentLable not in labelCounts.keys():
            labelCounts[currentLable] = 0
        labelCounts[currentLable] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2)
    return shannonEnt


def splitDataSet(dataSet, index, value):
    retDataset = []
    for featVec in dataSet:#整个样本
        if featVec[index] == value:
            reducedFeatVec = featVec[:index] # 特征1，特征2，特征3，特征4 -> featVec[:index]  = 特征1
            reducedFeatVec.extend(featVec[index+1:]) # featVec[index+1:] = 特征3，特征4
            retDataset.append(reducedFeatVec)
    return retDataset



def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain, bestFeature = 0.0, -1
    for i in range(numFeatures): #色泽，声音，纹理。。。
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:# 青绿，浅白。。。
            subdataset = splitDataSet(dataSet,i,value)
            prob = len(subdataset) /float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subdataset)
        bestInfoGain_ = baseEntropy - newEntropy
        if(bestInfoGain_ > bestInfoGain):
            bestInfoGain = bestInfoGain_
            bestFeature = i
    return bestFeature


def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]#如果数据里只有一种类别，直接返回
    # a = dataSet[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)#如果只有一个特征

    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabels = labels[bestFeat] # '纹理' 知道第一个特征选择的是纹理
    myTree = {bestFeatLabels:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueValus = set(featValues)
    for featValues in uniqueValus:#在子数据集里递归建立新的决策树
        subLabels = labels[:]
        myTree[bestFeatLabels][featValues] = createTree(splitDataSet(dataSet,bestFeat,featValues),subLabels)
    return myTree

In [11]:
def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat,dict):
        classLable = classify(valueOfFeat,featLabels,testVec)
    else:
        classLable = valueOfFeat
    return classLable


def fishTest():
    myDat, labels = createDataSet()
    import copy
    myTree = createTree(myDat, copy.deepcopy(labels))
    print(classify(myTree, labels, [1, 1]))

In [12]:
if __name__ == "__main__":
    fishTest()

yes
