## Some useful code for sharing across our notebooks.

Created by Basile Van Hoorick, Fall 2020.

To import, type in the first code cell:

%run FF_common.ipynb

In [1]:
# Library imports.
import copy
import numpy as np
import matplotlib.pyplot as plt
import sys
import torch
import torchvision
import torchvision.datasets

sys.path.append('../')
plt.style.use('seaborn')
prop_cycle = plt.rcParams['axes.prop_cycle']
default_colors = prop_cycle.by_key()['color']  # NOTE: Only six colors.

# Repository imports.
from DataGenerator import random_halfspace_data, layer_relu_data
from FFBrainNet import FFBrainNet
from FFLocalNet import FFLocalNet
from FFLocalPlasticityRules.TableRule_PrePost import TableRule_PrePost
from FFLocalPlasticityRules.TableRule_PrePostCount import TableRule_PrePostCount
from FFLocalPlasticityRules.TableRule_PrePostPercent import TableRule_PrePostPercent
from FFLocalPlasticityRules.TableRule_PostCount import TableRule_PostCount
from FFLocalPlasticityRules.OneBetaANNRule_PrePost import OneBetaANNRule_PrePost
from FFLocalPlasticityRules.OneBetaANNRule_PrePostAll import OneBetaANNRule_PrePostAll
from FFLocalPlasticityRules.OneBetaANNRule_PostAll import OneBetaANNRule_PostAll
from FFLocalPlasticityRules.AllBetasANNRule_PostAll import AllBetasANNRule_PostAll
from LocalNetBase import Options, UpdateScheme
from network import LocalNet
from train import metalearn_rules, train_downstream

In [None]:
def quick_get_data(which, dim, N=2000, split=0.75, relu_k=8):
    '''
    Quick, get some data!
    '''
    which = which.lower()
    
    if which == 'halfspace':
        X, y = random_halfspace_data(dim=dim, n=N)
    
    elif which == 'relu':
        X, y = layer_relu_data(n_up, N*3, relu_k)
        print('Count of 0:', np.sum(y == 0), ' Count of 1:', np.sum(y == 1))
    
    elif which == 'mnist':
        # NOTE: Argument N is ignored here.
        mnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=None)
        mnist_test = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=None)
        print('mnist_train:', len(mnist_train))
        print('mnist_test:', len(mnist_test))
        X_train = np.array([np.array(pair[0]) for pair in mnist_train]) / 255.0
        y_train = np.array([pair[1] for pair in mnist_train])
        X_test = np.array([np.array(pair[0]) for pair in mnist_test]) / 255.0
        y_test = np.array([pair[1] for pair in mnist_test])
        X_train = X_train.reshape(X_train.shape[0], -1)
        X_test = X_test.reshape(X_test.shape[0], -1)
    
    else:
        raise ValueError('Unknown or unused dataset: ' + which)
    
    if which != 'mnist':
        X_train = X[:int(N*split)]
        y_train = y[:int(N*split)]
        X_test = X[int(N*split):]
        y_test = y[int(N*split):]
    
    return X_train, y_train, X_test, y_test

In [1]:
def evaluate_brain(brain_fact, n,
                   dataset_up='halfspace', dataset_down='halfspace', downstream_backprop=False,
                   num_runs=1, num_rule_epochs=50, num_epochs_upstream=1, num_epochs_downstream=1,
                   min_upstream_acc=0.7):
    '''
    Evaluate a SINGLE network instance by meta-learning and then
    training on a reinitialized dataset of the same dimensionality.
    
    Args:
        brain_fact: Calling brain_fact() will create a new instance of the network under test.
        dataset_up: Upstream dataset class (halfspace / relu / mnist).
        dataset_down: If None, keep the same dataset instance. Otherwise, downstream dataset class.
        downstream_backprop: Use backprop for the direct GD layers downstream?
            Recommended False if dataset_down is None, since we assume layers
            with direct gradient descent to be trained upstream already.
        min_upstream_acc: Keep meta-learning until we find a good random initialization with
            this final test accuracy.
            
    Returns:
        (multi_stats_up, multi_stats_down).
        Both are lists of length num_runs.
    '''
    multi_stats_up = []
    multi_stats_down = []
    
    for run in range(num_runs):
        print()
        print(f'Run {run+1} / {num_runs}...')
        
        # Upstream.
        success = False
        while not success:
            brain = brain_fact()  # NOTE: Some initializations are unlucky.

            X_train, y_train, X_test, y_test = quick_get_data(dataset_up, n)
            print('Meta-learning on ' + dataset_up + '...')
            stats_up = metalearn_rules(
                X_train, y_train, brain, num_rule_epochs=num_rule_epochs,
                num_epochs=num_epochs_upstream, batch_size=100, learn_rate=1e-2,
                X_test=X_test, y_test=y_test, verbose=False)
            
            success = (stats_up[2][-1] >= min_upstream_acc)
            if not success:
                print(f'Final upstream test acc {stats_up[2][-1]:.4f} not high enough, retrying...')
        
        # Downstream.
        # NO rule transfer needed since we reuse the same network,
        # but just on a possibly altered dataset.
        if dataset_down is not None:
            X_train, y_train, X_test, y_test = quick_get_data(dataset_down, n)
            print('Training SAME brain instance on ' + dataset_down + '...')
        else:
            print('Training SAME brain instance on the same dataset instance...')
        stats_down = train_downstream(
            X_train, y_train, brain, num_epochs=num_epochs_downstream,
            batch_size=100, vanilla=False, learn_rate=5e-3,
            X_test=X_test, y_test=y_test, verbose=False,
            stats_interval=500, disable_backprop=not(downstream_backprop))
        
        # Save this run.
        multi_stats_up.append(stats_up)
        multi_stats_down.append(stats_down)
        
        print()
    
    return (multi_stats_up, multi_stats_down)

In [None]:
def evaluate_up_down(brain_up_fact, brain_down_fact, n_up, n_down,
                     dataset_up='halfspace', dataset_down='halfspace', downstream_backprop=False,
                     num_runs=1, num_rule_epochs=50, num_epochs_upstream=1, num_epochs_downstream=1,
                     get_model=False, min_upstream_acc=0.7):
    '''
    Evaluates a PAIR of brains on the quality of meta-learning
    and rule interpretations by training with transferred rules.
    
    Args:
        brain_up_fact: Calling this will create a new instance of the network to meta-learn.
        brain_down_fact: Calling this will create a new instance of the network to train.
        dataset_up: Upstream dataset class (halfspace / relu / mnist).
        dataset_down: If None, keep the same dataset instance. Otherwise, downstream dataset class.
        downstream_backprop: Use backprop for the direct GD layers downstream?
            Recommended True, since the downstream weights will remain randomly initialized otherwise.
        min_upstream_acc: Keep meta-learning until we find a good random initialization with
            this final test accuracy.
            
    Returns:
        (multi_stats_up, multi_stats_down) or ((multi_stats_up, multi_stats_down), brain_down).
        Both are lists of length num_runs.
    '''
    if (dataset_down is None) != (n_down is None):
        raise ValueError('The nullness of dataset_down does not agree with that of n_down.')
    
    multi_stats_up = []
    multi_stats_down = []
    
    for run in range(num_runs):
        print()
        print(f'Run {run+1} / {num_runs}...')
        
        # Upstream.
        success = False
        while not success:
            brain_up = brain_up_fact()  # NOTE: Some initializations are unlucky.
    
            X_train, y_train, X_test, y_test = quick_get_data(dataset_up, n_up)
            print('Meta-learning on ' + dataset_up + '...')
            stats_up = metalearn_rules(
                X_train, y_train, brain_up, num_rule_epochs=num_rule_epochs,
                num_epochs=num_epochs_upstream, batch_size=100, learn_rate=1e-2,
                X_test=X_test, y_test=y_test, verbose=False)
            
            success = (stats_up[2][-1] >= min_upstream_acc)
            if not success:
                print(f'Final upstream test acc {stats_up[2][-1]:.4f} not high enough, retrying...')

        # Transfer rules.
        brain_down = brain_down_fact()
        if isinstance(brain_down, FFLocalNet):
            # FF-ANN.
            brain_down.copy_rules(brain_up)
        else:
            # RNN.
            try:
                if brain_down.options.use_graph_rule:
                    brain_down.set_rnn_rule(brain_up.get_rnn_rule())
                if brain_down.options.use_output_rule:
                    brain_down.set_output_rule(brain_up.get_output_rule())
            except:
                print('FALLBACK: direct assignment of rules...')
                if downstream_backprop:
                    print('=> WARNING: Rules might still be updated by GD this way')
                brain_down.rnn_rule = brain_up.rnn_rule
                brain_down.output_rule = brain_up.output_rule

        # Downstream.
        if dataset_down is not None and n_down is not None:
            X_train, y_train, X_test, y_test = quick_get_data(dataset_down, n_down)
            print('Training NEW brain instance on ' + dataset_down + '...')
        else:
            print('Training NEW brain instance on the same dataset instance...')
        stats_down = train_downstream(
            X_train, y_train, brain_down, num_epochs=num_epochs_downstream,
            batch_size=100, vanilla=False, learn_rate=1e-2,
            X_test=X_test, y_test=y_test, verbose=False,
            stats_interval=300, disable_backprop=not(downstream_backprop))
        
        # Save this run.
        multi_stats_up.append(stats_up)
        multi_stats_down.append(stats_down)
        
        print()
    
    
    if get_model:
        return (multi_stats_up, multi_stats_down), brain_down
    else:
        return (multi_stats_up, multi_stats_down)

In [None]:
def evaluate_generalization(brain_up_fact, brain_down_fact, n_up, n_down, **kwargs):
    '''
    Legacy method.
    Evaluate the quality of meta-learning and rule interpretations by
    training a different network on a more complex dataset with transferred rules.
    '''
    kwargs['downstream_backprop'] = True
    return evaluate_up_down(brain_up_fact, brain_down_fact, n_up, n_down, **kwargs)

In [None]:
def convert_multi_stats_uncertainty(multi_stats):
    '''
    Merge and summarize stats from multiple runs into one tuple that
    tracks means and standard deviations over time.
    '''
    all_losses = np.array([s[0] for s in multi_stats])
    all_train_acc = np.array([s[1] for s in multi_stats])
    all_test_acc = np.array([s[2] for s in multi_stats])
#     print('all_losses:', all_losses.shape)
    
    # Summarize by calculating things across the 'run' dimension.
    losses_mean = all_losses.mean(axis=0)
    train_acc_mean = all_train_acc.mean(axis=0)
    test_acc_mean = all_test_acc.mean(axis=0)
    losses_std = all_losses.std(axis=0)
    train_acc_std = all_train_acc.std(axis=0)
    test_acc_std = all_test_acc.std(axis=0)
#     print('losses_mean:', losses_mean.shape)
#     print('losses_std:', losses_std.shape)
    
    sample_counts = multi_stats[0][3]  # We assume that this is the same everywhere!
    other_stats = None  # Can't be arsed to combine this.
    
    agg_stats = (losses_mean, losses_std, train_acc_mean, train_acc_std,
                 test_acc_mean, test_acc_std, sample_counts, other_stats)
    return agg_stats

In [None]:
def plot_curves(agg_stats_up, agg_stats_down, title_up, title_down, save_name='figs/default', no_downstream_loss=False):
    '''
    Plot upstream (optional) and downstream (required) learning curves of ONE model.
    If multiple runs were executed, the shaded areas indicate standard deviations.
    '''
    if len(agg_stats_down) == 5:
        # One run.
        if agg_stats_up is not None:
            (meta_losses, meta_train_acc, meta_test_acc, meta_sample_counts, meta_stats) = agg_stats_up
        (plas_losses, plas_train_acc, plas_test_acc, plas_sample_counts, plas_stats) = agg_stats_down
        plot_std = False
    
    else:
        # Multiple runs.
        if agg_stats_up is not None:
            (meta_losses, meta_losses_std, meta_train_acc, meta_train_acc_std,
             meta_test_acc, meta_test_acc_std, meta_sample_counts, meta_stats) = agg_stats_up
        (plas_losses, plas_losses_std, plas_train_acc, plas_train_acc_std,
         plas_test_acc, plas_test_acc_std, plas_sample_counts, plas_stats) = agg_stats_down
        plot_std = True

    fig, ax = plt.subplots(1, 2, figsize=(10, 4))
    
    # Left plot = upstream.
    if agg_stats_up is not None:
        ax[0].plot(meta_sample_counts, meta_losses, label='loss', color=default_colors[2])
        ax[0].plot(meta_sample_counts, meta_train_acc, label='train', color=default_colors[1])
        ax[0].plot(meta_sample_counts, meta_test_acc, label='test', color=default_colors[0])
        if plot_std:
            ax[0].fill_between(meta_sample_counts, meta_losses - meta_losses_std,
                               meta_losses + meta_losses_std, alpha=0.3, facecolor=default_colors[2])
            ax[0].fill_between(meta_sample_counts, meta_train_acc - meta_train_acc_std,
                               meta_train_acc + meta_train_acc_std, alpha=0.3, facecolor=default_colors[1])
            ax[0].fill_between(meta_sample_counts, meta_test_acc - meta_test_acc_std,
                               meta_test_acc + meta_test_acc_std, alpha=0.3, facecolor=default_colors[0])
        ax[0].set_xlabel('Cumulative number of training samples')
        ax[0].set_ylabel('Accuracy / Loss')
        ax[0].set_xlim(0, meta_sample_counts[-1])
        ax[0].set_title(title_up)
        ax[0].legend()
    else:
        ax[0].set_visible(False)
    
    # Right plot = downstream.
    if not no_downstream_loss:
        ax[1].plot(plas_sample_counts[1:], plas_losses[1:], label='loss', color=default_colors[2])
    ax[1].plot(plas_sample_counts, plas_train_acc, label='train', color=default_colors[1])
    ax[1].plot(plas_sample_counts, plas_test_acc, label='test', color=default_colors[0])
    if plot_std:
        if not no_downstream_loss:
            ax[1].fill_between(plas_sample_counts[1:], plas_losses[1:] - plas_losses_std[1:],
                               plas_losses[1:] + plas_losses_std[1:], alpha=0.3, facecolor=default_colors[2])
        ax[1].fill_between(plas_sample_counts, plas_train_acc - plas_train_acc_std,
                           plas_train_acc + plas_train_acc_std, alpha=0.3, facecolor=default_colors[1])
        ax[1].fill_between(plas_sample_counts, plas_test_acc - plas_test_acc_std,
                           plas_test_acc + plas_test_acc_std, alpha=0.3, facecolor=default_colors[0])
    ax[1].set_xlabel('Cumulative number of training samples')
    ax[1].set_ylabel('Accuracy / Loss')
    ax[1].set_xlim(0, plas_sample_counts[-1])
    ax[1].set_title(title_down)
    ax[1].legend()
    
    # Store and display graph.
    print('Saving figure to:', save_name)
    plt.savefig(save_name + '.pdf', dpi=192)
    plt.savefig(save_name + '.png', dpi=192)
    plt.show()
    
    # Print essential stats.
    print('Mean essential stats across all runs:')
    if agg_stats_up is not None:
        print(f'Last upstream loss: {meta_losses[-1]:.4f}')
        print(f'Last upstream train accuracy: {meta_train_acc[-1]:.4f}')
        print(f'Last upstream test accuracy: {meta_test_acc[-1]:.4f}')
    print(f'Last downstream loss: {plas_losses[-1]:.4f}')
    print(f'Last downstream train accuracy: {plas_train_acc[-1]:.4f}')
    print(f'Last downstream test accuracy: {plas_test_acc[-1]:.4f}')
    print()

In [None]:
def get_colors_styles(labels):
    # NOTE: Please feel free to modify this method to improve your figures.
    
    colors = copy.deepcopy(default_colors)
    styles = ['solid'] * len(labels)
    
    # Handle exceptions.
#     if len(labels) == 8:
#         colors = [default_colors[0]] * len(labels)
#         colors[0] = default_colors[0]  # RNN
#         colors[1] = colors[2] = colors[3] = default_colors[1]  # PrePost
#         colors[4] = colors[5] = colors[6] = default_colors[2]  # PrePostCount
#         colors[7] = default_colors[3]  # Vanilla
#         styles[2] = styles[5] = 'dashed'
#         styles[3] = styles[6] = 'dotted'
        
    return colors, styles

In [None]:
def plot_compare_models(all_stats_up, all_stats_down, labels, title_up, title_down, save_name='figs/default'):
    '''
    Plot upstream (optional) and downstream (required) curves of
    only one metric (test accuracy) across MANY models.
    '''
    num_models = len(all_stats_up)
    assert(num_models == len(all_stats_down) and num_models == len(labels))
    
    if len(labels) > 6:
        raise ValueError("Too many plots at once (we don't have that many colors)")
    
    fig, ax = plt.subplots(1, 2, figsize=(10, 4))
    colors, styles = get_colors_styles(labels)
    
    for i in range(num_models):
        agg_stats_up = all_stats_up[i]
        agg_stats_down = all_stats_down[i]
        
        if len(agg_stats_down) == 5:
            # One run.
            if agg_stats_up is not None:
                (meta_losses, meta_train_acc, meta_test_acc, meta_sample_counts, meta_stats) = agg_stats_up
            (plas_losses, plas_train_acc, plas_test_acc, plas_sample_counts, plas_stats) = agg_stats_down
            plot_std = False
        
        else:
            # Multiple runs.
            if agg_stats_up is not None:
                (meta_losses, meta_losses_std, meta_train_acc, meta_train_acc_std,
                 meta_test_acc, meta_test_acc_std, meta_sample_counts, meta_stats) = agg_stats_up
            (plas_losses, plas_losses_std, plas_train_acc, plas_train_acc_std,
             plas_test_acc, plas_test_acc_std, plas_sample_counts, plas_stats) = agg_stats_down
            plot_std = True
        
        if agg_stats_up is not None:
            ax[0].plot(meta_sample_counts, meta_test_acc, label=labels[i], color=colors[i], linestyle=styles[i])
        ax[1].plot(plas_sample_counts, plas_test_acc, label=labels[i], color=colors[i], linestyle=styles[i])
        if plot_std:
            if agg_stats_up is not None:
                ax[0].fill_between(meta_sample_counts, meta_test_acc - meta_test_acc_std,
                                   meta_test_acc + meta_test_acc_std, alpha=0.3, facecolor=colors[i], linestyle=styles[i])
            ax[1].fill_between(plas_sample_counts, plas_test_acc - plas_test_acc_std,
                               plas_test_acc + plas_test_acc_std, alpha=0.3, facecolor=colors[i], linestyle=styles[i])
        
    ax[0].set_xlabel('Cumulative number of training samples')
    ax[0].set_ylabel('Test accuracy')
    ax[0].set_title(title_up)
    ax[0].legend()
    ax[1].set_xlabel('Cumulative number of training samples')
    ax[1].set_ylabel('Test accuracy')
    ax[1].set_title(title_down)
    ax[1].legend()
    
    # Store and display graph.
    print('Saving figure to:', save_name)
    plt.savefig(save_name + '.pdf', dpi=192)
    plt.savefig(save_name + '.png', dpi=192)
    plt.show()