In [79]:
import numpy as np

In [80]:
def loadSimple():
    dataMat = np.matrix([[1.,2.1],
                       [2.,1.1],
                       [1.3,1.],
                       [1.,1.],
                       [2.,1.]])
    classLabels = [1.0,1.0,-1.0,-1.0,1.0]
    return dataMat,classLabels

In [81]:
#根据阈值进行分类
def stumpClassify(dataMatirx,dimen,threshVal,threshIneq):
    retArray = np.ones((np.shape(dataMatirx)[0],1))
    
    if threshIneq == 'lt':
        retArray[dataMatirx[:,dimen]<=threshVal] = -1.0
    else:
        retArray[dataMatirx[:,dimen]>threshVal] = -1.0
    
    return retArray

In [137]:
#构建单层最佳决策树(弱决策树)
def buildStump(dataArr,classLable,D):
    dataMatrix = np.mat(dataArr)
    classLabelMat = np.mat(classLable).T
    m,n = np.shape(dataMatrix)
    numStep = 10.0 
    bestStump = {}
    bestClassEst = np.mat(np.zeros((m,1)))
    minError = 1000000
    
    #遍历所有的特征
    for i in range(n):
        rangeMin = dataMatrix[:,i].min()
        rangeMax = dataMatrix[:,i].max()
        
        #确定步长
        stepSize = (rangeMax - rangeMin) /numStep
        
        for j in range(-1,int(numStep)+1):
            for inequal in ['lt','gt']:
                #计算阈值
                threshVal = rangeMin + float(j)*stepSize
                
                #根据阈值进行分类
                predictedVals = stumpClassify(dataMatrix,i,threshVal,inequal)
                
                #计算错误率
                errArr = np.mat(np.ones((m,1)))
                errArr[predictedVals==classLabelMat] = 0
                weightError = D.T * errArr
                
                print("split: dim%d,thresh %.2f,thresh inqueal:%s\
                    the weightError is %.3f"%(i,threshVal,inequal,weightError))
                
                if weightError<minError:
                    minError = weightError
                    bestClassEst = predictedVals.copy()
                    bestStump['dim'] = i
                    bestStump['thresh'] = threshVal
                    bestStump['ineq'] = inequal
    return bestStump,minError,bestClassEst

In [134]:
#完整Adaboost算法
def adaBoostTrain(dataArr,classLabels,numIt=40):
    m = np.shape(dataArr)[0]#行数
    bestTree = []#总的分类器
    aggClassEst = np.mat(np.zeros((m,1)))#各个分类器判断的y的值的和
    D = np.mat(np.ones((m,1))/m)
    for i in range(numIt):
        bestStump,minError,bestClassEst = buildStump(dataArr,classLabels,D)
        print("bestStump",bestStump,"minError",minError)
        print("D：",D.T)
        #计算本次弱分类器的系数
        alpha = float(0.5*np.log((1-minError)/max(minError,1e-16)))
        
        #把本次分类器加入到总分类器中
        bestStump['alpha'] = alpha
        bestTree.append(bestStump)
        print("classEst",bestClassEst)
        
        #更新权重D 
        expon = np.multiply(-1*alpha*np.mat(classLabels).T,bestClassEst)
        D = np.multiply(D,np.exp(expon))
        D = D/D.sum()
        
        #
        aggClassEst += alpha*bestClassEst
        print("aggClassEst",aggClassEst)
        
        aggErrors = np.multiply(np.sign(aggClassEst)!=np.mat(classLabels).T,np.ones((m,1)))
        errorRate = aggErrors.sum()/m
        print("errorRate",errorRate,"\n")
        if errorRate == 0.0: break
    return bestTree

In [138]:
dataArr,classLabels = loadSimple()
print(adaBoostTrain(dataArr,classLabels,9))


split: dim0,thresh 0.90,thresh inqueal:lt                    the weightError is 0.400
split: dim0,thresh 0.90,thresh inqueal:gt                    the weightError is 0.600
split: dim0,thresh 1.00,thresh inqueal:lt                    the weightError is 0.400
split: dim0,thresh 1.00,thresh inqueal:gt                    the weightError is 0.600
split: dim0,thresh 1.10,thresh inqueal:lt                    the weightError is 0.400
split: dim0,thresh 1.10,thresh inqueal:gt                    the weightError is 0.600
split: dim0,thresh 1.20,thresh inqueal:lt                    the weightError is 0.400
split: dim0,thresh 1.20,thresh inqueal:gt                    the weightError is 0.600
split: dim0,thresh 1.30,thresh inqueal:lt                    the weightError is 0.200
split: dim0,thresh 1.30,thresh inqueal:gt                    the weightError is 0.800
split: dim0,thresh 1.40,thresh inqueal:lt                    the weightError is 0.200
split: dim0,thresh 1.40,thresh inqueal:gt             