## 1. 预测注册用户

In [1]:
# 网站来源，位置，是否读过FAQ，浏览网页数，选择服务类型
data=[['slashdot','USA','yes',18,'None'],
        ['google','France','yes',23,'Premium'],
        ['digg','USA','yes',24,'Basic'],
        ['kiwitobes','France','yes',23,'Basic'],
        ['google','UK','no',21,'Premium'],
        ['(direct)','New Zealand','no',12,'None'],
        ['(direct)','UK','no',21,'Basic'],
        ['google','USA','no',24,'Premium'],
        ['slashdot','France','yes',19,'None'],
        ['digg','USA','no',18,'None'],
        ['google','UK','no',18,'None'],
        ['kiwitobes','UK','no',19,'None'],
        ['digg','New Zealand','yes',12,'Basic'],
        ['slashdot','UK','no',21,'None'],
        ['google','UK','yes',18,'Basic'],
        ['kiwitobes','France','yes',19,'Basic']]

### 1.1 引入决策树

In [2]:
class DecisionNode:
  def __init__(self,col=-1,value=None,leaf=None,tb=None,fb=None):
    self.col=col
    self.value=value
    self.leaf=leaf
    self.tb=tb
    self.fb=fb

### 1.2 拆分方案

In [3]:
def divideSet(rows, column, value):
    split_function = None
    if isinstance(value, int) or isinstance(value, float):
        split_function = lambda row : row[column] >= value
    else:
        split_function = lambda row : row[column] == value
    
    set1 = [row for row in rows if split_function(row)]
    set2 = [row for row in rows if not split_function(row)]
    return [set1, set2]

In [4]:
divideSet(data, 2, 'yes')

[[['slashdot', 'USA', 'yes', 18, 'None'],
  ['google', 'France', 'yes', 23, 'Premium'],
  ['digg', 'USA', 'yes', 24, 'Basic'],
  ['kiwitobes', 'France', 'yes', 23, 'Basic'],
  ['slashdot', 'France', 'yes', 19, 'None'],
  ['digg', 'New Zealand', 'yes', 12, 'Basic'],
  ['google', 'UK', 'yes', 18, 'Basic'],
  ['kiwitobes', 'France', 'yes', 19, 'Basic']],
 [['google', 'UK', 'no', 21, 'Premium'],
  ['(direct)', 'New Zealand', 'no', 12, 'None'],
  ['(direct)', 'UK', 'no', 21, 'Basic'],
  ['google', 'USA', 'no', 24, 'Premium'],
  ['digg', 'USA', 'no', 18, 'None'],
  ['google', 'UK', 'no', 18, 'None'],
  ['kiwitobes', 'UK', 'no', 19, 'None'],
  ['slashdot', 'UK', 'no', 21, 'None']]]

In [5]:
# 统计label的种类分布，据此计算不纯度
def uniqueCounts(rows, column):
    values_count = {}
    for row in rows:
        value = row[column]    # label = row[len(row)-1]
        if value not in values_count:
            values_count[value] = 0
        values_count[value] += 1
    return values_count

In [6]:
# 基尼不纯度,是指将来自集合中的某种结果随机应用在集合中，某一数据项的预期误差率。
def giniImpurity(rows, column=-1):
    total = len(rows)
    values_count = uniqueCounts(rows,column)
    impurity = 0.0
    for value1 in values_count:
        p1 = float(values_count[value1]) / total
        p2 = 0.0
        for value2 in values_count:
            if value1 == value2:
                continue
            p2 += values_count[value2] / total
        impurity += (p1 * p2)
    return impurity

In [7]:
def entropy(rows,column=-1):
    from math import log
    log2 = lambda x:log(x)/log(2)
    values_count = uniqueCounts(rows,column)
    ent = 0.0
    for value in values_count.keys():
        p = float(values_count[value]) / len(rows)
        ent -= p * log2(p)
    return ent

In [8]:
giniImpurity(data,-1)

0.6328125

In [9]:
entropy(data,-1)

1.5052408149441479

In [10]:
set1, set2 = divideSet(data, 2, 'yes')

In [11]:
entropy(set1,-1)

1.2987949406953985

In [12]:
giniImpurity(set1,-1)

0.53125

### 1.3 以递归方式构造树

In [13]:
def buildTree(rows, impurity=entropy, gainRatio=False):
    if len(rows) == 0:
        return DecisionNode()
    currentImpurity = impurity(rows)
    bestGain = 0.0
    bestCriteria = None
    bestSets = None
    columnCount = len(rows[0]) - 1
    for col in range(0, columnCount):
        columnEntropy = 1.0
        if gainRatio:
            columnEntropy = impurity(rows, col)
            print(str(col) + ' 属性熵:' +str(columnEntropy))
        colValues = set()
        for row in rows:
            colValues.add(row[col])
        for value in colValues:
            (set1, set2) = divideSet(rows, col, value)
            p = float(len(set1)) / len(rows)
            nextImpurity = (p * impurity(set1)) + ((1 - p) * impurity(set2))
            gain = currentImpurity - nextImpurity
            if gainRatio:
                gain = gain / columnEntropy                
            if gain > bestGain and len(set1) > 0 and len(set2) > 0:
                bestGain = gain
                bestCriteria = (col, value)
                bestSets = (set1, set2)
    if bestGain > 0:
        trueBranch = buildTree(rows=bestSets[0], impurity=impurity)
        falseBranch = buildTree(rows=bestSets[1], impurity=impurity)
        return DecisionNode(col=bestCriteria[0], value=bestCriteria[1], tb=trueBranch, fb=falseBranch)
    else:
        return DecisionNode(leaf=uniqueCounts(rows,column=-1)) 

### 1.4 显示树

In [14]:
def printTree(node, indent=' '):
    if node.leaf != None:
        print(str(node.leaf))
    else:
        print(indent + str(node.col) + ' : ' + str(node.value) + ' ?')
        print(indent + 'T ->', printTree(node.tb, indent + ' '))
        print(indent + 'F ->', printTree(node.fb, indent + ' '))

In [15]:
printTree(buildTree(data,gainRatio=True))

0 属性熵:2.257856063692049
1 属性熵:1.9056390622295662
2 属性熵:1.0
3 属性熵:2.5306390622295662
 2 : yes ?
  0 : slashdot ?
{'None': 2}
  T -> None
   0 : google ?
    1 : UK ?
{'Basic': 1}
    T -> None
{'Premium': 1}
    F -> None
   T -> None
{'Basic': 4}
   F -> None
  F -> None
 T -> None
  3 : 21 ?
   0 : google ?
{'Premium': 2}
   T -> None
    0 : slashdot ?
{'None': 1}
    T -> None
{'Basic': 1}
    F -> None
   F -> None
  T -> None
{'None': 4}
  F -> None
 F -> None


In [16]:
# 一个分支的总宽度等于所有子宽度之和，没有子分支的话，宽度为1
def getWidth(tree):
    if tree.tb == None and tree.fb == None:
        return 1
    return getWidth(tree.tb) + getWidth(tree.fb)
def getHeight(tree):
    if tree.tb == None and tree.tb == None:
        return 0
    return max(getHeight(tree.tb), getHeight(tree.fb)) + 1

In [22]:
from PIL import Image, ImageDraw

def drawNode(draw, node, x, y):
    if node.leaf == None:
        w1 = getWidth(node.fb) * 100
        w2 = getWidth(node.tb) * 100
        left = x - (w1+w2)/2
        right = x + (w1+w2)/2
        draw.text((x-20,y-20), str(node.col)+':'+str(node.value), (0,0,0))
        draw.line((x,y, left+w1/2,y+100), fill=(255,0,0))
        draw.line((x,y, right-w2/2,y+100), fill=(0,255,0))
        drawNode(draw, node.fb, left+w1/2, y+100)
        drawNode(draw, node.tb, right-w2/2, y+100)
    else:
        for item in node.leaf.items():
            print('%s:%d'%item)
        txt=' \n'.join(['%s:%d'%v for v in node.leaf.items()])
        draw.text((x-20,y),txt,(0,0,0))
def drawTree(tree, jpeg='tree.jpg'):
    w = getWidth(tree) * 100
    h = getHeight(tree) * 100 + 120
    image = Image.new('RGB',(w,h),(255,255,255))
    draw = ImageDraw.Draw(image)
    drawNode(draw, tree, w/2, 20)
    image.save(jpeg, 'JPEG')

In [24]:
drawTree(buildTree(data,gainRatio=False), jpeg='treeview.jpg')

None:3
Basic:1
Basic:4
None:3
None:1
Basic:1
Premium:3
