In [None]:
import numpy as np
import matplotlib.pyplot as plt
import re
import os
import imageio
import glob
import argparse
from pprint import pprint
from warnings import warn
from scipy.stats import zscore

from IPython.display import Image
from IPython.display import HTML
from IPython.display import Markdown
from IPython.display import display

ARCHS = ['mcnet', 'prednet']

RESULTS_ROOTS = {
    'mcnet': os.path.join('results', 'quantitative', 'MNIST'),
    'prednet': os.path.join('prednet', 'mnist_results', 'quantitative', 'MNIST')
}

IMAGE_ROOTS = {
    'mcnet': os.path.join('results', 'images', 'MNIST'),
    'prednet': os.path.join('prednet', 'mnist_results', 'images', 'MNIST')
}

In [None]:
'''
Display utils
'''

def get_gif_path(arch, dataset_label, model_label, video_idx):
    return os.path.join(IMAGE_ROOTS[arch], 'data=%s' % dataset_label,
                        'model=%s' % model_label, 'both_%04d.gif' % video_idx)

def show_image_row(image_paths, image_titles):
    assert(len(image_paths) == len(image_titles))
    html = '<table><tr>'
    for i in xrange(len(image_titles)):
        html += '<td style="text-align:center">%s</td>' % str(image_titles[i])
    html += '</tr><tr>'
    for i in xrange(len(image_paths)):
        html += '<td><img src="%s"></td>' \
            % (image_paths[i])
    html += '</tr></table>'
    
    display(HTML(html))

def visualize_videos(arch, dataset_label, model_label, sorted_idx):
    print('Worst videos:')
    image_paths = [get_gif_path(arch, dataset_label, model_label, i) for i in sorted_idx[:5]]
    show_image_row(image_paths, sorted_idx[:5])
    print('Middle videos:')
    mid_lower = len(sorted_idx) / 2 - 2
    mid_upper = len(sorted_idx) / 2 + 3
    image_paths = [get_gif_path(arch, dataset_label, model_label, i) for i in sorted_idx[mid_lower:mid_upper]]
    show_image_row(image_paths, sorted_idx[mid_lower:mid_upper])
    print('Best videos:')
    image_paths = [get_gif_path(arch, dataset_label, model_label, i) for i in sorted_idx[-5:]]
    show_image_row(image_paths, sorted_idx[-5:])

In [None]:
def load_experiments():
    '''
    Prepare a dictionary listing all the combinations of architecture, training set, and test set found
    in the result paths. Also prints warnings for inconsistencies like missing output videos or experi-
    ments only found for one architecture.
    '''
    
    ret = {arch: {} for arch in ARCHS}
    for arch in ARCHS:
        dataset_labels = [x.replace('data=', '') for x in os.listdir(RESULTS_ROOTS[arch])]
        if len(dataset_labels) == 0:
            warn('No datasets found for architecture %s, skipping' % arch)
        else:
            for dataset_label in dataset_labels:
                model_labels = [x.replace('model=', '') 
                               for x in os.listdir(os.path.join(RESULTS_ROOTS[arch], 'data=%s' % dataset_label))]
                if len(model_labels) == 0:
                    warn('No models found for architecture %s, data %s. Skipping' % (arch, dataset_label))
                else:
                    ret[arch][dataset_label] = model_labels
    
    # Check that each architecture has the same datasets
    print('Comparing dataset %s with dataset %s' % tuple(ARCHS))
    compare_experiment_structure(ret[ARCHS[0]], ret[ARCHS[1]])
    # Check that images exist
    for arch, arch_dict in ret.iteritems():
        print('Checking result files for architecture %s' % arch)
        print('Checking images for architecture %s' % arch)
        check_images(arch, arch_dict)
    
    return ret

def compare_experiment_structure(a, b):
    '''
    Check that the experiments in dictionary a match with those in b. Report inconsistencies
    '''
    for dataset_label in a.iterkeys():
        for model_label in a[dataset_label]:
            if dataset_label not in b.keys() or model_label not in b[dataset_label]:
                warn('Found experiment (data=%s, model=%s) in first dictionary but not second' \
                     % (dataset_label, model_label))
    for dataset_label in b.iterkeys():
        for model_label in b[dataset_label]:
            if dataset_label not in a.keys() or model_label not in a[dataset_label]:
                warn('Found experiment (data: %s, model: %s) in second dictionary but not first' \
                     % (dataset_label, model_label))

def check_result_files(arch, a):
    '''
    Check that exactly one npz file exists in each experiment in dictionary a
    '''
    for dataset_label in a.iterkeys():
        for model_label in a[dataset_label]:
            result_root = os.path.join(RESULTS_ROOTS[arch], 'data=%s' % dataset_label, 'model=%s' % model_label)
            if not os.path.isdir(result_root):
                warn('Result directory does not exist for (arch: %s, data: %s, model: %s)' \
                     % (arch, dataset_label, model_label))
            else:
                npz_files = [x for x in os.path.listdir(result_root) if x.endswith('.npz')]
                if len(npz_files) == 0:
                    warn('No .npz files found for (arch: %s, data: %s, model: %s)' \
                     % (arch, dataset_label, model_label))
                elif len(npz_files) > 1:
                    warn('Too many .npz files found for (arch: %s, data: %s, model: %s)' \
                     % (arch, dataset_label, model_label))
                
def check_images(arch, a):
    '''
    Check that the images for experiments in dictionary a exist
    '''
    for dataset_label in a.iterkeys():
        for model_label in a[dataset_label]:
            image_root = os.path.join(IMAGE_ROOTS[arch], 'data=%s' % dataset_label, 'model=%s' % model_label)
            if not os.path.isdir(image_root):
                warn('Image directory does not exist for (arch: %s, data: %s, model: %s)' \
                     % (arch, dataset_label, model_label))
            else:
                image_names = [x for x in os.listdir(image_root) if x.endswith('.gif')]
                if len(image_names) == 0:
                    warn('No images found for (arch: %s, data: %s, model: %s)' \
                         % (arch, dataset_label, model_label))

                    
def get_results(arch, dataset_label, model_label):
    '''
    Retrieve results array
    '''
    results_dir = os.path.join(RESULTS_ROOTS[arch], 'data=%s' % dataset_label, 'model=%s' % model_label)
    if not os.path.isdir(results_dir):
        raise RuntimeError('Path does not exist: %s' % results_dir)
    results_file_arr = filter(lambda x: x.endswith('.npz'), os.listdir(results_dir))
    if len(results_file_arr) == 0:
        raise RuntimeError('No results file found in path %s' % results_dir)
    elif len(results_file_arr) > 1:
        raise RuntimeError('Too many results files found in path %s' % results_dir)
    results_file_path = os.path.join(results_dir, results_file_arr[0])
    return np.load(results_file_path, mmap_mode='r')


def get_sorted_videos_by_metric(arch, dataset_label, model_label, metric):
    '''
    Get the list of videos sorted by the sum of the given metric over all time steps
    '''
    results = get_results(arch, dataset_label, model_label)
    metric_vals = results[metric]
    metric_aucs = np.sum(metric_vals, axis=1)
    sorted_idx = np.argsort(metric_aucs)
    return sorted_idx, metric_aucs[sorted_idx]


def plot_sorted_auc(ax, arch, dataset_label, model_label, metric, label, max_value=1.0):
    '''
    Plot ranked AUC plot. max_value specifies the maximum area under the plot if plotting normalized
    AUC is desired.
    '''
    # Plot
    _, sorted_metric_aucs = get_sorted_videos_by_metric(arch, dataset_label, model_label, metric)
    ax.plot(xrange(len(sorted_metric_aucs)), sorted_metric_aucs / float(max_value), label=label)


def plot_arch_comparison(dataset_label, model_label, metric, max_value=1.0):
    '''
    Plot the ranked performance plot between each architecture for a given train/test set combination
    '''
    display(HTML('<hr style="height:2px">'))
    display(HTML('<h1>Train:%s<br>Test:%s</h1>' % (model_label, dataset_label)))
    
    
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    for arch in ARCHS:
        plot_sorted_auc(ax1, arch, dataset_label, model_label, metric, arch, max_value=max_value)
    ax1.set_xlabel('Instance #')
    ax1.set_ylabel('%s AUC' % metric.upper())
    ax1.legend(loc='lower left', bbox_to_anchor=(1, 0))
    ax1.set_title('Architecture comparison\nTraining set: %s\nTesting set: %s' % (model_label, dataset_label))
    plt.show()
    
    for arch in ARCHS:
        display(Markdown('## Architecture: %s' % arch))
        sorted_idx, _ = get_sorted_videos_by_metric(arch, dataset_label, model_label, metric)
        visualize_videos(arch, dataset_label, model_label, sorted_idx)


def plot_training_set_comparison(exps, dataset_label, metric, max_value=1.0):
    '''
    Plot performance of all models that were tested on the given dataset on one graph.
    '''
    
    display(HTML('<hr style="height:2px">'))
    display(HTML('<h1>Test:%s</h1>' % (dataset_label)))
    
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    
    for arch in ARCHS:
        if dataset_label not in exps[arch]:
            warn('Did not find dataset %s for architecture %s, skipping' % (dataset_label, arch))
            continue
        model_labels = exps[arch][dataset_label]

        # Plot
        for model_label in model_labels:
            plot_sorted_auc(ax1, arch, dataset_label, model_label, metric,
                            'arch=%s,train=%s' % (arch, model_label), max_value=max_value)
    
    # Format plot
    display(HTML('<h2>Sorted SSIM AUC graph</h2>'))
    ax1.set_xlabel('Instance #')
    ax1.set_ylabel('%s AUC' % metric.upper())
    ax1.legend(loc='lower left', bbox_to_anchor=(1, 0))
    ax1.set_title('Training set comparison\nTesting set: %s' % dataset_label)
    plt.show()
    
    for arch in ARCHS:
        if dataset_label not in exps[arch]:
            warn('Did not find dataset %s for architecture %s, skipping' % (dataset_label, arch))
            continue
        model_labels = exps[arch][dataset_label]

        for model_label in model_labels:
            # Show videos
            display(HTML('<h2>Architecture: %s<br>Train: %s</h2>' % (arch, model_label)))
            sorted_idx, _ = get_sorted_videos_by_metric(arch, dataset_label, model_label, metric)
            visualize_videos(arch, dataset_label, model_label, sorted_idx)


def summarize_experiment(arch, dataset_label, model_label, metric, max_value=1.0):
    '''
    Show the ranked AUC plot and videos for one experiment.
    '''
    display(HTML('<hr style="height:2px">'))
    display(HTML('<h2>Architecture: %s<br>Train:%s<br>Test:%s</h2>' \
                % (arch, model_label, dataset_label)))
    
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    plot_sorted_auc(ax1, arch, dataset_label, model_label, metric, None, max_value=max_value)
    ax1.set_xlabel('Instance #')
    ax1.set_ylabel('%s AUC' % metric.upper())
    plt.show()
    
    sorted_idx, _ = get_sorted_videos_by_metric(arch, dataset_label, model_label, metric)
    visualize_videos(arch, dataset_label, model_label, sorted_idx)


def summarize_discriminator_output(dataset_label, model_label, use_zscore=False):
    '''
    Plot the discriminator score of MCNet for the given training and test set.
    Only valid for long-term video datasets
    '''
    assert(dataset_label.endswith('_long'))
    
    display(HTML('<hr style="height:2px">'))
    display(HTML('<h2>Train:%s<br>Test:%s</h2>' \
                % (model_label, dataset_label)))
    
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    results = get_results('mcnet', dataset_label, model_label)
    # Skip first column, which is zeros for some reason
    disc_output = results['disc_output'][:, 1:]
    if use_zscore:
        disc_output = zscore(disc_output)
    # Define zero-indexed frame indexes for discriminator output
    frame_indexes = np.arange(disc_output.shape[1]) + 15 + 1
    sorted_idx, _ = get_sorted_videos_by_metric('mcnet', dataset_label, model_label, 'ssim')
    
    for inv_rank in xrange(3):
        video_idx = sorted_idx[inv_rank]
        rank = len(sorted_idx) - inv_rank
        disc_output_cur_vid = disc_output[video_idx, :]
        ax1.plot(frame_indexes[::10], disc_output_cur_vid[::10], label='%04d (rank %d)' % (video_idx, rank))

    for inv_rank in xrange(len(sorted_idx)-3,len(sorted_idx)):
        video_idx = sorted_idx[inv_rank]
        rank = len(sorted_idx) - inv_rank
        disc_output_cur_vid = disc_output[video_idx, :]
        ax1.plot(frame_indexes[::10], disc_output_cur_vid[::10], label='%04d (rank %d)' % (video_idx, rank), linestyle=':')

    # Plot mean discriminator score over time
    mean_output = np.mean(disc_output, axis=0)
    ax1.plot(frame_indexes[::10], mean_output[::10], label='mean', linestyle='--')
        
    # Format plot
    ax1.legend(loc='lower left', bbox_to_anchor=(1, 0))
    ax1.set_xlim([16, 500])
    ax1.set_xlabel('Frame index (0 = first GT frame)')
    if use_zscore:
        ax1.set_ylabel('Discriminator output (z-scored)')
        ax1.set_title('Discriminator output for long-term videos (z-scored)\nTraining set: %s\nTesting set: %s' \
                     % (model_label, dataset_label) )
    else:
        ax1.set_ylabel('Discriminator output')
        ax1.set_title('Discriminator output for long-term videos\nTraining set: %s\nTesting set: %s' \
                     % (model_label, dataset_label) )
    plt.show()
    
    # Show examples
    visualize_videos('mcnet', dataset_label, model_label, sorted_idx)

def zscore(a):
    mean = np.mean(a)
    std = np.std(a)
    return (a - mean) / std

# Read and validate the experiments on disk (MUST RUN BEFORE PLOTTING)

In [None]:
experiments = load_experiments()

# Debug individual plots

In [None]:
# plot_arch_comparison('flashing=async_b+num_digits=2', 'flashing=async_a+num_digits=2', 'ssim')
# plot_training_set_comparison(experiments, 'flashing=async_b+num_digits=2', 'ssim')
# summarize_experiment('mcnet', 'translation=on+num_digits=2', 'translation=on', 'ssim')
# summarize_discriminator_output('rotation=no_limit+num_digits=2_long', 'rotation=no_limit', use_zscore=True)

# Short-term evaluation

The cell below shows the results of short-term prediction (i.e. given the first 10 frames, predict the next 5) for each test set. Note that this evaluates over the same time-steps that the model was trained on. Each plot corresponds to one test dataset, and shows the ranked AUC plot for every model that was evaluated on that dataset.

In [None]:
test_dataset_labels = [x.replace('data=', '') for x in os.listdir(RESULTS_ROOTS['mcnet']) if not x.endswith('_long')]
for dataset_label in test_dataset_labels:
    plot_training_set_comparison(experiments, dataset_label, 'ssim', max_value=5)

# Long-term evaluation

The cell below shows the results of long-term prediction (i.e. given the first 10 frames, predict the next 490 frames) for each test set. Since the models were only trained to predict the next 5 frames, this explores how well the learned short-term dynamics fit the long-term dynamics. Each plot corresponds to one test dataset, and shows the ranked AUC plot for every model that was evaluated on that dataset.

In [None]:
test_dataset_labels = [x.replace('data=', '') for x in os.listdir(RESULTS_ROOTS['mcnet']) if x.endswith('_long')]
for dataset_label in test_dataset_labels:
    plot_training_set_comparison(experiments, dataset_label, 'ssim', max_value=485)

# Discriminator scores for long-term prediction

The MCNet discriminator takes as input a 15-frame sequence, and outputs a value in [0, 1] corresponding to the confidence that the input is a real (not-generated) sequence (i.e. a real sequence should have confidence 1). To generate the plots below, a sliding window provides input for the discriminator. The x-axis indicates the index of the last frame in the input sequence, and the y-axis is the z-score of the current output over all discriminator scores in the video set. Each plot shows the output for the worst five videos (solid lines) and the best five videos (dotted lines).

In [None]:
for dataset_label in experiments['mcnet']:
    if not dataset_label.endswith('_long'): continue
    for model_label in experiments['mcnet'][dataset_label]:
        summarize_discriminator_output(dataset_label, model_label, use_zscore=True)