In [1]:
from __future__ import print_function
print(__doc__)
import operator
from math import log
from collections import Counter

from matplotlib import pyplot as dtPlot

Automatically created module for IPython interactive environment


In [2]:
def createDataSet():
    dataSet = [(1, 1, 'yes'),
               (1, 1, 'yes'),
               (1, 0, 'no'),
               (0, 1, 'no'),
               (0, 1, 'no')]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

In [3]:
def calcShannonEnt(dataSet):
    label_count = Counter(data[-1] for data in dataSet)
    probs = [p[1] / len(dataSet) for p in label_count.items()]
    shannonEnt = sum([-p * log(p, 2) for p in probs])
    return shannonEnt


In [4]:
def splitDataSet(dataSet, index, value):
    retDataSet = []
    for featVec in dataSet: 
        if featVec[index] == value:
            reducedFeatVec = [featVec[:index]]
            reducedFeatVec.extend(featVec[index+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


In [5]:
def chooseBestFeatureToSplit(dataSet):
     base_entropy = calcShannonEnt(dataSet)
     best_info_gain = 0
     best_feature = -1
     for i in range(len(dataSet[0]) - 1):
         feature_count = Counter([data[i] for data in dataSet])
         new_entropy = sum(feature[1] / float(len(dataSet)) * calcShannonEnt(splitDataSet(dataSet, i, feature[0])) \
                        for feature in feature_count.items())
  
         info_gain = base_entropy - new_entropy
         print('No. {0} feature info gain is {1:.3f}'.format(i, info_gain))
         if info_gain > best_info_gain:
             best_info_gain = info_gain
             best_feature = i
     return best_feature
   

In [6]:

def majorityCnt(classList):
     major_label = Counter(classList).most_common(1)[0]
     return major_label

In [7]:
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
 
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)

   
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
        
    for value in uniqueVals:
        del_bestFeat = bestFeat
        del_labels = labels[bestFeat]
        subLabels = labels[:]
        del (labels[bestFeat])
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
        labels.insert(del_bestFeat, del_labels)
    return myTree

In [8]:
def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    print('+++', firstStr, 'xxx', secondDict, '---', key, '>>>', valueOfFeat)
    if isinstance(valueOfFeat, dict):
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else:
        classLabel = valueOfFeat
    return classLabel

In [9]:
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()

In [10]:
def grabTree(filename):
    import pickle
    fr = open(filename,'rb')
    return pickle.load(fr)

In [11]:
def get_tree_height(tree):
    if not isinstance(tree, dict):
        return 1

    child_trees =list( tree.values())[0].values()

    max_height = 0
    for child_tree in child_trees:
        child_tree_height = get_tree_height(child_tree)

        if child_tree_height > max_height:
            max_height = child_tree_height

    return max_height + 1

In [12]:

def fishTest():
   
    myDat, labels = createDataSet()

    import copy
    myTree = createTree(myDat, copy.deepcopy(labels))
    print(myTree)
    print(classify(myTree, labels, [1, 1]))
    print(get_tree_height(myTree))


In [13]:
if __name__ == '__main__':
   fishTest()

No. 0 feature info gain is 0.420
No. 1 feature info gain is 0.171
No. 0 feature info gain is 0.000
No. 1 feature info gain is 0.918
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
+++ no surfacing xxx {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}} --- 1 >>> {'flippers': {0: 'no', 1: 'yes'}}
+++ flippers xxx {0: 'no', 1: 'yes'} --- 1 >>> yes
yes
3
