In [1]:
%matplotlib notebook

In [2]:
import matplotlib
matplotlib.use('Agg')

because the backend has already been chosen;
matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
or matplotlib.backends is imported for the first time.



In [3]:
import os
import pickle as pkl
import numpy as np
from matplotlib import pyplot as plt

In [4]:
from cycler import cycler
plt.rc('axes',
       prop_cycle=(
           cycler('color', ['b', 'g', 'r', 'y', 'g', 'b']) +
           cycler('linestyle', ['-', '-', '-', '-', '--', '-']) +
           cycler('marker', ['', '', '', '', 'o', '*'])
       ))

In [5]:
small_cascade = False

In [6]:
qs = [0.002, 0.004, 0.008, 0.016, 0.032]
if not small_cascade:
    cascade_ids = list(range(0, 19))
else:
    cascade_ids = list(range(2, 3))
methods = ['no-order', 'tbfs', 'greedy', 'closure']

In [7]:
if not small_cascade:
    result_dir = 'outputs/real_cascade_experiment/cascade_{}/'
else:
    result_dir = 'outputs/real_cascade_experiment/small_cascade_{}/'

In [8]:
rows = []
for q in qs:
    row = {}
    for method in methods:
        row[method] = {}
        for i in cascade_ids:
            # print('cascade {}'.format(i))
            # gather count statistics across all casacades,
            try:
                path = os.path.join(result_dir.format(i), method, "{}.pkl".format(q))
                r = pkl.load(open(path, 'rb'))
                
                if len(row[method]) == 0:
                    for key in r.columns:
                        row[method][key] = 0
                    row[method]['count'] = 0
            
            except FileNotFoundError:
                print(path, ": not found")
            
            if small_cascade:
                for _, rr in r.iterrows():
                    for key in r.columns:
                        row[method][key] += rr[key]
                    row[method]['count'] += 1
            else:
                count = r['n.correct_nodes']['count']                
                for key in r.columns:
                    row[method][key] += r[key]['mean'] * count
                row[method]['count'] += count

        row[method]['tree-size'] = row[method]['n.pred_edges'] / row[method]['count']
        row[method]['precision'] = row[method]['n.correct_nodes'] / row[method]['n.pred_nodes']
        row[method]['recall'] = row[method]['n.correct_nodes'] / row[method]['n.true_nodes']
        row[method]['order-accuracy'] = row[method]['n.correct_edges'] / row[method]['n.pred_edges']        
    rows.append(row)


In [9]:
measures = ['tree-size', 'precision', 'recall', 'order-accuracy']

In [10]:
# measure x qs x methods
data = np.zeros((4, len(qs), len(methods)))
for i, (q, row) in enumerate(zip(qs, rows)):
    for j, method in enumerate(methods):
        for k, measure in enumerate(measures):
            data[k, i, j] = row[method][measure]


In [11]:
qs = np.array(qs)
x = np.log2(qs / qs[0])


In [12]:
from plot_utils import make_line_cycle
line_cycle = make_line_cycle()

In [24]:
lines = []
for k, measure in enumerate(measures):
    fig = plt.figure(figsize=(3, 2.6))  # 2.25
    ax = fig.add_subplot(111)
    for j, m in enumerate(methods):
        l, = ax.plot(x, data[k,:,j], next(line_cycle), markersize=7.5, alpha=0.75)
        if len(lines) < len(methods):
            lines.append(l)
    ax.set_title(measure, fontsize=14)
    ax.set_xlabel('prop. of reports', fontsize=14)

    max_ticks=3
    ax.yaxis.set_major_locator(plt.MaxNLocator(max_ticks))
    ax.xaxis.set_major_locator(plt.MaxNLocator(max_ticks))
    
    xticks = ax.get_xticks()

    labels = list(map(lambda v: "$2^{}$".format(int(v)), xticks))
    ax.set_xticklabels(labels)
        
    # 'precision'
    if measure in ['order-accuracy']:
        ax.set_ylim(0.5, 1.0)
    # if not small_cascade:
    ax.text(0, 0.1, r"$0.002\times$ ", transform=plt.gcf().transFigure)    
    plt.tight_layout()
    if not small_cascade:
        fig.savefig('figs/digg/{}.pdf'.format(measure))
    else:
        fig.savefig('figs/digg_small/{}.pdf'.format(measure))


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [24]:
figlegend = plt.figure(figsize=(2.2 * len(methods), 0.5))
ax = fig.add_subplot(111)

label_mapping = {'closure': 'closure', 'tbfs': 'delayed-bfs', 'no-order': 'steiner', 'greedy': 'greedy'}
labels = list(map(label_mapping.__getitem__, methods))
figlegend.legend(lines, labels, 'center', ncol=len(methods), fontsize=18)
figlegend.show()
if not small_cascade:
    figlegend.savefig('figs/digg/legend.pdf')
else:
    figlegend.savefig('figs/digg_small/legend.pdf')


<IPython.core.display.Javascript object>