In [None]:
import numpy as np
import matplotlib.pyplot as plt
import re
import os
import imageio
import glob
import argparse
from pprint import pprint
from warnings import warn
from scipy.stats import zscore

from IPython.display import Image
from IPython.display import HTML
from IPython.display import Markdown
from IPython.display import display

ARCHS = ['mcnet', 'prednet', 'drnet']

RESULTS_ROOTS = {
    'mcnet': os.path.join('results', 'quantitative', 'MNIST'),
    'prednet': os.path.join('prednet', 'mnist_results', 'quantitative', 'MNIST'),
    'drnet': os.path.join('drnet', 'results', 'MNIST', 'quantitative')
}

IMAGE_ROOTS = {
    'mcnet': os.path.join('results', 'images', 'MNIST'),
    'prednet': os.path.join('prednet', 'mnist_results', 'images', 'MNIST'),
    'drnet': os.path.join('drnet', 'results', 'MNIST', 'images')
}

In [None]:
'''
Display utils
'''

def get_gif_path(arch, dataset_label, model_label, video_idx):
    return os.path.join(IMAGE_ROOTS[arch], 'data=%s' % dataset_label,
                        'model=%s' % model_label, 'both_%04d.gif' % video_idx)

def show_image_row(image_paths, image_titles):
    assert(len(image_paths) == len(image_titles))
    html = '<table><tr>'
    for i in xrange(len(image_titles)):
        html += '<td style="text-align:center">%s</td>' % str(image_titles[i])
    html += '</tr><tr>'
    for i in xrange(len(image_paths)):
        html += '<td><img src="%s"></td>' \
            % (image_paths[i])
    html += '</tr></table>'
    
    display(HTML(html))

def visualize_videos(arch, dataset_label, model_label, sorted_idx):
    print('Worst videos:')
    image_paths = [get_gif_path(arch, dataset_label, model_label, i) for i in sorted_idx[:5]]
    show_image_row(image_paths, sorted_idx[:5])
    print('Middle videos:')
    mid_lower = len(sorted_idx) / 2 - 2
    mid_upper = len(sorted_idx) / 2 + 3
    image_paths = [get_gif_path(arch, dataset_label, model_label, i) for i in sorted_idx[mid_lower:mid_upper]]
    show_image_row(image_paths, sorted_idx[mid_lower:mid_upper])
    print('Best videos:')
    image_paths = [get_gif_path(arch, dataset_label, model_label, i) for i in sorted_idx[-5:]]
    show_image_row(image_paths, sorted_idx[-5:])

In [None]:
def load_experiments():
    '''
    Prepare a dictionary listing all the combinations of architecture, training set, and test set found
    in the result paths. Also prints warnings for inconsistencies like missing output videos or experi-
    ments only found for one architecture.
    '''
    
    ret = {arch: {} for arch in ARCHS}
    for arch in ARCHS:
        dataset_labels = [x.replace('data=', '') for x in os.listdir(RESULTS_ROOTS[arch])]
        if len(dataset_labels) == 0:
            warn('No datasets found for architecture %s, skipping' % arch)
        else:
            for dataset_label in dataset_labels:
                model_labels = [x.replace('model=', '') 
                               for x in os.listdir(os.path.join(RESULTS_ROOTS[arch], 'data=%s' % dataset_label))]
                if len(model_labels) == 0:
                    warn('No models found for architecture %s, data %s. Skipping' % (arch, dataset_label))
                else:
                    ret[arch][dataset_label] = model_labels
    
    # Check that each architecture has the same datasets
    print('Comparing dataset %s with dataset %s' % (ARCHS[0], ARCHS[1]))
    compare_experiment_structure(ret[ARCHS[0]], ret[ARCHS[1]])
    print('Comparing dataset %s with dataset %s' % (ARCHS[1], ARCHS[2]))
    compare_experiment_structure(ret[ARCHS[1]], ret[ARCHS[2]])
    # Check that images exist
    for arch, arch_dict in ret.iteritems():
        print('Checking result files for architecture %s' % arch)
        print('Checking images for architecture %s' % arch)
        check_images(arch, arch_dict)
    
    return ret

def compare_experiment_structure(a, b):
    '''
    Check that the experiments in dictionary a match with those in b. Report inconsistencies
    '''
    for dataset_label in a.iterkeys():
        for model_label in a[dataset_label]:
            if dataset_label not in b.keys() or model_label not in b[dataset_label]:
                warn('Found experiment (data=%s, model=%s) in first dictionary but not second' \
                     % (dataset_label, model_label))
    for dataset_label in b.iterkeys():
        for model_label in b[dataset_label]:
            if dataset_label not in a.keys() or model_label not in a[dataset_label]:
                warn('Found experiment (data: %s, model: %s) in second dictionary but not first' \
                     % (dataset_label, model_label))

def check_result_files(arch, a):
    '''
    Check that exactly one npz file exists in each experiment in dictionary a
    '''
    for dataset_label in a.iterkeys():
        for model_label in a[dataset_label]:
            result_root = os.path.join(RESULTS_ROOTS[arch], 'data=%s' % dataset_label, 'model=%s' % model_label)
            if not os.path.isdir(result_root):
                warn('Result directory does not exist for (arch: %s, data: %s, model: %s)' \
                     % (arch, dataset_label, model_label))
            else:
                npz_files = [x for x in os.path.listdir(result_root) if x.endswith('.npz')]
                if len(npz_files) == 0:
                    warn('No .npz files found for (arch: %s, data: %s, model: %s)' \
                     % (arch, dataset_label, model_label))
                elif len(npz_files) > 1:
                    warn('Too many .npz files found for (arch: %s, data: %s, model: %s)' \
                     % (arch, dataset_label, model_label))
                
def check_images(arch, a):
    '''
    Check that the images for experiments in dictionary a exist
    '''
    for dataset_label in a.iterkeys():
        for model_label in a[dataset_label]:
            image_root = os.path.join(IMAGE_ROOTS[arch], 'data=%s' % dataset_label, 'model=%s' % model_label)
            if not os.path.isdir(image_root):
                warn('Image directory does not exist for (arch: %s, data: %s, model: %s)' \
                     % (arch, dataset_label, model_label))
            else:
                image_names = [x for x in os.listdir(image_root) if x.endswith('.gif')]
                if len(image_names) == 0:
                    warn('No images found for (arch: %s, data: %s, model: %s)' \
                         % (arch, dataset_label, model_label))

                    
def get_results(arch, dataset_label, model_label):
    '''
    Retrieve results array
    '''
    results_dir = os.path.join(RESULTS_ROOTS[arch], 'data=%s' % dataset_label, 'model=%s' % model_label)
    if not os.path.isdir(results_dir):
        raise RuntimeError('Path does not exist: %s' % results_dir)
    results_file_arr = filter(lambda x: x.endswith('.npz'), os.listdir(results_dir))
    if len(results_file_arr) == 0:
        raise RuntimeError('No results file found in path %s' % results_dir)
    elif len(results_file_arr) > 1:
        raise RuntimeError('Too many results files found in path %s' % results_dir)
    results_file_path = os.path.join(results_dir, results_file_arr[0])
    return np.load(results_file_path, mmap_mode='r')


def get_sorted_videos_by_metric(arch, dataset_label, model_label, metric):
    '''
    Get the list of videos sorted by the sum of the given metric over all time steps
    '''
    results = get_results(arch, dataset_label, model_label)
    metric_vals = results[metric]
    metric_aucs = np.sum(metric_vals, axis=1)
    sorted_idx = np.argsort(metric_aucs)
    return sorted_idx, metric_aucs[sorted_idx]


def plot_sorted_auc(ax, arch, dataset_label, model_label, metric, label, max_value=1.0, linestyle=None, color=None):
    '''
    Plot ranked AUC plot. max_value specifies the maximum area under the plot if plotting normalized
    AUC is desired.
    '''
    # Plot
    _, sorted_metric_aucs = get_sorted_videos_by_metric(arch, dataset_label, model_label, metric)
    ax.plot(xrange(len(sorted_metric_aucs)), sorted_metric_aucs / float(max_value), label=label, linestyle=linestyle, color=color)


def plot_arch_comparison(dataset_label, model_label, metric, max_value=1.0):
    '''
    Plot the ranked performance plot between each architecture for a given train/test set combination
    '''
    display(HTML('<hr style="height:2px">'))
    display(HTML('<h1>Train:%s<br>Test:%s</h1>' % (model_label, dataset_label)))
    
    
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    for arch in ARCHS:
        plot_sorted_auc(ax1, arch, dataset_label, model_label, metric, arch, max_value=max_value)
    ax1.set_xlabel('Instance #')
    ax1.set_ylabel('%s AUC' % metric.upper())
    ax1.legend(loc='lower left', bbox_to_anchor=(1, 0))
    ax1.set_title('Architecture comparison\nTraining set: %s\nTesting set: %s' % (model_label, dataset_label))
    plt.show()
    
    for arch in ARCHS:
        display(Markdown('## Architecture: %s' % arch))
        sorted_idx, _ = get_sorted_videos_by_metric(arch, dataset_label, model_label, metric)
        visualize_videos(arch, dataset_label, model_label, sorted_idx)


def plot_training_set_comparison(exps, dataset_label, metric, max_value=1.0):
    '''
    Plot performance of all models that were tested on the given dataset on one graph.
    '''
    
    display(HTML('<hr style="height:2px">'))
    display(HTML('<h1>Test:%s</h1>' % (dataset_label)))
    
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    
    for arch in ARCHS:
        if dataset_label not in exps[arch]:
            warn('Did not find dataset %s for architecture %s, skipping' % (dataset_label, arch))
            continue
        model_labels = exps[arch][dataset_label]

        # Plot
        for model_label in model_labels:
            plot_sorted_auc(ax1, arch, dataset_label, model_label, metric,
                            'arch=%s,train=%s' % (arch, model_label), max_value=max_value)
    
    # Format plot
    display(HTML('<h2>Sorted SSIM AUC graph</h2>'))
    ax1.set_xlabel('Instance #')
    ax1.set_ylabel('%s AUC' % metric.upper())
    ax1.legend(loc='lower left', bbox_to_anchor=(1, 0))
    ax1.set_title('Training set comparison\nTesting set: %s' % dataset_label)
    plt.show()
    
    for arch in ARCHS:
        if dataset_label not in exps[arch]:
            warn('Did not find dataset %s for architecture %s, skipping' % (dataset_label, arch))
            continue
        model_labels = exps[arch][dataset_label]

        for model_label in model_labels:
            # Show videos
            display(HTML('<h2>Architecture: %s<br>Train: %s</h2>' % (arch, model_label)))
            sorted_idx, _ = get_sorted_videos_by_metric(arch, dataset_label, model_label, metric)
            visualize_videos(arch, dataset_label, model_label, sorted_idx)


def summarize_experiment(arch, dataset_label, model_label, metric, max_value=1.0):
    '''
    Show the ranked AUC plot and videos for one experiment.
    '''
    display(HTML('<hr style="height:2px">'))
    display(HTML('<h2>Architecture: %s<br>Train:%s<br>Test:%s</h2>' \
                % (arch, model_label, dataset_label)))
    
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    plot_sorted_auc(ax1, arch, dataset_label, model_label, metric, None, max_value=max_value)
    ax1.set_xlabel('Instance #')
    ax1.set_ylabel('%s AUC' % metric.upper())
    plt.show()
    
    sorted_idx, _ = get_sorted_videos_by_metric(arch, dataset_label, model_label, metric)
    visualize_videos(arch, dataset_label, model_label, sorted_idx)


def summarize_discriminator_output(dataset_label, model_label, use_zscore=False):
    '''
    Plot the discriminator score of MCNet for the given training and test set.
    Only valid for long-term video datasets
    '''
    assert(dataset_label.endswith('_long'))
    
    display(HTML('<hr style="height:2px">'))
    display(HTML('<h2>Train:%s<br>Test:%s</h2>' \
                % (model_label, dataset_label)))
    
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    results = get_results('mcnet', dataset_label, model_label)
    # Skip first column, which is zeros for some reason
    disc_output = results['disc_output'][:, 1:]
    display(Markdown('*Mean output: %.05f *' % np.mean(disc_output)))
    if use_zscore:
        disc_output = zscore(disc_output)
    # Define zero-indexed frame indexes for discriminator output
    frame_indexes = np.arange(disc_output.shape[1]) + 15 + 1
    sorted_idx, _ = get_sorted_videos_by_metric('mcnet', dataset_label, model_label, 'ssim')
    
    for inv_rank in xrange(5):
        video_idx = sorted_idx[inv_rank]
        rank = len(sorted_idx) - inv_rank
        disc_output_cur_vid = disc_output[video_idx, :]
        ax1.plot(frame_indexes[::10], disc_output_cur_vid[::10], label='%04d (rank %d)' % (video_idx, rank))

    for inv_rank in xrange(len(sorted_idx)-5,len(sorted_idx)):
        video_idx = sorted_idx[inv_rank]
        rank = len(sorted_idx) - inv_rank
        disc_output_cur_vid = disc_output[video_idx, :]
        ax1.plot(frame_indexes[::10], disc_output_cur_vid[::10], label='%04d (rank %d)' % (video_idx, rank), linestyle=':')

    # Plot mean discriminator score over time
    mean_output = np.mean(disc_output, axis=0)
    ax1.plot(frame_indexes[::10], mean_output[::10], label='mean', linestyle='--')
        
    # Format plot
    ax1.legend(loc='lower left', bbox_to_anchor=(1, 0))
    ax1.set_xlim([16, 500])
    ax1.set_xlabel('Frame index (0 = first GT frame)')
    if use_zscore:
        ax1.set_ylabel('Discriminator output (z-scored)')
        ax1.set_title('Discriminator output for long-term videos (z-scored)\nTraining set: %s\nTesting set: %s' \
                     % (model_label, dataset_label) )
    else:
        ax1.set_ylabel('Discriminator output')
        ax1.set_title('Discriminator output for long-term videos\nTraining set: %s\nTesting set: %s' \
                     % (model_label, dataset_label) )
    plt.show()
    
    # Show examples
    visualize_videos('mcnet', dataset_label, model_label, sorted_idx)

def zscore(a):
    mean = np.mean(a)
    std = np.std(a)
    return (a - mean) / std


def plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=None, linestyles=None, colors=None):
    assert(len(archs) == len(dataset_labels)
           and len(archs) == len(model_labels)
           and (linestyles is None or len(archs) == len(linestyles))
           and (colors is None or len(archs) == len(colors)))

    fig = plt.figure()
    ax1 = fig.add_subplot(111)

    for i in xrange(len(archs)):
        arch = archs[i]
        dataset_label = dataset_labels[i]
        model_label = model_labels[i]
        linestyle = None if linestyles is None else linestyles[i]
        color = None if colors is None else colors[i]

        label = 'arch=%s,data=%s,model=%s' % (arch, dataset_label, model_label)
        plot_sorted_auc(ax1, arch, dataset_label, model_label, metric, label, max_value=5, linestyle=linestyle, color=color)

    if ylim is not None:
        ax1.set_ylim(ylim)
    ax1.set_xlabel('Instance #')
    ax1.set_ylabel('%s AUC' % metric.upper())
    ax1.legend(loc='lower left', bbox_to_anchor=(1, 0))
    plt.show()

    for i in xrange(len(archs)):
        arch = archs[i]
        dataset_label = dataset_labels[i]
        model_label = model_labels[i]

        display(HTML('<h3>Arch: %s<br>Train: %s<br>Test: %s</h3>' % (arch, model_label, dataset_label)))
        sorted_idx, _ = get_sorted_videos_by_metric(arch, dataset_label, model_label, metric)
        visualize_videos(arch, dataset_label, model_label, sorted_idx)

# Read and validate the experiments on disk (MUST RUN BEFORE PLOTTING)

In [None]:
experiments = load_experiments()

# Debug individual plots

In [None]:
plot_arch_comparison('flashing=async_b+num_digits=2', 'flashing=async_a+num_digits=2', 'ssim')
# plot_training_set_comparison(experiments, 'flashing=async_b+num_digits=2', 'ssim')
# summarize_experiment('mcnet', 'translation=on+num_digits=2', 'translation=on', 'ssim')
# summarize_discriminator_output('rotation=no_limit+num_digits=2_long', 'rotation=no_limit', use_zscore=True)

# Short-term evaluation

The cell below shows the results of short-term prediction (i.e. given the first 10 frames, predict the next 5) for each test set. Note that this evaluates over the same time-steps that the model was trained on. Each plot corresponds to one test dataset, and shows the ranked AUC plot for every model that was evaluated on that dataset.

In [None]:
test_dataset_labels = [x.replace('data=', '') for x in os.listdir(RESULTS_ROOTS['mcnet']) if not x.endswith('_long')]
for dataset_label in test_dataset_labels:
    plot_training_set_comparison(experiments, dataset_label, 'ssim', max_value=5)

# Long-term evaluation

The cell below shows the results of long-term prediction (i.e. given the first 10 frames, predict the next 490 frames) for each test set. Since the models were only trained to predict the next 5 frames, this explores how well the learned short-term dynamics fit the long-term dynamics. Each plot corresponds to one test dataset, and shows the ranked AUC plot for every model that was evaluated on that dataset.

In [None]:
test_dataset_labels = [x.replace('data=', '') for x in os.listdir(RESULTS_ROOTS['mcnet']) if x.endswith('_long')]
for dataset_label in test_dataset_labels:
    plot_training_set_comparison(experiments, dataset_label, 'ssim', max_value=485)

# Discriminator scores for long-term prediction

The MCNet discriminator takes as input a 15-frame sequence, and outputs a value in [0, 1] corresponding to the confidence that the input is a real (not-generated) sequence (i.e. a real sequence should have confidence 1). To generate the plots below, a sliding window provides input for the discriminator. The x-axis indicates the index of the last frame in the input sequence, and the y-axis is the z-score of the current output over all discriminator scores in the video set. Each plot shows the output for the worst three videos (solid lines) and the best three videos (dotted lines). The plot labeled "mean" shows the average discriminator response across all frames at a given time step.

In [None]:
for dataset_label in experiments['mcnet']:
    if not dataset_label.endswith('_long'): continue
    for model_label in experiments['mcnet'][dataset_label]:
        summarize_discriminator_output(dataset_label, model_label, use_zscore=False)

# Discriminator scores for long-term prediction (diagonal experiments only)

Discriminator scores might be less useful for the extrapolation experiments. The plots below only show experiments where the training and testing sets were generated from the same parameters.

In [None]:
for dataset_label in experiments['mcnet']:
    if not dataset_label.endswith('_long'): continue
    for model_label in experiments['mcnet'][dataset_label]:
        if dataset_label == model_label + '_long':
            summarize_discriminator_output(dataset_label, model_label, use_zscore=False)

# Summaries of "diagonal" experiments

These plots summarize the experiments where the training and testing sets were generated from the same parameters.

In [None]:
for arch in ARCHS:
    # Get all models for the current architecture
    model_labels = []
    for dataset_label in experiments[arch].iterkeys():
        # Check that any model was trained on the current dataset
        if dataset_label in experiments[arch][dataset_label]:
            model_labels.append(dataset_label)

    for model_label in model_labels:
        try:
            summarize_experiment(arch, model_label, model_label, 'ssim', max_value=5)
        except RuntimeError:
            pass

# Experiment sandbox

The area below is for plotting arbitrary experiments on the same plot. Content will change frequently.

In [None]:
metric = 'ssim'

# Learning: 1 digit
display(Markdown('# Learning dynamics: 1 digit'))

archs = 6 * ['mcnet'] + 6 * ['prednet'] + 6 * ['drnet']
dataset_labels = 3 * ['translation=on', 'rotation=limit_a', 'rotation=no_limit', 'rotation=cw', 'scale=limit_a', 'flashing=sync_b']
model_labels = 3 * ['translation=on', 'rotation=limit_a', 'rotation=no_limit', 'rotation=cw', 'scale=limit_a', 'flashing=sync_b']
linestyles = 6 * [None] + 6 * ['--'] + 6 * ['-.']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[0.94, 1.005], linestyles=linestyles)

In [None]:
# Extrapolation: 1 digit
display(Markdown('# Extrapolating dynamics: 1 digit'))

archs = 2 * ['mcnet'] + 2 * ['prednet'] + 2 * ['drnet']
dataset_labels = 3 * ['rotation=limit_b', 'scale=limit_a']
model_labels = 3 * ['rotation=limit_a', 'scale=limit_b']
linestyles = 2 * [None] + 2 * ['--'] + 2 * ['-.']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, linestyles=linestyles)

In [None]:
# Learning: 2 digits
display(Markdown('# Learning dynamics: 2 digits'))

display(Markdown('## MCNet'))
archs = 7 * ['mcnet'] + 7 * ['prednet'] + 7 * ['drnet']
dataset_labels = 3 * ['translation=on+num_digits=2', 'rotation=limit_a+num_digits=2', 'rotation=no_limit+num_digits=2', 'rotation=cw+num_digits=2', 'scale=limit_a+num_digits=2', 'flashing=sync_b+num_digits=2', 'flashing=async_b+num_digits=2']
model_labels = dataset_labels
linestyles = 7 * [None] + 7 * ['--'] + 7 * ['-.']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.94, 1.005], linestyles=linestyles)

In [None]:
# Learning: 3 digits
display(Markdown('# Learning dynamics: 3 digits'))

display(Markdown('## MCNet'))
archs = 7 * ['mcnet'] + 7 * ['prednet'] + 7 * ['drnet']
dataset_labels = 3 * ['translation=on+num_digits=3', 'rotation=limit_a+num_digits=3', 'rotation=no_limit+num_digits=3', 'rotation=cw+num_digits=3', 'scale=limit_a+num_digits=3', 'flashing=sync_b+num_digits=3', 'flashing=async_b+num_digits=3']
model_labels = dataset_labels
linestyles = 7 * [None] + 7 * ['--'] + 7 * ['-.']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.94, 1.005], linestyles=linestyles)

In [None]:
display(Markdown('# Extrapolating dynamics: 1 digit'))

display(Markdown('## Rotation'))
archs = 2 * ['mcnet'] + 2 * ['prednet'] + 2 * ['drnet']
dataset_labels = 3 * ['rotation=limit_b', 'rotation=limit_a']
model_labels = 6 * ['rotation=limit_a']
linestyles = 2 * [None] + 2 * ['--'] + 2 * ['-.']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, linestyles=linestyles)

display(Markdown('## Scale'))
archs = 2 * ['mcnet'] + 2 * ['prednet'] + 2 * ['drnet']
dataset_labels = 6 * ['scale=limit_a']
model_labels = 3 * ['scale=limit_a', 'scale=limit_b']
linestyles = 2 * [None] + 2 * ['--'] + 2 * ['-.']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, linestyles=linestyles)

In [None]:
# display(Markdown('# Extrapolating dynamics: 2 digits'))

In [None]:
display(Markdown('# Extrapolating number of digits: 1 digit -> 2 digits'))

display(Markdown('## MCNet'))
archs = 8 * ['mcnet']
dataset_labels = 2 * ['translation=on+num_digits=2'] + 2 * ['rotation=limit_a+num_digits=2'] + 2 * ['rotation=no_limit+num_digits=2'] + 2 * ['scale=limit_a+num_digits=2']
model_labels = ['translation=on', 'translation=on+num_digits=2', 'rotation=limit_a', 'rotation=limit_a+num_digits=2', 'rotation=no_limit', 'rotation=no_limit+num_digits=2', 'scale=limit_a', 'scale=limit_a+num_digits=2']
linestyles = 4 * [':', None]
colors = ['C0', 'C0', 'C1', 'C1', 'C2', 'C2', 'C3', 'C3']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.88, 1.005], linestyles=linestyles, colors=colors)

display(Markdown('## PredNet'))
archs = 6 * ['prednet']
dataset_labels = 2 * ['rotation=limit_a+num_digits=2'] + 2 * ['rotation=no_limit+num_digits=2'] + 2 * ['scale=limit_a+num_digits=2']
model_labels = ['rotation=limit_a', 'rotation=limit_a+num_digits=2', 'rotation=no_limit', 'rotation=no_limit+num_digits=2', 'scale=limit_a', 'scale=limit_a+num_digits=2']
linestyles = 3 * [':', None]
colors = ['C0', 'C0', 'C1', 'C1', 'C2', 'C2']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.88, 1.005], linestyles=linestyles, colors=colors)

display(Markdown('## DrNet'))
archs = 8 * ['drnet']
dataset_labels = 2 * ['translation=on+num_digits=2'] + 2 * ['rotation=limit_a+num_digits=2'] + 2 * ['rotation=no_limit+num_digits=2'] + 2 * ['scale=limit_a+num_digits=2']
model_labels = ['translation=on', 'translation=on+num_digits=2', 'rotation=limit_a', 'rotation=limit_a+num_digits=2', 'rotation=no_limit', 'rotation=no_limit+num_digits=2', 'scale=limit_a', 'scale=limit_a+num_digits=2']
linestyles = 4 * [':', None]
colors = ['C0', 'C0', 'C1', 'C1', 'C2', 'C2', 'C3', 'C3']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.88, 1.005], linestyles=linestyles, colors=colors)

In [None]:
display(Markdown('# Extrapolating number of digits: 2 digits -> 1 digit'))

display(Markdown('## MCNet'))
archs = 8 * ['mcnet']
dataset_labels = 2 * ['translation=on'] + 2 * ['rotation=limit_a'] + 2 * ['rotation=no_limit'] + 2 * ['scale=limit_a']
model_labels = ['translation=on', 'translation=on+num_digits=2', 'rotation=limit_a', 'rotation=limit_a+num_digits=2', 'rotation=no_limit', 'rotation=no_limit+num_digits=2', 'scale=limit_a', 'scale=limit_a+num_digits=2']
linestyles = 4 * [':', None]
colors = ['C0', 'C0', 'C1', 'C1', 'C2', 'C2', 'C3', 'C3']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.88, 1.005], linestyles=linestyles, colors=colors)

display(Markdown('## PredNet'))
archs = 6 * ['prednet']
dataset_labels = 2 * ['rotation=limit_a'] + 2 * ['rotation=no_limit'] + 2 * ['scale=limit_a']
model_labels = ['rotation=limit_a', 'rotation=limit_a+num_digits=2', 'rotation=no_limit', 'rotation=no_limit+num_digits=2', 'scale=limit_a', 'scale=limit_a+num_digits=2']
linestyles = 3 * [':', None]
colors = ['C0', 'C0', 'C1', 'C1', 'C2', 'C2']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.93, 1.005], linestyles=linestyles, colors=colors)

display(Markdown('## DrNet'))
archs = 8 * ['drnet']
dataset_labels = 2 * ['translation=on'] + 2 * ['rotation=limit_a'] + 2 * ['rotation=no_limit'] + 2 * ['scale=limit_a']
model_labels = ['translation=on', 'translation=on+num_digits=2', 'rotation=limit_a', 'rotation=limit_a+num_digits=2', 'rotation=no_limit', 'rotation=no_limit+num_digits=2', 'scale=limit_a', 'scale=limit_a+num_digits=2']
linestyles = 4 * [':', None]
colors = ['C0', 'C0', 'C1', 'C1', 'C2', 'C2', 'C3', 'C3']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.88, 1.005], linestyles=linestyles, colors=colors)

In [None]:
display(Markdown('# Extrapolating identity: one digit'))

display(Markdown('## MCnet translation/scale/flashing'))
archs = 6 * ['mcnet']
dataset_labels = []
dataset_labels_temp = ['translation=on', 'scale=limit_a', 'flashing=sync_b']
for dataset_label in dataset_labels_temp:
    dataset_labels += [dataset_label, dataset_label + '+image=label_subset_b']
model_labels = [x.replace('label_subset_b', 'label_subset_a') for x in dataset_labels]
linestyles = 3 * [None, ':']
colors = ['C0', 'C0', 'C1', 'C1', 'C2', 'C2']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.94, 1.005], linestyles=linestyles, colors=colors)

display(Markdown('## MCNet rotation'))
archs = 6 * ['mcnet']
dataset_labels = []
dataset_labels_temp = ['rotation=limit_a', 'rotation=no_limit', 'rotation=cw']
for dataset_label in dataset_labels_temp:
    dataset_labels += [dataset_label, dataset_label + '+image=label_subset_b']
model_labels = [x.replace('label_subset_b', 'label_subset_a') for x in dataset_labels]
linestyles = 3 * [None, ':']
colors = ['C%d' % x for x in xrange(3) for _ in xrange(2)]
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.94, 1.005], linestyles=linestyles, colors=colors)

display(Markdown('## PredNet rotation/scale'))
archs = 8 * ['prednet']
dataset_labels = []
dataset_labels_temp = ['scale=limit_a', 'rotation=limit_a', 'rotation=no_limit', 'rotation=cw']
for dataset_label in dataset_labels_temp:
    dataset_labels += [dataset_label, dataset_label + '+image=label_subset_b']
model_labels = [x.replace('label_subset_b', 'label_subset_a') for x in dataset_labels]
linestyles = 4 * [None, ':']
colors = ['C%d' % x for x in xrange(4) for _ in xrange(2)]
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.94, 1.005], linestyles=linestyles, colors=colors)

display(Markdown('## DrNet translation/scale/flashing'))
archs = 6 * ['drnet']
dataset_labels = []
dataset_labels_temp = ['translation=on', 'scale=limit_a', 'flashing=sync_b']
for dataset_label in dataset_labels_temp:
    dataset_labels += [dataset_label, dataset_label + '+image=label_subset_b']
model_labels = [x.replace('label_subset_b', 'label_subset_a') for x in dataset_labels]
linestyles = 3 * [None, ':']
colors = ['C0', 'C0', 'C1', 'C1', 'C2', 'C2']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.94, 1.005], linestyles=linestyles, colors=colors)

display(Markdown('## DrNet rotation'))
archs = 6 * ['drnet']
dataset_labels = []
dataset_labels_temp = ['rotation=limit_a', 'rotation=no_limit', 'rotation=cw']
for dataset_label in dataset_labels_temp:
    dataset_labels += [dataset_label, dataset_label + '+image=label_subset_b']
model_labels = [x.replace('label_subset_b', 'label_subset_a') for x in dataset_labels]
linestyles = 3 * [None, ':']
colors = ['C%d' % x for x in xrange(3) for _ in xrange(2)]
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.94, 1.005], linestyles=linestyles, colors=colors)

In [None]:
display(Markdown('# Extrapolating identity: two digits'))

display(Markdown('## MCnet translation/scale/flashing'))
archs = 6 * ['mcnet']
dataset_labels = []
dataset_labels_temp = ['translation=on', 'scale=limit_a', 'flashing=sync_b']
dataset_labels_temp = [x + '+num_digits=2' for x in dataset_labels_temp]
for dataset_label in dataset_labels_temp:
    dataset_labels += [dataset_label, dataset_label + '+image=label_subset_b']
model_labels = [x.replace('label_subset_b', 'label_subset_a') for x in dataset_labels]
linestyles = 3 * [None, ':']
colors = ['C0', 'C0', 'C1', 'C1', 'C2', 'C2']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.94, 1.005], linestyles=linestyles, colors=colors)

display(Markdown('## MCNet rotation'))
archs = 6 * ['mcnet']
dataset_labels = []
dataset_labels_temp = ['rotation=limit_a', 'rotation=no_limit', 'rotation=cw']
dataset_labels_temp = [x + '+num_digits=2' for x in dataset_labels_temp]
for dataset_label in dataset_labels_temp:
    dataset_labels += [dataset_label, dataset_label + '+image=label_subset_b']
model_labels = [x.replace('label_subset_b', 'label_subset_a') for x in dataset_labels]
linestyles = 3 * [None, ':']
colors = ['C%d' % x for x in xrange(3) for _ in xrange(2)]
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.94, 1.005], linestyles=linestyles, colors=colors)

display(Markdown('## PredNet rotation/scale'))
archs = 8 * ['prednet']
dataset_labels = []
dataset_labels_temp = ['scale=limit_a', 'rotation=limit_a', 'rotation=no_limit', 'rotation=cw']
dataset_labels_temp = [x + '+num_digits=2' for x in dataset_labels_temp]
for dataset_label in dataset_labels_temp:
    dataset_labels += [dataset_label, dataset_label + '+image=label_subset_b']
model_labels = [x.replace('label_subset_b', 'label_subset_a') for x in dataset_labels]
linestyles = 4 * [None, ':']
colors = ['C%d' % x for x in xrange(4) for _ in xrange(2)]
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.94, 1.005], linestyles=linestyles, colors=colors)

display(Markdown('## DrNet translation/scale/flashing'))
archs = 6 * ['drnet']
dataset_labels = []
dataset_labels_temp = ['translation=on', 'scale=limit_a', 'flashing=sync_b']
dataset_labels_temp = [x + '+num_digits=2' for x in dataset_labels_temp]
for dataset_label in dataset_labels_temp:
    dataset_labels += [dataset_label, dataset_label + '+image=label_subset_b']
model_labels = [x.replace('label_subset_b', 'label_subset_a') for x in dataset_labels]
linestyles = 3 * [None, ':']
colors = ['C0', 'C0', 'C1', 'C1', 'C2', 'C2']
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.94, 1.005], linestyles=linestyles, colors=colors)

display(Markdown('## DrNet rotation'))
archs = 6 * ['drnet']
dataset_labels = []
dataset_labels_temp = ['rotation=limit_a', 'rotation=no_limit', 'rotation=cw']
dataset_labels_temp = [x + '+num_digits=2' for x in dataset_labels_temp]
for dataset_label in dataset_labels_temp:
    dataset_labels += [dataset_label, dataset_label + '+image=label_subset_b']
model_labels = [x.replace('label_subset_b', 'label_subset_a') for x in dataset_labels]
linestyles = 3 * [None, ':']
colors = ['C%d' % x for x in xrange(3) for _ in xrange(2)]
plot_arbitrary_experiments(archs, dataset_labels, model_labels, metric, ylim=[.94, 1.005], linestyles=linestyles, colors=colors)