# 绘制树形图

In [2]:
import matplotlib.pyplot as plt
%matplotlib
import numpy as np

Using matplotlib backend: Qt5Agg


In [2]:
# 要求，使用matplotlib画出该树状图
treeinfo={
    'tearRate':{
        'reduced':'no lenses',
        'normal':{
            'astigmatic':{
                'yes':{
                    'prescript':{
                        'hyper':{
                            'age':{
                                'pre':'no lenses',
                                'presbyopic':'no lenses',
                                'young':'hard'
                            }
                        },
                        'myope':'hard'
                    }
                },
                'no':{
                    'age':{
                        'pre':'soft',
                        'presbyopic':{
                            'prescript':{
                                'hyper':'soft',
                                'myope':'no lenses'
                            }
                        },
                        'young':'soft'
                    }
                }
            }
        }
    }
}

In [24]:
# 打印版本
def detree(tree,depth=0):
    '''调用时请勿修改depth参数'''
    treename=list(tree.keys())[0]
    branches=list(tree[treename].keys())
    print('      '*depth+'♟',treename+':') # 子树根
    for branch in branches:
        print('      '*depth+'·',branch)
        if isinstance(tree[treename][branch],dict): # 子树，递归
            detree(tree[treename][branch],depth+1)
            pass
        else: # 叶子节点
            print('      '*(depth+1)+'♀',tree[treename][branch])
    
detree(treeinfo)

♟ tearRate:
· reduced
      ♀ no lenses
· normal
      ♟ astigmatic:
      · yes
            ♟ prescript:
            · hyper
                  ♟ age:
                  · pre
                        ♀ no lenses
                  · presbyopic
                        ♀ no lenses
                  · young
                        ♀ hard
            · myope
                  ♀ hard
      · no
            ♟ age:
            · pre
                  ♀ soft
            · presbyopic
                  ♟ prescript:
                  · hyper
                        ♀ soft
                  · myope
                        ♀ no lenses
            · young
                  ♀ soft


In [27]:
# 更新版本
def detree(tree,prefix=''):
    '''调用时请勿修改prefix参数'''
    treename=list(tree.keys())[0]
    branches=list(tree[treename].keys())
    if prefix!='':
        t='|'.join(prefix.split('|')[:-1])+'|'
        print(t+'  ↳ ♟',treename+':') # 第一级子树根
    else:
        print(prefix+'♟',treename+':') # 子树根
    for branch in branches:
        print(prefix+'|-',branch)
        if isinstance(tree[treename][branch],dict): # 子树，递归
            detree(tree[treename][branch],prefix+'|    ')
            pass
        else: # 叶子节点
            print(prefix+'|  ↳ ♀',tree[treename][branch])
    
detree(treeinfo)
# 另见 ./DB/demo/dostree.py

♟ tearRate:
|- reduced
|  ↳ ♀ no lenses
|- normal
|  ↳ ♟ astigmatic:
|    |- yes
|    |  ↳ ♟ prescript:
|    |    |- hyper
|    |    |  ↳ ♟ age:
|    |    |    |- pre
|    |    |    |  ↳ ♀ no lenses
|    |    |    |- presbyopic
|    |    |    |  ↳ ♀ no lenses
|    |    |    |- young
|    |    |    |  ↳ ♀ hard
|    |    |- myope
|    |    |  ↳ ♀ hard
|    |- no
|    |  ↳ ♟ age:
|    |    |- pre
|    |    |  ↳ ♀ soft
|    |    |- presbyopic
|    |    |  ↳ ♟ prescript:
|    |    |    |- hyper
|    |    |    |  ↳ ♀ soft
|    |    |    |- myope
|    |    |    |  ↳ ♀ no lenses
|    |    |- young
|    |    |  ↳ ♀ soft


给定整棵树的根节点坐标（0,0）以及画布的宽高，经过一次遍历（靠，我经过了两次遍历），我需要计算出每一个节点确切的位置坐标，最终的效果是这棵树应能在给定的画布中充分伸展，已知根节点总是在画布的正中央的垂直轴线的顶处，且所有节点都满足该原则：其位于所在子树全部叶子节点的中部轴线上

In [18]:
def analysetree(tree,treeanalysis,depth=0,ifprint=False,father='ROOT'):
    '''函数返回全部叶子节点数量，调用时候，需在外部定义一个空字典传给参数treeanalysis，分析的结果将存于此处，
    其为一个同结构的树，只不过多了一些字段信息：每个节点都附加了__leaves字段，是为以该节点为根的子树的叶子数
    量，如果该节点本身为叶子节点，那么其__leaves的值就是1。depth参数请不要修改它，默认为0，它将追踪当前的深
    度。参数ifprint用来控制是否在遍历过程中即时打印当前节点的信息，默认不打印。father参数用于记录当前节点的
    父亲节点，它出现于打印的信息中，若修改father参数，只会改变整棵树的根节点的父亲，其实际并不存在，可以随意
    命名。我稍微修改了点，使得函数同时返回树的最大深度，因而现在函数的返回值是一个二元组：(叶子数,最大深度)'''
    if ifprint:
        print('depth:',depth,'\n#\nfather:',father)
    if not isinstance(tree,dict):
        if ifprint:
            print('leaf node:',tree,'\n#')
        treeanalysis['__leaves']=1
        treeanalysis['__leave']=tree # 记录叶子节点的名字
        return 1,depth
    treename=list(tree.keys())[0] # 作为所在子树的根节点名
    branches=list(tree[treename].keys()) # 全部分支名，分支将到达子节点
    
    treeanalysis[treename]={}
    
    leaves=0
    maxdepth=[]
    
    if ifprint:
        print('root node of current subtree:',treename,'\nbranches of the root node:',branches,'\nbranch info:')
        for branch in branches:
            print('    ',branch,':',tree[treename][branch])
        print('#')
    for branch in branches:
        treeanalysis[treename][branch]={}
        t=analysetree(tree[treename][branch],treeanalysis[treename][branch],depth+1,ifprint,treename)
        leaves+=t[0]
        maxdepth.append(t[1])
        
    treeanalysis[treename]['__leaves']=leaves
    
    return leaves,max(maxdepth)

#treeinfo={'no surfacing':{0:'no',1:{'flipper':{0:'no',1:'yes'}}}} # test
treeanalysis={}
ret=analysetree(treeinfo,treeanalysis,ifprint=False)
print('leaves of the whole tree:',ret[0])
print('depth of the whole tree:',ret[1])
treeanalysis

leaves of the whole tree: 9
depth of the whole tree: 4


{'tearRate': {'__leaves': 9,
  'normal': {'astigmatic': {'__leaves': 8,
    'no': {'age': {'__leaves': 4,
      'pre': {'__leave': 'soft', '__leaves': 1},
      'presbyopic': {'prescript': {'__leaves': 2,
        'hyper': {'__leave': 'soft', '__leaves': 1},
        'myope': {'__leave': 'no lenses', '__leaves': 1}}},
      'young': {'__leave': 'soft', '__leaves': 1}}},
    'yes': {'prescript': {'__leaves': 4,
      'hyper': {'age': {'__leaves': 3,
        'pre': {'__leave': 'no lenses', '__leaves': 1},
        'presbyopic': {'__leave': 'no lenses', '__leaves': 1},
        'young': {'__leave': 'hard', '__leaves': 1}}},
      'myope': {'__leave': 'hard', '__leaves': 1}}}}},
  'reduced': {'__leave': 'no lenses', '__leaves': 1}}}

In [19]:
def calcposition(tree,depth,wh=(2,1)):
    '''参数tree是一棵“临时树”，其为analysetree函数的中间处理结果（参数treeanalysis）。总是以(0,0)为整棵树根节点的
    位置坐标，wh参数记录了画布的宽度和高度，我们的树将尽量伸展于整个画布空间，treeposanalysis记录函数的处理
    结果，使用同analysetree的treeanalysis参数，其在函数外部定义，以一个空的字典传入，它将记录树中每个节点的位置坐标
    ，存储在字段__pos中'''
    def f(tree,treeposanalysis,currdepth,xbounds): # 这是有史以来我最TM痛苦的编程经历，太恐怖了，鬼知道我是怎么过来的，满脑子浆糊了已经。。
        if '__leave' in tree.keys(): # 叶子节点
            treeposanalysis['__leave']=tree['__leave']
            return
        treename=list(tree.keys())[0]
        branches=list(tree[treename].keys())
        leavenums=tree[treename]['__leaves']
        y=vincrease*currdepth
        subleavenums=[]
        subleavenames=[]
        subleavexbounds=[]
        subleavebranchs=[]
        treeposanalysis[treename]={}
        for branch in branches:
            if branch!='__leaves':
                if tree[treename][branch].get('__leave')!=None: # 叶子子节点
                    subleavenums.append(tree[treename][branch]['__leaves'])
                    subleavenames.append(tree[treename][branch]['__leave'])
                else: # 非叶子子节点
                    for key in tree[treename][branch].keys():
                        subleavenums.append(tree[treename][branch][key]['__leaves'])
                        subleavenames.append(list(tree[treename][branch].keys())[0])
                subleavebranchs.append(branch)
                treeposanalysis[treename][branch]={}
                
        #print(currdepth,leavenums,subleavenums,subleavenames)
        treeposanalysis[treename]['__pos']=(sum(xbounds)/2,y)
        scale=(xbounds[1]-xbounds[0])/leavenums
        start=xbounds[0]
        for i,e in enumerate(subleavenames):
            t=start+subleavenums[i]*scale
            subleavexbounds.append((start,t))
            start=t
        temp={} # 当前节点的子节点的xbounds集合
        for i,j,k in zip(subleavenames,subleavexbounds,subleavebranchs):
            temp[str(k)+'_'+str(i)]=j
        #print(y+vincrease,temp) # y+vincrease的值是当前节点的y轴坐标
        
        for branch in branches:
            if branch!='__leaves':
                treeposanalysis[treename][branch]={}
                if len(tree[treename][branch])==2:
                    name=tree[treename][branch]['__leave']
                    treeposanalysis[treename][branch]['__pos']=(sum(temp[str(branch)+'_'+str(name)])/2,y+vincrease)
                else:
                    name=list(tree[treename][branch].keys())[0]
                    
                f(tree[treename][branch],treeposanalysis[treename][branch],currdepth+1,temp[str(branch)+'_'+str(name)])
                
    vincrease=-wh[1]/depth # 垂直增距
    treeposanalysis={}
    f(tree,treeposanalysis,0,(0-wh[0]/2,0+wh[0]/2))
    return treeposanalysis

# 写完了，基本上也忘记具体的过程了，递归真NM恶心。。
postree=calcposition(treeanalysis,ret[1],wh=(2,2))
postree

{'tearRate': {'__pos': (0.0, -0.0),
  'normal': {'astigmatic': {'__pos': (0.11111111111111105, -0.5),
    'no': {'age': {'__pos': (0.5555555555555555, -1.0),
      'pre': {'__leave': 'soft', '__pos': (0.22222222222222215, -1.5)},
      'presbyopic': {'prescript': {'__pos': (0.5555555555555555, -1.5),
        'hyper': {'__leave': 'soft', '__pos': (0.44444444444444436, -2.0)},
        'myope': {'__leave': 'no lenses',
         '__pos': (0.6666666666666665, -2.0)}}},
      'young': {'__leave': 'soft', '__pos': (0.8888888888888888, -1.5)}}},
    'yes': {'prescript': {'__pos': (-0.33333333333333337, -1.0),
      'hyper': {'age': {'__pos': (-0.4444444444444445, -1.5),
        'pre': {'__leave': 'no lenses', '__pos': (-0.6666666666666667, -2.0)},
        'presbyopic': {'__leave': 'no lenses',
         '__pos': (-0.4444444444444445, -2.0)},
        'young': {'__leave': 'hard', '__pos': (-0.22222222222222227, -2.0)}}},
      'myope': {'__leave': 'hard',
       '__pos': (-5.551115123125783e-17, 

In [24]:
def plotarrownode(text,front,rear,nodetype=dict(boxstyle='square',facecolor='gray'),arrowtype=dict(arrowstyle='<|-',alpha=0.5)):
    plt.annotate(text,front,rear,arrowprops=arrowtype,bbox=nodetype)

def plotthetree(tree,wh=None,title=None,axis=False,delay=0):
    '''参数tree是一棵临时树，接受自函数calcposition的处理结果，wh为轴域显示的宽高，必填参数，通过控制此参数以及calcposition函数
    的wh参数，可以改善树状图的显示效果，当树比较瘦，而皱（某些字符重叠）在一起时，可以“拉伸”它。title为树状图设置名称。axis控
    制是否显示轴域的边框。delay接受一个大于等于0的数，当其大于0时，绘制动作会有间隔，如果等于0或者False，则无间隔延迟'''
    notleafnode=dict(boxstyle='square',fc='green',alpha=0.6) # 非叶子节点的bbox
    leafnode=dict(boxstyle='round4',facecolor='pink',alpha=0.8) # 叶子节点的bbox
    if wh!=None:
        #plt.axis('equal') # 这会导致一些问题，图像会变得散碎异常
        plt.axis([-wh[0]/2*(4/3),wh[0]/2*(5/4),-wh[1]*(4/3),0+wh[1]*0.25])
    if not axis:
        ax=plt.gca()
        ax.spines['bottom'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['right'].set_visible(False)
        plt.xticks([])
        plt.yticks([])
    if title!=None:
        plt.title(title,bbox=dict(boxstyle='Roundtooth',alpha=0.4,facecolor='gray'),color='red',size=16)
    flag=True
    def plottree(tree):
        treename=list(tree.keys())[0]
        branches=list(tree[treename].keys())
        treepos=tree[treename]['__pos'] # 当前节点（作为当前子树的根）的位置坐标
        #print(treepos)
        childpos=None
        nonlocal flag
        if flag:
            flag=False
            plt.text(0,0,treename,ha="center",va="center",size=15,bbox=notleafnode) # 绘制整棵树的根节点，只一次
        for branch in branches:
            if branch!='__pos':
                if tree[treename][branch].get('__leave')!=None: # 叶子子节点
                    childpos=leafpos=tree[treename][branch].get('__pos') # 当前节点的子（叶子）节点的位置坐标
                    leafname=tree[treename][branch].get('__leave')
                    #print(leafpos)
                    #plt.plot(*zip(treepos,leafpos))
                    plotarrownode(leafname,treepos,leafpos,leafnode)
                    if delay:
                        plt.pause(delay)
                else: # 非叶子子节点
                    childpos=subtreepos=tree[treename][branch][list(tree[treename][branch].keys())[0]]['__pos'] # 当前节点的子（非叶子）节点的位置坐标
                    subtreename=list(tree[treename][branch].keys())[0]
                    #print(subtreepos)
                    #plt.plot(*zip(treepos,subtreepos))
                    plotarrownode(subtreename,treepos,subtreepos,notleafnode)
                    if delay:
                        plt.pause(delay)
                    plottree(tree[treename][branch])
                plt.annotate('{0}'.format(branch),((treepos[0]+childpos[0])/2,(treepos[1]+childpos[1])/2),rotation=0,bbox=dict(boxstyle='round',facecolor='gray',alpha=0.2))
    
    plottree(tree)

plotthetree(postree,(1.5,2),title='decision tree',delay=0.2) # 试着改变wh的值：(2,2), (4,2), (1.5,2) 观察效果

![](./DB/image/1.png)

muggledy 2019.8.15