# 寻找网站中新用户有可能成为付费用户的模型

In [1]:
my_data=[['slashdot','USA','yes',18,'None'],
        ['google','France','yes',23,'Premium'],
        ['digg','USA','yes',24,'Basic'],
        ['kiwitobes','France','yes',23,'Basic'],
        ['google','UK','no',21,'Premium'],
        ['(direct)','New Zealand','no',12,'None'],
        ['(direct)','UK','no',21,'Basic'],
        ['google','USA','no',24,'Premium'],
        ['slashdot','France','yes',19,'None'],
        ['digg','USA','no',18,'None'],
        ['google','UK','no',18,'None'],
        ['kiwitobes','UK','no',19,'None'],
        ['digg','New Zealand','yes',12,'Basic'],
        ['slashdot','UK','no',21,'None'],
        ['google','UK','yes',18,'Basic'],
        ['kiwitobes','France','yes',19,'Basic']]

In [2]:
#决策树的节点    在上一个案例中是用字典表示
class decisionnode:
    def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
        self.col=col  #待检验的判断条件所对应的列索引值
        self.value=value  #为了使结果变为 TRUE  当前列必须匹配的值
        self.results=results  #针对当前分枝的结果  它是一个字典，除叶节点外，其他结果上该值皆为None
        self.tb=tb
        self.fb=fb
        

In [3]:
#  这里采用cart 算法（分类，回归） 它首先创建了一个根节点  然后通过评估表中所有待测变量从中选取最合适的变量对数据进行拆分

In [4]:
def divideset(rows,column,value):
    split_function=None   #对于数值型数据及离散型数据  采用不同的回调来进行分隔数据操作
    if isinstance(value,int) or isinstance(value,float):
        split_function=lambda row:row[column] >=value   # split_function 实际上就是一个匿名函数
    else:
        split_function=lambda row:row[column]==value
        
    #使用这个函数来完成对rows中的数据的分隔
    set1=[row for row in rows if split_function(row)]
    set2=[row for row in rows if not split_function(row)]
    return (set1,set2)

In [5]:
#测试划分
set1,set2=divideset(my_data,0,'google')

print(set1)
print('---------------')
print(set2)

[['google', 'France', 'yes', 23, 'Premium'], ['google', 'UK', 'no', 21, 'Premium'], ['google', 'USA', 'no', 24, 'Premium'], ['google', 'UK', 'no', 18, 'None'], ['google', 'UK', 'yes', 18, 'Basic']]
---------------
[['slashdot', 'USA', 'yes', 18, 'None'], ['digg', 'USA', 'yes', 24, 'Basic'], ['kiwitobes', 'France', 'yes', 23, 'Basic'], ['(direct)', 'New Zealand', 'no', 12, 'None'], ['(direct)', 'UK', 'no', 21, 'Basic'], ['slashdot', 'France', 'yes', 19, 'None'], ['digg', 'USA', 'no', 18, 'None'], ['kiwitobes', 'UK', 'no', 19, 'None'], ['digg', 'New Zealand', 'yes', 12, 'Basic'], ['slashdot', 'UK', 'no', 21, 'None'], ['kiwitobes', 'France', 'yes', 19, 'Basic']]


In [6]:
# 对每种可能的结果标签进行计数
def uniquecount(rows):
    result={}
    for row in rows:
        label=row[len(row)-1]
        result[label]=result.get(label,0)+1
    return result

In [7]:
#测试
print(uniquecount(my_data))
print(uniquecount(divideset(my_data,0,'google')[0] ) )

{'None': 7, 'Premium': 3, 'Basic': 6}
{'Premium': 3, 'None': 1, 'Basic': 1}


In [8]:
#熵的算法
from math import log
def entropy(rows):
    result=uniquecount(rows)
    shannonEnt=0
    #计算熵
    for key in result.keys():
        p=float(result[key]/len(rows) )
        shannonEnt-=p*log(p,2)
    return shannonEnt

In [9]:
set1,set2=divideset(my_data,0,'digg')
print(entropy(set1))
print(entropy(set2))

0.9182958340544896
1.5262349099495225


In [10]:
#基尼系数 随机放置的数据出现于错误分类的概率
#步骤：一结果出现的次数除以集合总行数来计算相应概率，然后把所有的概率值的乘积累加起来
def giniimpurity(rows):
    totalrows=len(rows)
    result=uniquecount(rows)
    impurity=0
    for k1 in result.keys():
        p1=float(result[k1])/totalrows
        for k2 in result.keys():
            if k1==k2:
                continue
            p2=float(result[k2])/totalrows
            impurity+=p1*p2
    return impurity

In [11]:
set1,set2=divideset(my_data,0,'digg')
print(giniimpurity(set1))
print(giniimpurity(set2))

0.4444444444444444
0.6390532544378699


In [13]:
#构建树
def buildTree2(rows,scoref=giniimpurity):
    if len(rows)==0:
        return decisionnode()
    
    current_score=scoref(rows)  #计算熵或者基尼的不纯度
    
    #用于划分最优特征的参数
    best_gain=0.0
    best_criteria=None
    best_sets=None
    
    column_count=len(rows[0])-1  #列数
    
    #循环每个列号
    for col in range(0,column_count):
        column_values=set()   # [当前第col列的取值:1,当前第col列的取值:1,]
        for row in rows:   #循环每一行  取出第 col 列的值 设置到字典中 将 value 设置为1
            column_values.add(row[col])
            
        #按照第col列value切分数据集
        for value in column_values:
            (set1,set2)=divideset(rows,col,value)  #按照 col 列的value取值切分数据
            
            #计算信息增益
            p=float(len(set1))/len(rows)

            #current_score 没有切分的情况下  数据的混乱程度
            gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
            if gain>best_gain and len(set1)>0 and len(set2)>0:
                best_gain=gain   #最佳信息增益
                best_criteria=(col,value)  # 最优切分的列和值
                best_sets=(set1,set2)   # 切分好的数据集
                
    #根据上面找到的最优列来递归切分数据集
    if best_gain>0:
        trueBranch=buildTree2(best_sets[0])
        falseBranch=buildTree2(best_sets[1])
        return decisionnode(col=best_criteria[0],value=best_criteria[1],tb=trueBranch,fb=falseBranch)
    else:
        return decisionnode(results=uniquecount(rows))

In [14]:
tree=buildTree2(my_data)
tree

<__main__.decisionnode at 0x24f9fcd98d0>

In [15]:
#前序输出显示树
def printtree(tree,indent=' '):
    if tree.results !=None:
        print( indent,str(tree.results) )
    else:
        print(indent,str(tree.col)+':'+str(tree.value)+'?')
        
        print(indent+'T->',)
        printtree(tree.tb,indent+' ')
        print(indent+'F->')
        printtree(tree.fb,indent+' ')

In [16]:
printtree(tree)

  0:google?
 T->
   3:21?
  T->
    {'Premium': 3}
  F->
    2:no?
   T->
     {'None': 1}
   F->
     {'Basic': 1}
 F->
   0:slashdot?
  T->
    {'None': 3}
  F->
    2:yes?
   T->
     {'Basic': 4}
   F->
     3:21?
    T->
      {'Basic': 1}
    F->
      {'None': 3}


In [17]:
#分类
def classify(observation,tree):
    #如果有result 说明是叶节点直接取值
    if tree.results !=None:
        return tree.results
    else:
        v=observation[tree.col]
        branch=None
        if isinstance(v,int) or isinstance(v,float):
            if v>=tree.value:
                branch=tree.tb
            else:
                branch=tree.fb
        else:
            if v==tree.value:
                branch=tree.tb
            else:
                branch=tree.fb
        return classify(observation,branch)

In [18]:
classify(['(direct)','USA','yes',5],tree)

{'Basic': 4}