In [None]:
import numpy as np
import matplotlib.pyplot as plt
import re
import os
import imageio
import glob
import argparse
from IPython.display import Image
from IPython.display import HTML
from IPython.display import Markdown

RESULTS_ROOTS = {
    'mcnet': os.path.join('results', 'quantitative', 'MNIST'),
    'prednet': os.path.join('prednet', 'mnist_results', 'quantitative', 'MNIST')
}

IMAGE_ROOTS = {
    'mcnet': os.path.join('results', 'images', 'MNIST'),
    'prednet': os.path.join('prednet', 'mnist_results', 'images', 'MNIST')
}

In [None]:
'''
Display utils
'''

def show_gif(gif_path):
    display(HTML('<img src="%s">' % gif_path))

def show_image_row(image_paths, image_titles):
    assert(len(image_paths) == len(image_titles))
    html = '<table><tr>'
    for i in xrange(len(image_titles)):
        html += '<td style="text-align:center">%s</td>' % str(image_titles[i])
    html += '</tr><tr>'
    for i in xrange(len(image_paths)):
        html += '<td><img src="%s"></td>' \
            % (image_paths[i])
    html += '</tr></table>'
    
    display(HTML(html))

In [None]:
def get_results(arch, dataset_label, model_label):
    results_dir = os.path.join(RESULTS_ROOTS[arch], 'data=%s' % dataset_label, 'model=%s' % model_label)
    if not os.path.isdir(results_dir):
        raise RuntimeError('Path does not exist: %s' % results_dir)
    results_file_arr = filter(lambda x: x.endswith('.npz'), os.listdir(results_dir))
    if len(results_file_arr) == 0:
        raise RuntimeError('No results file found in path %s' % results_dir)
    elif len(results_file_arr) > 1:
        raise RuntimeError('Too many results files found in path %s' % results_dir)
    results_file_path = os.path.join(results_dir, results_file_arr[0])
    return np.load(results_file_path)


def get_sorted_videos_by_metric(arch, dataset_label, model_label, metric):
    results = get_results(arch, dataset_label, model_label)
    metric_vals = results[metric]
    metric_aucs = np.sum(metric_vals, axis=1)
    sorted_idx = np.argsort(metric_aucs)
    return sorted_idx, metric_aucs[sorted_idx]


def plot_sorted_auc(ax, arch, dataset_label, model_label, metric, label):
    # Plot
    _, sorted_metric_aucs = get_sorted_videos_by_metric(arch, dataset_label, model_label, metric)
    ax.plot(xrange(len(sorted_metric_aucs)), sorted_metric_aucs, label=label)

In [None]:
# fig = plt.figure()
# ax1 = fig.add_subplot(111)
# plot_sorted_auc(ax1, 'prednet', 'flashing=async_b+num_digits=2', 'flashing=async_b+num_digits=2', 'ssim', None)
# plot_sorted_auc(ax1, 'prednet', 'flashing=async_b+num_digits=2', 'flashing=async_a+num_digits=2', 'ssim', None)
# ax1.set_xlim([0, 1000])
# plt.show()

In [None]:
def plot_arch_comparison(dataset_label, model_label, metric):
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    for arch in RESULTS_ROOTS.iterkeys():
        plot_sorted_auc(ax1, arch, dataset_label, model_label, metric, arch)
    ax1.set_xlabel('Instance #')
    ax1.set_ylabel('%s AUC' % metric.upper())
    ax1.legend(loc='best')
    ax1.set_title('Architecture comparison\nTraining set: %s\nTesting set: %s' % (model_label, dataset_label))
    plt.show()

# plot_arch_comparison('flashing=async_b+num_digits=2', 'flashing=async_a+num_digits=2', 'ssim')

In [None]:
def plot_training_set_comparison(dataset_label, metric):
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    
    # Find all models trained on this dataset
    for arch in RESULTS_ROOTS.iterkeys():
        dataset_results_root = os.path.join(RESULTS_ROOTS[arch], 'data=%s' % dataset_label)
        if not os.path.isdir(dataset_results_root):
            continue
        model_labels = [x.replace('model=', '') for x in os.listdir(dataset_results_root)]

        # Plot
        for model_label in model_labels:
            try:
                plot_sorted_auc(ax1, arch, dataset_label, model_label, metric, 'arch=%s,train=%s' % (arch, model_label))
            except RuntimeError:
                pass
    
    # Format plot
    ax1.set_xlabel('Instance #')
    ax1.set_ylabel('%s AUC' % metric.upper())
    ax1.set_ylim([4, 5.1])
    ax1.legend(loc='lower left', bbox_to_anchor=(1, 0))
    ax1.set_title('Training set comparison\nTesting set: %s' % dataset_label)
    plt.show()

# plot_training_set_comparison('flashing=async_b+num_digits=2', 'ssim')

In [None]:
all_test_dataset_labels = [x.replace('data=', '') for x in os.listdir(RESULTS_ROOTS['mcnet'])]
for dataset_label in all_test_dataset_labels:
    plot_training_set_comparison(dataset_label, 'ssim')

In [None]:
def summarize_experiment(arch, dataset_label, model_label, metric):
    display(HTML('<hr style="height:2px">'))
    display(HTML('<span style="font-weight:bold;font-size:20px">Architecture: %s<br>Train:%s<br>Test:%s</span>' \
                % (arch, model_label, dataset_label)))
    
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    plot_sorted_auc(ax1, arch, dataset_label, model_label, metric, None)
    ax1.set_xlabel('Instance #')
    ax1.set_ylabel('%s AUC' % metric.upper())
#     ax1.set_ylim([3.5, 5.1])
    ax1.legend(loc='lower left', bbox_to_anchor=(1, 0))
    plt.show()
    
    sorted_idx, _ = get_sorted_videos_by_metric(arch, dataset_label, model_label, metric)
    print('Worst videos:')
    image_paths = [os.path.join(IMAGE_ROOTS[arch], 'data=%s' % dataset_label, 'model=%s' % model_label, 'both_%04d.gif' % i) for i in sorted_idx[:5]]
    show_image_row(image_paths, sorted_idx[:5])
    print('Best videos:')
    image_paths = [os.path.join(IMAGE_ROOTS[arch], 'data=%s' % dataset_label, 'model=%s' % model_label, 'both_%04d.gif' % i) for i in sorted_idx[-5:]]
    show_image_row(image_paths, sorted_idx[-5:])

# summarize_experiment('mcnet', 'translation=on+num_digits=2', 'translation=on', 'ssim')

In [None]:
for arch in RESULTS_ROOTS.iterkeys():
    dataset_labels = [x.replace('data=', '') for x in os.listdir(RESULTS_ROOTS[arch])]
    for dataset_label in dataset_labels:
        model_labels = [x.replace('model=', '') for x in os.listdir(os.path.join(RESULTS_ROOTS[arch], 'data=%s' % dataset_label))]
        for model_label in model_labels:
            try:
                summarize_experiment(arch, dataset_label, model_label, 'ssim')
            except RuntimeError:
                print('Skipping %s' % ((arch, dataset_label, model_label),))
