# Dataset generalization

## For every considered model, meta-learn on 16-dimensional halfspace dataset, then transfer rules, and train / test that on MNIST.

Created by Basile Van Hoorick, Fall 2020.

In [1]:
%run FF_common.ipynb

In [2]:
# IMPORTANT: Henceforth, we use GD directly on inputs but use plasticity rules in the output and hidden layers.
opts_up = Options(gd_input=True,
                  use_graph_rule=True,
                  gd_graph_rule=True,
                  use_output_rule=True,
                  gd_output_rule=True,
                  gd_output=False)
opts_down = Options(gd_input=True,
                    use_graph_rule=True,
                    gd_graph_rule=False,  # Not meta-trainable anymore!
                    use_output_rule=True,
                    gd_output_rule=False,  # Not meta-trainable anymore!
                    gd_output=False)
scheme = UpdateScheme(cross_entropy_loss=True,
                      mse_loss=False,
                      update_misclassified_only=False,
                      update_all_edges=True)

# Feed-forward brain config.
n_up = 16  # Input layer size for meta-learning.
n_down = 28 * 28  # Input layer size for desired task training.
m_up = 2  # Output layer size for meta-learning.
m_down = 10  # Output layer size for desired task training.
l = 4  # Number of hidden layers.
w = 64  # Width of hidden layers.
p = 0.5  # Connectivity probability.
cap = 32  # Number of nodes firing per layer.

# Training config.
num_runs = 10
num_rule_epochs = 50
num_epochs_upstream = 1
num_epochs_downstream = 1
dataset_up = 'halfspace'
dataset_down = 'mnist'

In [3]:
# Instantiate brain factories.
brain_rnn_up_fact = lambda: LocalNet(n_up, m_up, 64, p, 32, 2, options=opts_up, update_scheme=scheme)
brain_rnn_down_fact = lambda: LocalNet(n_down, m_down, 640, p, 320, 2, options=opts_down, update_scheme=scheme)

brain_prepost_up_fact = lambda: FFLocalNet(
    n_up, m_up, l, w, p, cap, hl_rules=TableRule_PrePost(),
    output_rule=TableRule_PrePost(), options=opts_up, update_scheme=scheme)
brain_prepost_down_fact = lambda: FFLocalNet(
    n_down, m_down, l, w*10, p, cap*10, hl_rules=TableRule_PrePost(),
    output_rule=TableRule_PrePost(), options=opts_down, update_scheme=scheme)
brain_prepostcount_up_fact = lambda: FFLocalNet(
    n_up, m_up, l, w, p, cap, hl_rules=TableRule_PrePostCount(),
    output_rule=TableRule_PrePostCount(), options=opts_up, update_scheme=scheme)
brain_prepostcount_down_fact = lambda: FFLocalNet(
    n_down, m_down, l, w*10, p, cap*10, hl_rules=TableRule_PrePostCount(),
    output_rule=TableRule_PrePostCount(), options=opts_down, update_scheme=scheme)

brain_prepost_nonuni_up_fact = lambda: FFLocalNet(
    n_up, m_up, l, w, p, cap, hl_rules=[TableRule_PrePost(), TableRule_PrePost(), TableRule_PrePost()],
    output_rule=TableRule_PrePost(), options=opts_up, update_scheme=scheme)
brain_prepost_nonuni_down_fact = lambda: FFLocalNet(
    n_down, m_down, l, w*10, p, cap*10, hl_rules=[TableRule_PrePost(), TableRule_PrePost(), TableRule_PrePost()],
    output_rule=TableRule_PrePost(), options=opts_down, update_scheme=scheme)
brain_prepostcount_nonuni_up_fact = lambda: FFLocalNet(
    n_up, m_up, l, w, p, cap, hl_rules=[TableRule_PrePostCount(), TableRule_PrePostCount(), TableRule_PrePostCount()],
    output_rule=TableRule_PrePostCount(), options=opts_up, update_scheme=scheme)
brain_prepostcount_nonuni_down_fact = lambda: FFLocalNet(
    n_down, m_down, l, w*10, p, cap*10, hl_rules=[TableRule_PrePostCount(), TableRule_PrePostCount(), TableRule_PrePostCount()],
    output_rule=TableRule_PrePostCount(), options=opts_down, update_scheme=scheme)

In [4]:
# Evaluate models.
print('==== Original RNN (very different from all the rest) ====')
stats_rnn_up, stats_rnn_down = evaluate_generalization(
    brain_rnn_up_fact, brain_rnn_down_fact, n_up, n_down,
    dataset_up=dataset_up, dataset_down=dataset_down,
    num_runs=num_runs, num_rule_epochs=num_rule_epochs,
    num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)

  0%|          | 0/50 [00:00<?, ?it/s]

==== Original RNN (very different from all the rest) ====

Run 1 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [03:00<00:00,  3.60s/it]
  self.rnn_rule = torch.tensor(rule).flatten().double()


Last loss: 0.3458
Last train accuracy: 0.9580
Last test accuracy: 0.9580
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0903


  0%|          | 0/60000 [00:00<?, ?it/s]

INITIAL test accuracy: 0.0892
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [32:52<00:00, 30.41it/s]  
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.3612
Last train accuracy: 0.1124
Last test accuracy: 0.1135


Run 2 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [03:08<00:00,  3.77s/it]


Last loss: 0.3661
Last train accuracy: 0.9693
Last test accuracy: 0.9580
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0902


  0%|          | 0/60000 [00:00<?, ?it/s]

INITIAL test accuracy: 0.0891
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [32:49<00:00, 30.46it/s]  
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.3172
Last train accuracy: 0.1124
Last test accuracy: 0.1135


Run 3 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [03:02<00:00,  3.64s/it]


Last loss: 0.3582
Last train accuracy: 0.9447
Last test accuracy: 0.9280
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0902
INITIAL test accuracy: 0.0891
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [32:43<00:00, 30.56it/s]  
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.3432
Last train accuracy: 0.1124
Last test accuracy: 0.1135


Run 4 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [02:50<00:00,  3.40s/it]


Last loss: 0.3440
Last train accuracy: 0.9613
Last test accuracy: 0.9400
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0902
INITIAL test accuracy: 0.0892
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [29:17<00:00, 34.14it/s]  
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.3332
Last train accuracy: 0.1124
Last test accuracy: 0.1135


Run 5 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [02:51<00:00,  3.43s/it]


Last loss: 0.3563
Last train accuracy: 0.9460
Last test accuracy: 0.9500
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0902


  0%|          | 0/60000 [00:00<?, ?it/s]

INITIAL test accuracy: 0.0892
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [31:31<00:00, 31.72it/s]  
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.3352
Last train accuracy: 0.1124
Last test accuracy: 0.1135


Run 6 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [03:09<00:00,  3.79s/it]


Last loss: 0.3611
Last train accuracy: 0.9480
Last test accuracy: 0.9500
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0898
INITIAL test accuracy: 0.0892
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [33:45<00:00, 29.62it/s]  
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.3532
Last train accuracy: 0.1124
Last test accuracy: 0.1135


Run 7 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [03:06<00:00,  3.74s/it]


Last loss: 0.3543
Last train accuracy: 0.9633
Last test accuracy: 0.9420
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0903
INITIAL test accuracy: 0.0892
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [26:35<00:00, 37.60it/s]  
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.3592
Last train accuracy: 0.1124
Last test accuracy: 0.1135


Run 8 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [02:32<00:00,  3.05s/it]


Last loss: 0.3545
Last train accuracy: 0.9507
Last test accuracy: 0.9340
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0903


  0%|          | 0/60000 [00:00<?, ?it/s]

INITIAL test accuracy: 0.0892
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [20:37<00:00, 48.49it/s]  
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.3552
Last train accuracy: 0.1124
Last test accuracy: 0.1135


Run 9 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [02:20<00:00,  2.81s/it]


Last loss: 0.3668
Last train accuracy: 0.9507
Last test accuracy: 0.9380
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0903


  0%|          | 0/60000 [00:00<?, ?it/s]

INITIAL test accuracy: 0.0892
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [17:00<00:00, 58.79it/s]  
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.3452
Last train accuracy: 0.1124
Last test accuracy: 0.1135


Run 10 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [02:01<00:00,  2.44s/it]


Last loss: 0.3765
Last train accuracy: 0.9727
Last test accuracy: 0.9660
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0903


  0%|          | 0/60000 [00:00<?, ?it/s]

INITIAL test accuracy: 0.0892
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [17:09<00:00, 58.30it/s]  


Last loss: 2.3652
Last train accuracy: 0.1124
Last test accuracy: 0.1135






In [None]:
print('==== Interpretation: PrePost (universal) ====')
stats_prepost_up, stats_prepost_down = evaluate_generalization(
    brain_prepost_up_fact, brain_prepost_down_fact, n_up, n_down,
    dataset_up=dataset_up, dataset_down=dataset_down,
    num_runs=num_runs, num_rule_epochs=num_rule_epochs,
    num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)
print('==== Interpretation: PrePostCount (universal) ====')
stats_prepostcount_up, stats_prepostcount_down = evaluate_generalization(
    brain_prepostcount_up_fact, brain_prepostcount_down_fact, n_up, n_down,
    dataset_up=dataset_up, dataset_down=dataset_down,
    num_runs=num_runs, num_rule_epochs=num_rule_epochs,
    num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)

  0%|          | 0/50 [00:00<?, ?it/s]

==== Interpretation: PrePost (universal) ====

Run 1 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [03:24<00:00,  4.09s/it]


Last loss: 0.6931
Last train accuracy: 0.4840
Last test accuracy: 0.5060
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0987


  0%|          | 0/60000 [00:00<?, ?it/s]

INITIAL test accuracy: 0.0980
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [24:29<00:00, 40.82it/s]  
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.3026
Last train accuracy: 0.0987
Last test accuracy: 0.0980


Run 2 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [03:21<00:00,  4.03s/it]


Last loss: 0.6931
Last train accuracy: 0.4960
Last test accuracy: 0.4660
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0987


  0%|          | 0/60000 [00:00<?, ?it/s]

INITIAL test accuracy: 0.0980
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [24:57<00:00, 40.06it/s]  
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.3026
Last train accuracy: 0.0987
Last test accuracy: 0.0980


Run 3 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [03:14<00:00,  3.89s/it]


Last loss: 0.3848
Last train accuracy: 0.9313
Last test accuracy: 0.9380
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0987


  0%|          | 0/60000 [00:00<?, ?it/s]

INITIAL test accuracy: 0.0980
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [26:01<00:00, 38.42it/s]  
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.3732
Last train accuracy: 0.1124
Last test accuracy: 0.1135


Run 4 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [03:24<00:00,  4.08s/it]


Last loss: 0.6931
Last train accuracy: 0.5167
Last test accuracy: 0.4580
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0987


  0%|          | 0/60000 [00:00<?, ?it/s]

INITIAL test accuracy: 0.0980
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [25:56<00:00, 38.56it/s]  
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.2142
Last train accuracy: 0.1927
Last test accuracy: 0.1965


Run 5 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [03:34<00:00,  4.30s/it]


Last loss: 0.4257
Last train accuracy: 0.8647
Last test accuracy: 0.8280
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0987


  0%|          | 0/60000 [00:00<?, ?it/s]

INITIAL test accuracy: 0.0980
Epoch 1 / 1 ...


100%|██████████| 60000/60000 [25:55<00:00, 38.57it/s]  
  0%|          | 0/50 [00:00<?, ?it/s]


Last loss: 2.2172
Last train accuracy: 0.2613
Last test accuracy: 0.2665


Run 6 / 10...
Meta-learning on halfspace...


100%|██████████| 50/50 [03:34<00:00,  4.28s/it]


Last loss: 0.6931
Last train accuracy: 0.4967
Last test accuracy: 0.5200
mnist_train: 60000
mnist_test: 10000
Training NEW brain instance (WITH backprop) on mnist...
INITIAL train accuracy: 0.0987


  0%|          | 0/60000 [00:00<?, ?it/s]

INITIAL test accuracy: 0.0980
Epoch 1 / 1 ...


 86%|████████▌ | 51497/60000 [20:42<03:03, 46.46it/s]  

In [None]:
if 0:
    print('==== Interpretation: PrePost (NOT universal) ====')
    stats_prepost_nonuni_up, stats_prepost_nonuni_down = evaluate_generalization(
        brain_prepost_nonuni_up_fact, brain_prepost_nonuni_down_fact, n_up, n_down,
        dataset_up=dataset_up, dataset_down=dataset_down,
        num_runs=num_runs, num_rule_epochs=num_rule_epochs,
        num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)
    print('==== Interpretation: PrePostCount (NOT universal) ====')
    stats_prepostcount_nonuni_up, stats_prepostcount_nonuni_down = evaluate_generalization(
        brain_prepostcount_nonuni_up_fact, brain_prepostcount_nonuni_down_fact, n_up, n_down,
        dataset_up=dataset_up, dataset_down=dataset_down,
        num_runs=num_runs, num_rule_epochs=num_rule_epochs,
        num_epochs_upstream=num_epochs_upstream, num_epochs_downstream=num_epochs_downstream)

In [None]:
# Plot aggregated stats.
agg_stats_rnn_up = convert_multi_stats_uncertainty(stats_rnn_up)
agg_stats_rnn_down = convert_multi_stats_uncertainty(stats_rnn_down)
plot_curves(agg_stats_rnn_up, agg_stats_rnn_down,
            '[RNN] Upstream meta-learning on ' + dataset_up,
            '[RNN] Downstream training on ' + dataset_down,
            'figs/generalization_rnn_' + dataset_up + '_' + dataset_down,
            no_downstream_loss=True)

agg_stats_prepost_up = convert_multi_stats_uncertainty(stats_prepost_up)
agg_stats_prepost_down = convert_multi_stats_uncertainty(stats_prepost_down)
plot_curves(agg_stats_prepost_up, agg_stats_prepost_down,
            '[PrePost-Uni] Upstream meta-learning on ' + dataset_up,
            '[PrePost-Uni] Downstream training on ' + dataset_down,
            'figs/generalization_prepost_uni_' + dataset_up + '_' + dataset_down,
            no_downstream_loss=True)
agg_stats_prepostcount_up = convert_multi_stats_uncertainty(stats_prepostcount_up)
agg_stats_prepostcount_down = convert_multi_stats_uncertainty(stats_prepostcount_down)
plot_curves(agg_stats_prepostcount_up, agg_stats_prepostcount_down,
            '[PrePostCount-Uni] Upstream meta-learning on ' + dataset_up,
            '[PrePostCount-Uni] Downstream training on ' + dataset_down,
            'figs/generalization_prepostcount_uni_' + dataset_up + '_' + dataset_down,
            no_downstream_loss=True)

if 0:
    agg_stats_prepost_nonuni_up = convert_multi_stats_uncertainty(stats_prepost_nonuni_up)
    agg_stats_prepost_nonuni_down = convert_multi_stats_uncertainty(stats_prepost_nonuni_down)
    plot_curves(agg_stats_prepost_nonuni_up, agg_stats_prepost_nonuni_down,
                '[PrePost-Uni] Upstream meta-learning on ' + dataset_up,
                '[PrePost-Uni] Downstream training on ' + dataset_down,
                'figs/generalization_prepost_nonuni_' + dataset_up + '_' + dataset_down,
                no_downstream_loss=True)
    agg_stats_prepostcount_nonuni_up = convert_multi_stats_uncertainty(stats_prepostcount_nonuni_up)
    agg_stats_prepostcount_nonuni_down = convert_multi_stats_uncertainty(stats_prepostcount_nonuni_down)
    plot_curves(agg_stats_prepostcount_nonuni_up, agg_stats_prepostcount_nonuni_down,
                '[PrePostCount-Uni] Upstream meta-learning on ' + dataset_up,
                '[PrePostCount-Uni] Downstream training on ' + dataset_down,
                'figs/generalization_prepostcount_nonuni_' + dataset_up + '_' + dataset_down,
                no_downstream_loss=True)

## Train vanilla net

In [None]:
# Instantiate model.
brain_vanilla = FFBrainNet(
    n_down, m_down, l, w*10, p, cap*10, full_gd=True)

In [None]:
# Evaluate model.
print('==== Vanilla ====')
X_train, y_train, X_test, y_test = quick_get_data('mnist', 28 * 28)
print('Training VANILLA brain instance (WITH backprop) on mnist...')
stats_vanilla = train_downstream(
    X_train, y_train, brain_vanilla, num_epochs=num_epochs_downstream,
    batch_size=100, vanilla=True, learn_rate=5e-3,
    X_test=X_test, y_test=y_test, verbose=False, stats_interval=500)

In [None]:
# Plot aggregated stats.
plot_curves(None, stats_vanilla, None,
            '[Vanilla] Downstream training on ' + dataset_down,
            'figs/generalization_vanilla_' + dataset_down,
            no_downstream_loss=True)

In [None]:
# Plot to compare some.
all_stats_up = [agg_stats_rnn_up, agg_stats_prepost_up, agg_stats_prepostcount_up, None]
all_stats_down = [agg_stats_rnn_down, agg_stats_prepost_down, agg_stats_prepostcount_down, stats_vanilla]
labels = ['RNN', 'PrePost', 'PrePostCount', 'Vanilla']
plot_compare_models(all_stats_up, all_stats_down, labels,
                    'Upstream meta-learning on ' + dataset_up,
                    'Downstream training on ' + dataset_down,
                    'figs/generalization_all_' + dataset_up + '_' + dataset_down)

In [None]:
# Plot to compare all.
if 0:
    all_stats_up = [agg_stats_rnn_up, agg_stats_prepost_up, agg_stats_prepostcount_up, agg_stats_prepost_nonuni_up, agg_stats_prepostcount_nonuni_up, None]
    all_stats_down = [agg_stats_rnn_down, agg_stats_prepost_down, agg_stats_prepostcount_down, agg_stats_prepost_nonuni_down, agg_stats_prepostcount_nonuni_down, stats_vanilla]
    labels = ['RNN', 'PrePostUni', 'PrePostCountUni', 'PrePostNonUni', 'PrePostCountNonUni', 'Vanilla']
    plot_compare_models(all_stats_up, all_stats_down, labels,
                        'Upstream meta-learning on ' + dataset_up,
                        'Downstream training on ' + dataset_down,
                        'figs/generalization_all_' + dataset_up + '_' + dataset_down)