In [103]:
%matplotlib inline
from __future__ import print_function, division
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns

In [115]:
#set what to plot
metric='auprcs'
metricname='AUPRC'

modelnames = {
    'freeconv':r'$M_u$',
    'motifconv':r'$M_c$',
    'dumbmotifconv':r'$M_d$',
    'logreg':r'$M_\ell$',
    'peaks_freeconv':r'$M_u$ (top peaks)',
    'peaks_motifconv':r'$M_c$ (top peaks)',
    'peaks_logreg':r'$M_\ell$ (top peaks)'

}

annotate=False

models = ['freeconv', 'motifconv', 'dumbmotifconv', 'logreg']
figname='comparison.allpeaks'
set_lim=0.2

# models = ['peaks_freeconv', 'peaks_motifconv', 'peaks_logreg']
# figname='comparison.toppeaks'
# set_lim=0.2

In [116]:
#figure aesthetics
plt.close()
fig = plt.figure(figsize=(6,6))
gs = gridspec.GridSpec(len(models)-1,len(models)-1)

line_props = {
        'color':'gray',
        'linestyle':'--',
        'linewidth':0.5,
        'alpha':0.8
        }
labelfontsize=8
tickprops = {
        'direction':'out',
        'length':2,
        'width':0.8,
        'pad':4,
        'labelsize':7}

<matplotlib.figure.Figure at 0x1a1fa07c90>

In [117]:
#plotting code
for i, m1 in enumerate(models):
    for j, m2 in enumerate(models):
        if j<=i:
            continue
        ax = plt.subplot(gs[i,j-1])
        print(m1, m2)
        m1df = pd.read_csv('stats.'+ m1 + '.tsv', sep='\t').rename(columns={metric:'metric'})
        m2df = pd.read_csv('stats.'+ m2 + '.tsv', sep='\t').rename(columns={metric:'metric'})
        both = pd.merge(m2df[['id','metric']], m1df[['id','metric']], on='id')
        ax.scatter(both.metric_x, both.metric_y, s=2)
        ax.plot([0,1], [0,1], **line_props)
        if set_lim is None:
            lim = max(both.metric_x.max(), both.metric_y.max())*1.1
        else:
            lim = set_lim
        ax.set_xlim(0,lim)
        ax.set_ylim(0,lim)
        ax.set_ylabel(modelnames[m1] + ' ' + metricname, fontsize=labelfontsize)
        ax.set_xlabel(modelnames[m2] + ' ' + metricname, fontsize=labelfontsize)
        ax.set_xticks([0,lim])
        ax.set_yticks([0,lim])
        ax.set_xticklabels(['0', '{:.1f}'.format(lim)])
        ax.set_yticklabels(['0', '{:.1f}'.format(lim)])
        ax.tick_params(**tickprops)
        
        if annotate:
            for k, txt in enumerate(both['id']):
                ax.annotate(txt, (both.metric_x[k], both.metric_y[k]), fontsize=labelfontsize-2)
        print('==')

sns.despine()
plt.tight_layout()
# plt.show()
plt.savefig('../figures/'+figname+'.pdf'); plt.close()

freeconv motifconv
==
freeconv dumbmotifconv
==
freeconv logreg
==
motifconv dumbmotifconv
==
motifconv logreg
==
dumbmotifconv logreg
==
