# Dataset generalization

## For every considered model, meta-learn on 16-dimensional halfspace dataset, then transfer rules, and train / test that on MNIST.

Created by Basile Van Hoorick, Fall 2020.

In [1]:
%run FF_common.ipynb

In [2]:
# IMPORTANT: Henceforth, we use GD directly on inputs but use plasticity rules in the output and hidden layers.
opts_up = Options(gd_input=True,
                  use_graph_rule=True,
                  gd_graph_rule=True,
                  use_output_rule=True,
                  gd_output_rule=True,
                  gd_output=False)
opts_down = Options(gd_input=True,
                    use_graph_rule=True,
                    gd_graph_rule=False,  # Not meta-trainable anymore!
                    use_output_rule=True,
                    gd_output_rule=False,  # Not meta-trainable anymore!
                    gd_output=False)
if 1:
    scheme = UpdateScheme(cross_entropy_loss=True,
                          mse_loss=False,
                          update_misclassified_only=False,
                          update_all_edges=True)
else:
    # Same as paper.
    scheme = UpdateScheme(cross_entropy_loss=True,
                          mse_loss=False,
                          update_misclassified_only=True,
                          update_all_edges=False)

# Feed-forward brain config.
n_up = 16  # Input layer size for meta-learning.
n_down = 28 * 28  # Input layer size for desired task training.
m_up = 2  # Output layer size for meta-learning.
m_down = 10  # Output layer size for desired task training.
w = 100  # Width of hidden layers.
p = 0.5  # Connectivity probability.
cap = 50  # Number of nodes firing per layer.

# Training config.
num_runs = 5
num_rule_epochs = 50
num_epochs_upstream = 1
num_epochs_downstream = 1
downstream_backprop = False
dataset_up = 'halfspace'
dataset_down = 'mnist'

In [3]:
# Instantiate brain factories.
brain_rnn_up_fact = lambda: LocalNet(n_up, m_up, w, p, cap, 2, options=opts_up, update_scheme=scheme)
brain_rnn_down_fact = lambda: LocalNet(n_down, m_down, w, p, cap, 2, options=opts_down, update_scheme=scheme)

brain_prepost_l2_up_fact = lambda: FFLocalNet(
    n_up, m_up, 2, w, p, cap, hl_rules=TableRule_PrePost(),
    output_rule=TableRule_PrePost(), options=opts_up, update_scheme=scheme)
brain_prepost_l2_down_fact = lambda: FFLocalNet(
    n_down, m_down, 2, w, p, cap, hl_rules=TableRule_PrePost(),
    output_rule=TableRule_PrePost(), options=opts_down, update_scheme=scheme)
brain_prepost_l3_up_fact = lambda: FFLocalNet(
    n_up, m_up, 3, w, p, cap, hl_rules=TableRule_PrePost(),
    output_rule=TableRule_PrePost(), options=opts_up, update_scheme=scheme)
brain_prepost_l3_down_fact = lambda: FFLocalNet(
    n_down, m_down, 3, w, p, cap, hl_rules=TableRule_PrePost(),
    output_rule=TableRule_PrePost(), options=opts_down, update_scheme=scheme)
brain_prepost_l4_up_fact = lambda: FFLocalNet(
    n_up, m_up, 4, w, p, cap, hl_rules=TableRule_PrePost(),
    output_rule=TableRule_PrePost(), options=opts_up, update_scheme=scheme)
brain_prepost_l4_down_fact = lambda: FFLocalNet(
    n_down, m_down, 4, w, p, cap, hl_rules=TableRule_PrePost(),
    output_rule=TableRule_PrePost(), options=opts_down, update_scheme=scheme)

brain_prepostcount_l2_up_fact = lambda: FFLocalNet(
    n_up, m_up, 2, w, p, cap, hl_rules=TableRule_PrePostCount(),
    output_rule=TableRule_PrePostCount(), options=opts_up, update_scheme=scheme)
brain_prepostcount_l2_down_fact = lambda: FFLocalNet(
    n_down, m_down, 2, w, p, cap, hl_rules=TableRule_PrePostCount(),
    output_rule=TableRule_PrePostCount(), options=opts_down, update_scheme=scheme)
brain_prepostcount_l3_up_fact = lambda: FFLocalNet(
    n_up, m_up, 3, w, p, cap, hl_rules=TableRule_PrePostCount(),
    output_rule=TableRule_PrePostCount(), options=opts_up, update_scheme=scheme)
brain_prepostcount_l3_down_fact = lambda: FFLocalNet(
    n_down, m_down, 3, w, p, cap, hl_rules=TableRule_PrePostCount(),
    output_rule=TableRule_PrePostCount(), options=opts_down, update_scheme=scheme)
brain_prepostcount_l4_up_fact = lambda: FFLocalNet(
    n_up, m_up, 4, w, p, cap, hl_rules=TableRule_PrePostCount(),
    output_rule=TableRule_PrePostCount(), options=opts_up, update_scheme=scheme)
brain_prepostcount_l4_down_fact = lambda: FFLocalNet(
    n_down, m_down, 4, w, p, cap, hl_rules=TableRule_PrePostCount(),
    output_rule=TableRule_PrePostCount(), options=opts_down, update_scheme=scheme)

In [4]:
# Evaluate models.
print('==== Original RNN (very different from all the rest) ====')
stats_rnn_up, stats_rnn_down = evaluate_up_down(
    brain_rnn_up_fact, brain_rnn_down_fact, n_up, n_down,
    dataset_up=dataset_up, dataset_down=dataset_down,
    downstream_backprop=downstream_backprop, num_runs=num_runs, num_rule_epochs=num_rule_epochs,
    num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)

  0%|          | 0/50 [00:00<?, ?it/s]

==== Original RNN (very different from all the rest) ====

Run 1 / 5...
Meta-learning on halfspace...


100%|██████████| 50/50 [02:22<00:00,  2.86s/it]
  self.rnn_rule = torch.tensor(rule).flatten().double()


Last loss: 0.3737
Last train accuracy: 0.9500
Last test accuracy: 0.9220
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance on mnist...
===> This is NOT recommended by Basile!
INITIAL train accuracy: 0.1113
INITIAL test accuracy: 0.1136
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [10:40<00:00, 93.68it/s] 
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.3552
Last train accuracy: 0.1096
Last test accuracy: 0.1102


Run 2 / 5...
Meta-learning on halfspace...


100%|██████████| 50/50 [02:37<00:00,  3.15s/it]


Last loss: 0.4966
Last train accuracy: 0.8687
Last test accuracy: 0.8480
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance on mnist...
===> This is NOT recommended by Basile!
INITIAL train accuracy: 0.0903


  0%|          | 0/60000 [00:00<?, ?it/s]

INITIAL test accuracy: 0.0891
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [10:34<00:00, 94.59it/s] 
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.3052
Last train accuracy: 0.1800
Last test accuracy: 0.1874


Run 3 / 5...
Meta-learning on halfspace...


100%|██████████| 50/50 [02:50<00:00,  3.41s/it]


Last loss: 0.3617
Last train accuracy: 0.9453
Last test accuracy: 0.9320
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance on mnist...
===> This is NOT recommended by Basile!
INITIAL train accuracy: 0.0984
INITIAL test accuracy: 0.0945
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [10:26<00:00, 95.81it/s] 
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.2828
Last train accuracy: 0.2294
Last test accuracy: 0.2321


Run 4 / 5...
Meta-learning on halfspace...


100%|██████████| 50/50 [02:48<00:00,  3.37s/it]


Last loss: 0.3694
Last train accuracy: 0.9487
Last test accuracy: 0.9240
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance on mnist...
===> This is NOT recommended by Basile!
INITIAL train accuracy: 0.0948


  0%|          | 0/60000 [00:00<?, ?it/s]

INITIAL test accuracy: 0.0919
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [07:56<00:00, 125.92it/s]
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.2036
Last train accuracy: 0.2223
Last test accuracy: 0.2290


Run 5 / 5...
Meta-learning on halfspace...


100%|██████████| 50/50 [02:23<00:00,  2.87s/it]
  0%|          | 0/50 [00:00<?, ?it/s]

Last loss: 0.6628
Last train accuracy: 0.5533
Last test accuracy: 0.5540
Final upstream test acc 0.5540 not high enough, retrying...
Meta-learning on halfspace...


100%|██████████| 50/50 [02:16<00:00,  2.72s/it]
  0%|          | 0/50 [00:00<?, ?it/s]

Last loss: 0.5692
Last train accuracy: 0.7133
Last test accuracy: 0.6980
Final upstream test acc 0.6980 not high enough, retrying...
Meta-learning on halfspace...


100%|██████████| 50/50 [02:32<00:00,  3.06s/it]


Last loss: 0.3678
Last train accuracy: 0.9480
Last test accuracy: 0.9300
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance on mnist...
===> This is NOT recommended by Basile!
INITIAL train accuracy: 0.1082
INITIAL test accuracy: 0.1050
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [08:21<00:00, 119.71it/s]


Last loss: 2.2651
Last train accuracy: 0.1814
Last test accuracy: 0.1785






In [None]:
print('==== Interpretation: PrePost, 2 hidden layers (universal) ====')
stats_prepost_l2_up, stats_prepost_l2_down = evaluate_up_down(
    brain_prepost_l2_up_fact, brain_prepost_l2_down_fact, n_up, n_down,
    dataset_up=dataset_up, dataset_down=dataset_down,
    downstream_backprop=downstream_backprop, num_runs=num_runs, num_rule_epochs=num_rule_epochs,
    num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)
print('==== Interpretation: PrePost, 3 hidden layers (universal) ====')
stats_prepost_l3_up, stats_prepost_l3_down = evaluate_up_down(
    brain_prepost_l3_up_fact, brain_prepost_l3_down_fact, n_up, n_down,
    dataset_up=dataset_up, dataset_down=dataset_down,
    downstream_backprop=downstream_backprop, num_runs=num_runs, num_rule_epochs=num_rule_epochs,
    num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)
print('==== Interpretation: PrePost, 4 hidden layers (universal) ====')
stats_prepost_l4_up, stats_prepost_l4_down = evaluate_up_down(
    brain_prepost_l4_up_fact, brain_prepost_l4_down_fact, n_up, n_down,
    dataset_up=dataset_up, dataset_down=dataset_down,
    downstream_backprop=downstream_backprop, num_runs=num_runs, num_rule_epochs=num_rule_epochs,
    num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)

  0%|          | 0/50 [00:00<?, ?it/s]

==== Interpretation: PrePost, 2 hidden layers (universal) ====

Run 1 / 5...
Meta-learning on halfspace...


100%|██████████| 50/50 [02:02<00:00,  2.44s/it]


Last loss: 0.4573
Last train accuracy: 0.8407
Last test accuracy: 0.8320
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance on mnist...
===> This is NOT recommended by Basile!
INITIAL train accuracy: 0.0987
INITIAL test accuracy: 0.0980
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [04:23<00:00, 227.79it/s]
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 1.8632
Last train accuracy: 0.6143
Last test accuracy: 0.6232


Run 2 / 5...
Meta-learning on halfspace...


100%|██████████| 50/50 [01:56<00:00,  2.32s/it]


Last loss: 0.4400
Last train accuracy: 0.8467
Last test accuracy: 0.8420
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance on mnist...
===> This is NOT recommended by Basile!
INITIAL train accuracy: 0.0987
INITIAL test accuracy: 0.0980
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [03:49<00:00, 261.20it/s]
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 1.9512
Last train accuracy: 0.5207
Last test accuracy: 0.5276


Run 3 / 5...
Meta-learning on halfspace...


100%|██████████| 50/50 [01:56<00:00,  2.32s/it]


Last loss: 0.4595
Last train accuracy: 0.8240
Last test accuracy: 0.8240
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance on mnist...
===> This is NOT recommended by Basile!
INITIAL train accuracy: 0.0987
INITIAL test accuracy: 0.0980
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [05:25<00:00, 184.51it/s]
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 1.8332
Last train accuracy: 0.6139
Last test accuracy: 0.6251


Run 4 / 5...
Meta-learning on halfspace...


100%|██████████| 50/50 [02:11<00:00,  2.62s/it]


Last loss: 0.4623
Last train accuracy: 0.8313
Last test accuracy: 0.8240
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance on mnist...
===> This is NOT recommended by Basile!
INITIAL train accuracy: 0.0987
INITIAL test accuracy: 0.0980
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [03:34<00:00, 279.60it/s]
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 1.8532
Last train accuracy: 0.5917
Last test accuracy: 0.5962


Run 5 / 5...
Meta-learning on halfspace...


100%|██████████| 50/50 [01:56<00:00,  2.33s/it]


Last loss: 0.4566
Last train accuracy: 0.8267
Last test accuracy: 0.7960
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance on mnist...
===> This is NOT recommended by Basile!
INITIAL train accuracy: 0.0987
INITIAL test accuracy: 0.0980
Epoch 1 / 1 ...


 14%|█▎        | 8183/60000 [00:29<02:51, 302.70it/s]

In [None]:
print('==== Interpretation: PrePostCount, 2 hidden layers (universal) ====')
stats_prepostcount_l2_up, stats_prepostcount_l2_down = evaluate_up_down(
    brain_prepostcount_l2_up_fact, brain_prepostcount_l2_down_fact, n_up, n_down,
    dataset_up=dataset_up, dataset_down=dataset_down,
    downstream_backprop=downstream_backprop, num_runs=num_runs, num_rule_epochs=num_rule_epochs,
    num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)
print('==== Interpretation: PrePostCount, 3 hidden layers (universal) ====')
stats_prepostcount_l3_up, stats_prepostcount_l3_down = evaluate_up_down(
    brain_prepostcount_l3_up_fact, brain_prepostcount_l3_down_fact, n_up, n_down,
    dataset_up=dataset_up, dataset_down=dataset_down,
    downstream_backprop=downstream_backprop, num_runs=num_runs, num_rule_epochs=num_rule_epochs,
    num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)
print('==== Interpretation: PrePostCount, 4 hidden layers (universal) ====')
stats_prepostcount_l4_up, stats_prepostcount_l4_down = evaluate_up_down(
    brain_prepostcount_l4_up_fact, brain_prepostcount_l4_down_fact, n_up, n_down,
    dataset_up=dataset_up, dataset_down=dataset_down,
    downstream_backprop=downstream_backprop, num_runs=num_runs, num_rule_epochs=num_rule_epochs,
    num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)

In [None]:
# Plot aggregated stats.
agg_stats_rnn_up = convert_multi_stats_uncertainty(stats_rnn_up)
agg_stats_rnn_down = convert_multi_stats_uncertainty(stats_rnn_down)
plot_curves(agg_stats_rnn_up, agg_stats_rnn_down,
            '[RNN] Upstream meta-learning on ' + dataset_up,
            '[RNN] Downstream training on ' + dataset_down,
            'figs/generalization_rnn_' + dataset_up + '_' + dataset_down,
            no_downstream_loss=True)

In [None]:
agg_stats_prepost_l2_up = convert_multi_stats_uncertainty(stats_prepost_l2_up)
agg_stats_prepost_l2_down = convert_multi_stats_uncertainty(stats_prepost_l2_down)
plot_curves(agg_stats_prepost_l2_up, agg_stats_prepost_l2_down,
            '[PrePost-L2-Uni] Upstream meta-learning on ' + dataset_up,
            '[PrePost-L2-Uni] Downstream training on ' + dataset_down,
            'figs/generalization_prepost_l2_uni_' + dataset_up + '_' + dataset_down,
            no_downstream_loss=True)
agg_stats_prepost_l3_up = convert_multi_stats_uncertainty(stats_prepost_l3_up)
agg_stats_prepost_l3_down = convert_multi_stats_uncertainty(stats_prepost_l3_down)
plot_curves(agg_stats_prepost_l3_up, agg_stats_prepost_l3_down,
            '[PrePost-L3-Uni] Upstream meta-learning on ' + dataset_up,
            '[PrePost-L3-Uni] Downstream training on ' + dataset_down,
            'figs/generalization_prepost_l3_uni_' + dataset_up + '_' + dataset_down,
            no_downstream_loss=True)
agg_stats_prepost_l4_up = convert_multi_stats_uncertainty(stats_prepost_l4_up)
agg_stats_prepost_l4_down = convert_multi_stats_uncertainty(stats_prepost_l4_down)
plot_curves(agg_stats_prepost_l4_up, agg_stats_prepost_l4_down,
            '[PrePost-L4-Uni] Upstream meta-learning on ' + dataset_up,
            '[PrePost-L4-Uni] Downstream training on ' + dataset_down,
            'figs/generalization_prepost_l4_uni_' + dataset_up + '_' + dataset_down,
            no_downstream_loss=True)

In [None]:
agg_stats_prepostcount_l2_up = convert_multi_stats_uncertainty(stats_prepostcount_l2_up)
agg_stats_prepostcount_l2_down = convert_multi_stats_uncertainty(stats_prepostcount_l2_down)
plot_curves(agg_stats_prepostcount_l2_up, agg_stats_prepostcount_l2_down,
            '[PrePostCount-L2-Uni] Upstream meta-learning on ' + dataset_up,
            '[PrePostCount-L2-Uni] Downstream training on ' + dataset_down,
            'figs/generalization_prepostcount_l2_uni_' + dataset_up + '_' + dataset_down,
            no_downstream_loss=True)
agg_stats_prepostcount_l3_up = convert_multi_stats_uncertainty(stats_prepostcount_l3_up)
agg_stats_prepostcount_l3_down = convert_multi_stats_uncertainty(stats_prepostcount_l3_down)
plot_curves(agg_stats_prepostcount_l3_up, agg_stats_prepostcount_l3_down,
            '[PrePostCount-L3-Uni] Upstream meta-learning on ' + dataset_up,
            '[PrePostCount-L3-Uni] Downstream training on ' + dataset_down,
            'figs/generalization_prepostcount_l3_uni_' + dataset_up + '_' + dataset_down,
            no_downstream_loss=True)
agg_stats_prepostcount_l4_up = convert_multi_stats_uncertainty(stats_prepostcount_l4_up)
agg_stats_prepostcount_l4_down = convert_multi_stats_uncertainty(stats_prepostcount_l4_down)
plot_curves(agg_stats_prepostcount_l4_up, agg_stats_prepostcount_l4_down,
            '[PrePostCount-L4-Uni] Upstream meta-learning on ' + dataset_up,
            '[PrePostCount-L4-Uni] Downstream training on ' + dataset_down,
            'figs/generalization_prepostcount_l4_uni_' + dataset_up + '_' + dataset_down,
            no_downstream_loss=True)

## Train vanilla net

In [None]:
# Instantiate model.
brain_vanilla = FFBrainNet(
    n_down, m_down, l, w, p, cap, full_gd=True)

In [None]:
# Evaluate model.
print('==== Vanilla ====')
X_train, y_train, X_test, y_test = quick_get_data('mnist', 28 * 28)
print('Training VANILLA brain instance (WITH backprop) on mnist...')
stats_vanilla = train_downstream(
    X_train, y_train, brain_vanilla, num_epochs=num_epochs_downstream,
    batch_size=256, vanilla=True, learn_rate=5e-3,
    X_test=X_test, y_test=y_test, verbose=False, stats_interval=500)

In [None]:
# Plot aggregated stats.
plot_curves(None, stats_vanilla, None,
            '[Vanilla] Downstream training on ' + dataset_down,
            'figs/generalization_vanilla_' + dataset_down,
            no_downstream_loss=True)

## Final plot

In [None]:
# Plot to compare all.
all_stats_up = [agg_stats_rnn_up,
                agg_stats_prepost_l2_up, agg_stats_prepost_l3_up, agg_stats_prepost_l4_up,
                agg_stats_prepostcount_l2_up, agg_stats_prepostcount_l3_up, agg_stats_prepostcount_l4_up,
                None]
all_stats_down = [agg_stats_rnn_down,
                  agg_stats_prepost_l2_down, agg_stats_prepost_l3_down, agg_stats_prepost_l4_down,
                  agg_stats_prepostcount_l2_down, agg_stats_prepostcount_l3_down, agg_stats_prepostcount_l4_down,
                  stats_vanilla]
labels = ['RNN', 'PrePost-L2', 'PrePost-L3', 'PrePost-L4', 'PrePostCount-L2', 'PrePostCount-L3', 'PrePostCount-L4', 'Vanilla']
plot_compare_models(all_stats_up, all_stats_down, labels,
                    'Upstream meta-learning on ' + dataset_up,
                    'Downstream training on ' + dataset_down,
                    'figs/generalization_all_' + dataset_up + '_' + dataset_down)