上一章介绍的线性回归需要拟合所有的样本，当数据特征众多且之间关系复杂是，就很难构建全局模型，而且生活中很多问题都是非线性的，不可能使用全局线性模型拟合所有数据。

一种可行的方法是将数据集切分成很多易建模建模的数据，然后利用线性回归。

这里将介绍一个显得叫做CART（classification and regression trees,分类回归树）的树结构算法，既可用于分类，也可用于回归。

## 1、复杂数据的局部性建模

决策树是一种贪心算法，在给定时间内做出最佳选择，并不关心是否达到全局最优。

> 优点：可以对复杂和非线性的数据建模

> 缺点：结果不容易理解

> 适用数据类型：数值型和标称型

与ID3算法不同，CART算法使用二元切分处理数据，因此稍加修改就可用于回归问题，节省了树的构建时间。

## 2、连续和离散型特征的树的构建

In [49]:
# 定义节点
class treeNode():
    def __init__(self, feat, val, left, right):
        featureToSpliton = feat
        valueOfSPlit = val
        leftBranch = left
        rightBranch = right

In [50]:
import numpy as np

In [51]:
def loadDataSet(fileName):
    dataMat = []
    with open(fileName) as f:
        for line in f.readlines():
            curLine = line.strip().split('\t')
            # 将str映射出float
            fltLine = list(map(float, curLine))
            dataMat.append(fltLine)
    return dataMat

In [52]:
def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet[np.nonzero(dataSet[:, feature] > value)[0], :]
    mat1 = dataSet[np.nonzero(dataSet[:, feature] <= value)[0], :]
    return mat0, mat1

In [11]:
testMat = np.mat(np.eye(4))
testMat

matrix([[ 1.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.],
        [ 0.,  0.,  1.,  0.],
        [ 0.,  0.,  0.,  1.]])

In [22]:
mat0, mat1 = binSplitDataSet(testMat, 1, 0.5)
mat0

matrix([[ 0.,  1.,  0.,  0.]])

In [23]:
mat1

matrix([[ 1.,  0.,  0.,  0.],
        [ 0.,  0.,  1.,  0.],
        [ 0.,  0.,  0.,  1.]])

## 3、将CART算法用于回归

对数据的复杂关系建模，我们借用树结构帮助切分数据，为成功构建以分段常数为叶节点的树，需要度量出数据的一致性，计算连续性数值的混乱度，首先计算所有数据的均值，然后计算每个数据的值到均值的差值，这里一般采用绝对值或者平方值来度量，就像是为平均化之前的方差。

### 3.1 构建树

In [53]:
def regLeaf(dataSet):
    return np.mean(dataSet[:, -1], 0)

In [54]:
def regErr(dataSet):
    m, n = dataSet.shape
    return np.var(dataSet[:, -1]) * m

In [55]:
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
    tolS = ops[0]
    tolN = ops[1]
    if len(set(dataSet[:, -1].T.tolist()[0])) == 1: # 所有值相等就退出
        return None, leafType(dataSet)
    m, n = dataSet.shape
    S = errType(dataSet)
    bestS = np.inf
    bestIndex = 0
    bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:, featIndex].T.A.tolist()[0]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (mat0.shape[0] < tolN) or (mat1.shape[0] < tolN):
                continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if (S - bestS) < tolS:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (mat0.shape[0] < tolN) or (mat1.shape[0] < tolN):
        return None, leafType(dataSet)
    return bestIndex, bestValue

In [56]:
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None:
        return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

In [50]:
myDat = loadDataSet('ex00.txt')
myMat = np.mat(myDat)
myTree = createTree(myMat)
myTree

{'left': 1.0180967672413792,
 'right': -0.044650285714285719,
 'spInd': 0,
 'spVal': 0.48813}

In [51]:
myDat1 = loadDataSet('ex0.txt')
myDat1 = np.mat(myDat1)
myTree1 = createTree(myDat1)
myTree1

{'left': {'left': {'left': 3.9871631999999999,
   'right': 2.9836209534883724,
   'spInd': 1,
   'spVal': 0.797583},
  'right': 1.980035071428571,
  'spInd': 1,
  'spVal': 0.582002},
 'right': {'left': 1.0289583666666666,
  'right': -0.023838155555555553,
  'spInd': 1,
  'spVal': 0.197834},
 'spInd': 1,
 'spVal': 0.39435}

## 4、树剪枝

一棵树如果节点过多，表明该模型有可能“过拟合”，可以用测试集上某种交叉验证技术发现过拟合。

通过降低决策树的复杂度来避免过拟合的过程被称为剪枝（pruning）。在函数chooseBestSplit()中的提前终止条件，实际上是在进行一种所谓的预剪枝（prepruning
）的操作，另一种形式的剪枝需要使用测试集和训练集，称为后剪枝（postpruning）。

### 4.1 预剪枝

上文构建的算法对输入蚕食tolS和tolN非常敏感，找到最佳的参数需要多次调试，这是个预剪枝的过程，但是比较繁琐。

事实上我们常常甚至不确定到底需要寻找什么样的结果，这正是机器学习所关注的内容，计算机应该可以给出总体的概貌。

其实可以根据测试集对树进行剪枝，由于不需要用户指定参数，后剪枝是个更理想的剪枝方法。

### 4.2 后剪枝

In [57]:
def isTree(obj):
    return (type(obj).__name__ == 'dict')

In [58]:
def getMean(tree):
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    return (tree['left']+tree['right']) / 2

In [59]:
def prune(tree, testData):
    if testData.shape[0] == 0:
        return getMean(tree)
    if (isTree(tree['left']) or isTree(tree['right'])):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']):
        tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']):
        tree['right'] = prune(tree['right'], rSet)
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = np.sum(np.power(lSet[:, -1]-tree['left'], 2)) + \
                        np.sum(np.power(rSet[:, -1]-tree['right'], 2))
        treeMean = (tree['left']+tree['right'])/2
        errorMerge = np.sum(np.power(testData[:, -1]-treeMean, 2))
        if errorMerge < errorNoMerge:
            print('merging')
            return treeMean
        else:
            return tree
    else:
        return tree

In [62]:
myDat2 = loadDataSet('ex2.txt')
myMat2 = np.mat(myDat2)
myTree2 = createTree(myMat2, ops=(0, 1))
myTree2

{'left': {'left': {'left': {'left': {'left': 86.399636999999998,
     'right': 98.648346000000004,
     'spInd': 0,
     'spVal': 0.968621},
    'right': {'left': {'left': {'left': 112.386764,
       'right': 123.559747,
       'spInd': 0,
       'spVal': 0.960398},
      'right': 135.83701300000001,
      'spInd': 0,
      'spVal': 0.958512},
     'right': {'left': {'left': 82.016541000000004,
       'right': 100.935789,
       'spInd': 0,
       'spVal': 0.954711},
      'right': 130.92648,
      'spInd': 0,
      'spVal': 0.953902},
     'spInd': 0,
     'spVal': 0.956951},
    'spInd': 0,
    'spVal': 0.965969},
   'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 100.649591,
                'right': 73.520802000000003,
                'spInd': 0,
                'spVal': 0.952377},
               'right': 105.75250800000001,
               'spInd': 0,
               'spVal': 0.949198},
              'right

In [67]:
myDatTest = loadDataSet('ex2test.txt')
myMat2Test = np.mat(myDatTest)
myTree2Pruned = prune(myTree2, myMat2Test)
myTree2Pruned

merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging


{'left': {'left': {'left': {'left': 92.523991499999994,
    'right': {'left': {'left': {'left': 112.386764,
       'right': 123.559747,
       'spInd': 0,
       'spVal': 0.960398},
      'right': 135.83701300000001,
      'spInd': 0,
      'spVal': 0.958512},
     'right': 111.2013225,
     'spInd': 0,
     'spVal': 0.956951},
    'spInd': 0,
    'spVal': 0.965969},
   'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 96.41885225,
              'right': 69.318648999999994,
              'spInd': 0,
              'spVal': 0.948822},
             'right': {'left': {'left': 110.03503850000001,
               'right': {'left': 65.548417999999998,
                'right': {'left': 115.75399400000001,
                 'right': {'left': {'left': 94.396114499999996,
                   'right': 85.005351000000005,
                   'spInd': 0,
                   'spVal': 0.912161},
                  'right': {'left': {'left': 106.81466

## 5、模型树

用树来对数据建模，除了把叶节点简单地设定为常数值之外，还有一种方法是吧叶节点设定为分段线性函数。


决策树相比于其他机器学习算法优势之一在于结果更易理解，模型属的可解释性是它优于回归树的特点之一，另外模型树也具有更高的预测准确度。

In [170]:
def linearSolve(dataSet):
    m, n = dataSet.shape
    X = np.mat(np.ones((m, n)))
    Y = np.mat(np.ones((m ,1)))
    X[:, 1:n] = dataSet[:, 0:n-1]
    Y = dataSet[:, -1]
    xTx = X.T * X
    if np.linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse, \n\
                        try increasing the second value of ops')
    ws = xTx.I * X.T * Y
    return ws, X, Y

In [61]:
def modelLeaf(dataSet):
    ws, X, Y = linearSolve(dataSet)
    return ws

In [62]:
def modelErr(dataSet):
    ws, X, Y = linearSolve(dataSet)
    yHat = X * ws
    return np.sum(np.power(Y-yHat, 2))

In [76]:
myRegTree = createTree(myMat2, modelLeaf, modelErr, (1, 10))
myRegTree

{'left': {'left': {'left': matrix([[ 43180.40883751],
           [ 41233.6722789 ]]), 'right': matrix([[ 40685.1778718],
           [ 36707.246158 ]]), 'spInd': 0, 'spVal': 0.934853},
  'right': {'left': matrix([[ 45468.09210422],
           [ 37942.95526797]]), 'right': matrix([[ 49960.0410645 ],
           [ 37567.45600755]]), 'spInd': 0, 'spVal': 0.798198},
  'spInd': 0,
  'spVal': 0.872199},
 'right': {'left': {'left': matrix([[ 39499.77774367],
           [ 26281.14622539]]), 'right': matrix([[ 41822.03111756],
           [ 24545.31742533]]), 'spInd': 0, 'spVal': 0.628061},
  'right': {'left': {'left': matrix([[ 11038.79915287],
            [  5723.22062811]]), 'right': matrix([[-409.61642917],
            [-189.54785021]]), 'spInd': 0, 'spVal': 0.487381},
   'right': {'left': {'left': matrix([[-656.61820174],
             [-254.8836606 ]]), 'right': matrix([[-1985.13428453],
             [ -640.08274006]]), 'spInd': 0, 'spVal': 0.342761},
    'right': {'left': matrix([[-942.51893

## 6、示例：树回归与标准回归的比较

In [63]:
def regTreeEval(model, inDat):
    return float(model)

In [64]:
def modelTreeEval(model, inDat):
    m, n = inDat.shape
    X = np.mat(np.ones((1, n+1)))
    X[:, 1:n+1] = inDat
    return float(X*model)

In [65]:
def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree):
        return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']):
            return treeForeCast(tree['left'], inData, modelEval)
        else:
            return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']):
            return treeForeCast(tree['right'], inData, modelEval)
        else:
            return modelEval(tree['right'], inData)

In [66]:
def createForeCast(tree, testData, modelEval=regTreeEval):
    m = len(testData)
    yHat = np.mat(np.zeros((m, 1)))
    for i in range(m):
        yHat[i, 0] = treeForeCast(tree, np.mat(testData[i]), modelEval)
    return yHat

In [85]:
trainMat = np.mat(loadDataSet('bikeSpeedVsIq_train.txt'))
testMat = np.mat(loadDataSet('bikeSpeedVsIq_test.txt'))
myTree= createTree(trainMat, ops=(1,20))
yHat = createForeCast(myTree, testMat[:, 0])
np.corrcoef(yHat, testMat[:, 1], rowvar=0)[0, 1]

0.96408523182221451

可以看出效果相当好！树回归在预测复杂数据时会比简单的线性模型更有效。

## 7、使用Python的Tkinter库创建GUI

### 7.1 用Tkinter创建GUI

In [108]:
from tkinter import *

In [159]:
root = Tk()

In [9]:
myLabel = Label(root, text='Hello World')
myLabel.grid()

In [10]:
root.mainloop()

正式开始

In [12]:
def reDraw(tolS, tolN):
    pass

In [13]:
def drawNewTree():
    pass

In [171]:
root = Tk()
Label(root, text='Plot Place Holder').grid(row=0, columnspan=3)
Label(root, text='tolN').grid(row=1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0, '10')
Label(root, text='tolS').grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0, '1.0')
Button(root, text='ReDraw', command=drawNewTree).grid(row=1, column=2, rowspan=3)
chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text='Model Tree', variable=chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)
reDraw.rawDat = np.mat(loadDataSet('sine.txt'))
reDraw.testDat = np.arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)

reDraw(1.0, 10)
#root.mainloop()

TclError: this isn't a Tk application

### 7.2 集成Matplotlib和Tkinter

通过将Matplotlib的后端设置为TkAgg，即可将Matplotlib绘制在GUI框架上。

In [147]:
import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
reDraw.f = Figure(figsize=(5, 4), dpi=100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
#reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)
#root.mainloop()

In [153]:
def reDraw(tolS, tolN):
    reDraw.f.clf()
    reDraw.a = reDraw.f.add_subplot(111)
    if chkBtnVar.get():
        if tolN < 2:
            tolN = 2
        myTree = createTree(reDraw.rawDat, modelLeaf, modelErr, (tolS, tolN))
        yHat = createForeCast(myTree, reDraw.testDat, modelTreeEval)
    else:
        myTree = createTree(reDraw.rawDat, ops=(tolS, tolN))
        yHat = createForeCast(myTree, reDraw.testDat)
    reDraw.a.scatter(reDraw.rawDat[:, 0].flatten().A, reDraw.rawDat[:, 1].flatten().A, s=5)
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0)
    reDraw.canvas.show()

In [154]:
def getInputs():
    try:
        tolN = int(tolNentry.get())
    except:
        tolN = 10
        print('enter Integer for tolN')
        tolNentry.delete(0, END)
        tolNentry.insert(0, '10')
    try:
        tolS = float(tolSentry.get())
    except:
        tolS = 1.0
        print('enter Float for tols')
        tolSentry.delete(0, END)
        tolSentry.insert(0, '1.0')
    return tolN, tolS

In [155]:
def drawNewTree():
    tolN, tolS = getInputs()
    reDraw(tolS, tolN)

In [172]:
import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
reDraw.f = Figure(figsize=(5, 4), dpi=100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)
root.mainloop()

> 总结：数据集中经常包含一些复杂的相互关系，输入数据和目标变量之间呈现飞信线性关系，这时候可以用树对预测值分段。若结点使用的模型是分段常数则称为回归树，若叶节点使用模型是线性回归方程则称为模型树。 CART算法可用于构建二元树，离散数据和连续数据都可以处理，该算法构建的树倾向于过拟合，需要采取剪枝的策略。