# Dataset generalization

## For every considered model, meta-learn on 16-dimensional halfspace dataset, then transfer rules, and train / test that on MNIST.

Created by Basile Van Hoorick, Fall 2020.

In [27]:
%run FF_common.ipynb

In [28]:
# IMPORTANT: Henceforth, we use GD directly on inputs but use plasticity rules in the output and hidden layers.
opts_up = Options(gd_input=True,
                  use_graph_rule=True,
                  gd_graph_rule=True,
                  use_output_rule=True,
                  gd_output_rule=True,
                  gd_output=False)
opts_down = Options(gd_input=True,
                    use_graph_rule=True,
                    gd_graph_rule=False,  # Not meta-trainable anymore!
                    use_output_rule=True,
                    gd_output_rule=False,  # Not meta-trainable anymore!
                    gd_output=False)
scheme = UpdateScheme(cross_entropy_loss=True,
                      mse_loss=False,
                      update_misclassified_only=False,
                      update_all_edges=True)

# Feed-forward brain config.
n_up = 16  # Input layer size for meta-learning.
n_down = 28 * 28  # Input layer size for desired task training.
m_up = 2  # Output layer size for meta-learning.
m_down = 10  # Output layer size for desired task training.
l = 4  # Number of hidden layers.
w = 64  # Width of hidden layers.
p = 0.5  # Connectivity probability.
cap = 32  # Number of nodes firing per layer.

# Training config.
num_runs = 10
num_rule_epochs = 50
num_epochs_upstream = 1
num_epochs_downstream = 1
dataset_up = 'halfspace'
dataset_down = 'mnist'

In [29]:
# Instantiate brain factories.
brain_rnn_up_fact = lambda: LocalNet(n_up, m_up, 64, p, 32, 3, options=opts_up, update_scheme=scheme)
brain_rnn_down_fact = lambda: LocalNet(n_down, m_down, 64, p, 32, 3, options=opts_down, update_scheme=scheme)

brain_prepost_up_fact = lambda: FFLocalNet(
    n_up, m_up, l, w, p, cap, hl_rules=TableRule_PrePost(),
    output_rule=TableRule_PrePost(), options=opts_up, update_scheme=scheme)
brain_prepost_down_fact = lambda: FFLocalNet(
    n_down, m_down, l, w, p, cap, hl_rules=TableRule_PrePost(),
    output_rule=TableRule_PrePost(), options=opts_down, update_scheme=scheme)
brain_prepostcount_up_fact = lambda: FFLocalNet(
    n_up, m_up, l, w, p, cap, hl_rules=TableRule_PrePostCount(),
    output_rule=TableRule_PrePostCount(), options=opts_up, update_scheme=scheme)
brain_prepostcount_down_fact = lambda: FFLocalNet(
    n_down, m_down, l, w, p, cap, hl_rules=TableRule_PrePostCount(),
    output_rule=TableRule_PrePostCount(), options=opts_down, update_scheme=scheme)

brain_prepost_nonuni_up_fact = lambda: FFLocalNet(
    n_up, m_up, l, w, p, cap, hl_rules=[TableRule_PrePost(), TableRule_PrePost(), TableRule_PrePost()],
    output_rule=TableRule_PrePost(), options=opts_up, update_scheme=scheme)
brain_prepost_nonuni_down_fact = lambda: FFLocalNet(
    n_down, m_down, l, w, p, cap, hl_rules=[TableRule_PrePost(), TableRule_PrePost(), TableRule_PrePost()],
    output_rule=TableRule_PrePost(), options=opts_down, update_scheme=scheme)
brain_prepostcount_nonuni_up_fact = lambda: FFLocalNet(
    n_up, m_up, l, w, p, cap, hl_rules=[TableRule_PrePostCount(), TableRule_PrePostCount(), TableRule_PrePostCount()],
    output_rule=TableRule_PrePostCount(), options=opts_up, update_scheme=scheme)
brain_prepostcount_nonuni_down_fact = lambda: FFLocalNet(
    n_down, m_down, l, w, p, cap, hl_rules=[TableRule_PrePostCount(), TableRule_PrePostCount(), TableRule_PrePostCount()],
    output_rule=TableRule_PrePostCount(), options=opts_down, update_scheme=scheme)

In [None]:
# Evaluate models.
print('==== Original RNN (very different from all the rest) ====')
stats_rnn_up, stats_rnn_down = evaluate_generalization(
    brain_rnn_up_fact, brain_rnn_down_fact, n_up, n_down,
    dataset_up=dataset_up, dataset_down=dataset_down,
    num_runs=num_runs, num_rule_epochs=num_rule_epochs,
    num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)

print('==== Interpretation: PrePost (universal) ====')
stats_prepost_up, stats_prepost_down = evaluate_generalization(
    brain_prepost_up_fact, brain_prepost_down_fact, n_up, n_down,
    dataset_up=dataset_up, dataset_down=dataset_down,
    num_runs=num_runs, num_rule_epochs=num_rule_epochs,
    num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)
print('==== Interpretation: PrePostCount (universal) ====')
stats_prepostcount_up, stats_prepostcount_down = evaluate_generalization(
    brain_prepostcount_up_fact, brain_prepostcount_down_fact, n_up, n_down,
    dataset_up=dataset_up, dataset_down=dataset_down,
    num_runs=num_runs, num_rule_epochs=num_rule_epochs,
    num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)

print('==== Interpretation: PrePost (NOT universal) ====')
stats_prepost_nonuni_up, stats_prepost_nonuni_down = evaluate_generalization(
    brain_prepost_nonuni_up_fact, brain_prepost_nonuni_down_fact, n_up, n_down,
    dataset_up=dataset_up, dataset_down=dataset_down,
    num_runs=num_runs, num_rule_epochs=num_rule_epochs,
    num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)
print('==== Interpretation: PrePostCount (NOT universal) ====')
stats_prepostcount_nonuni_up, stats_prepostcount_nonuni_down = evaluate_generalization(
    brain_prepostcount_nonuni_up_fact, brain_prepostcount_nonuni_down_fact, n_up, n_down,
    dataset_up=dataset_up, dataset_down=dataset_down,
    num_runs=num_runs, num_rule_epochs=num_rule_epochs,
    num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)

  0%|          | 0/50 [00:00<?, ?it/s]

==== Original RNN (very different from all the rest) ====

Run 1 / 10...
Meta-learning on halfspace...


 24%|██▍       | 12/50 [00:41<02:11,  3.46s/it]

In [None]:
# Plot aggregated stats.
agg_stats_rnn_up = convert_multi_stats_uncertainty(stats_rnn_up)
agg_stats_rnn_down = convert_multi_stats_uncertainty(stats_rnn_down)
plot_curves(agg_stats_rnn_up, agg_stats_rnn_down,
            '[RNN] Upstream meta-learning on ' + dataset_up,
            '[RNN] Downstream training on ' + dataset_down,
            'figs/generalization_rnn_' + dataset_up + '_' + dataset_down,
            no_downstream_loss=True)

agg_stats_prepost_up = convert_multi_stats_uncertainty(stats_prepost_up)
agg_stats_prepost_down = convert_multi_stats_uncertainty(stats_prepost_down)
plot_curves(agg_stats_prepost_up, agg_stats_prepost_down,
            '[PrePost-Uni] Upstream meta-learning on ' + dataset_up,
            '[PrePost-Uni] Downstream training on ' + dataset_down,
            'figs/generalization_prepost_uni_' + dataset_up + '_' + dataset_down,
            no_downstream_loss=True)
agg_stats_prepostcount_up = convert_multi_stats_uncertainty(stats_prepostcount_up)
agg_stats_prepostcount_down = convert_multi_stats_uncertainty(stats_prepostcount_down)
plot_curves(agg_stats_prepostcount_up, agg_stats_prepostcount_down,
            '[PrePostCount-Uni] Upstream meta-learning on ' + dataset_up,
            '[PrePostCount-Uni] Downstream training on ' + dataset_down,
            'figs/generalization_prepostcount_uni_' + dataset_up + '_' + dataset_down,
            no_downstream_loss=True)

agg_stats_prepost_nonuni_up = convert_multi_stats_uncertainty(stats_prepost_nonuni_up)
agg_stats_prepost_nonuni_down = convert_multi_stats_uncertainty(stats_prepost_nonuni_down)
plot_curves(agg_stats_prepost_nonuni_up, agg_stats_prepost_nonuni_down,
            '[PrePost-Uni] Upstream meta-learning on ' + dataset_up,
            '[PrePost-Uni] Downstream training on ' + dataset_down,
            'figs/generalization_prepost_nonuni_' + dataset_up + '_' + dataset_down,
            no_downstream_loss=True)
agg_stats_prepostcount_nonuni_up = convert_multi_stats_uncertainty(stats_prepostcount_nonuni_up)
agg_stats_prepostcount_nonuni_down = convert_multi_stats_uncertainty(stats_prepostcount_nonuni_down)
plot_curves(agg_stats_prepostcount_nonuni_up, agg_stats_prepostcount_nonuni_down,
            '[PrePostCount-Uni] Upstream meta-learning on ' + dataset_up,
            '[PrePostCount-Uni] Downstream training on ' + dataset_down,
            'figs/generalization_prepostcount_nonuni_' + dataset_up + '_' + dataset_down,
            no_downstream_loss=True)

## Train vanilla net

In [None]:
# Instantiate model.
brain_vanilla = FFBrainNet(
    n_down, m_down, l, w, p, cap, full_gd=True)

In [None]:
# Evaluate model.
print('==== Vanilla ====')
X_train, y_train, X_test, y_test = quick_get_data('mnist', 28 * 28)
print('Training VANILLA brain instance (WITH backprop) on mnist...')
stats_vanilla = train_downstream(
    X_train, y_train, brain_vanilla, num_epochs=num_epochs_downstream,
    batch_size=100, vanilla=True, learn_rate=5e-3,
    X_test=X_test, y_test=y_test, verbose=False, stats_interval=500)

In [None]:
# Plot aggregated stats.
plot_curves(None, stats_vanilla, None,
            '[Vanilla] Downstream training on ' + dataset_down,
            'figs/generalization_vanilla_' + dataset_down,
            no_downstream_loss=True)