In [1]:
# Library imports.
import numpy as np
import matplotlib.pyplot as plt
import sys
import torch
import torchvision.datasets
sys.path.append('../')
plt.style.use('seaborn')

# Repository imports.
from FFBrainNet import FFBrainNet
from LocalNetBase import Options, UpdateScheme
from DataGenerator import random_halfspace_data
from train import *

## Methodology in this notebook: For every model, meta-learn on 4-dimensional halfspace, then train / test on 8-dimensional halfspace by transferring rules across networks.

In [2]:
def evaluate_simple(brain_up, brain_down, n_up, n_down):
    # Upstream.
    N = 1000
    X,y = random_halfspace_data(dim=n_up, n=3*N)
    X_test = X[:N]
    y_test = y[:N]
    X_train = X[N:]
    y_train = y[N:]

    print('Meta-learning...')
    data_up = metalearn_rules(
        X_train, y_train, brain_up, num_rule_epochs=20, num_epochs=2, batch_size=100, learn_rate=1e-2,
        X_test=X_test, y_test=y_test, verbose=False)
    
    # Transfer rules.
    try:
        brain_down.set_hidden_layer_rule(brain_up.get_hidden_layer_rule())
        brain_down.set_output_rule(brain_up.get_output_rule())
    except:
        brain_down.set_rnn_rule(brain_up.get_rnn_rule())
        brain_down.set_output_rule(brain_up.get_output_rule())
    
    # Downstream.
    N = 1000
    X, y = random_halfspace_data(dim=n_down, n=3*N)
    X_test = X[:N]
    y_test = y[:N]
    X_train = X[N:]
    y_train = y[N:]
    
    print('Training...')
    data_down = train_downstream(
        X_train, y_train, brain_down, num_epochs=5, batch_size=100, vanilla=False, learn_rate=1e-2,
        X_test=X_test, y_test=y_test, verbose=False, stats_interval=500)
    
    return (data_up, data_down)

In [3]:
def plot_curves(data_up, data_down):
    (meta_losses, meta_train_acc, meta_test_acc, meta_sample_counts, meta_stats) = data_up
    (plas_losses, plas_train_acc, plas_test_acc, plas_sample_counts, plas_stats) = data_down

    fig, ax = plt.subplots(1, 2, figsize=(14, 5))
    
    ax[0].plot(meta_sample_counts, meta_losses, label='loss')
    ax[0].plot(meta_sample_counts, meta_train_acc, label='train')
    ax[0].plot(meta_sample_counts, meta_test_acc, label='test')
    ax[0].set_xlabel('Cumulative number of training samples')
    ax[0].set_ylabel('Accuracy')
    ax[0].set_title('Upstream meta-learning on 4-dim half-space')
    ax[0].legend()
    
    ax[1].plot(plas_sample_counts[1:], plas_losses[1:], label='loss')
    ax[1].plot(plas_sample_counts, plas_train_acc, label='train')
    ax[1].plot(plas_sample_counts, plas_test_acc, label='test')
    ax[1].set_xlabel('Cumulative number of training samples')
    ax[1].set_ylabel('Accuracy')
    ax[1].set_title('Downstream training on 8-dim half-space')
    ax[1].legend()
    
    plt.show()

In [4]:
def plot_compare_models(datas_up, datas_down, labels):
    num_models = len(datas_up)
    assert(num_models == len(datas_down) and num_models == len(labels))
    
    fig, ax = plt.subplots(1, 2, figsize=(14, 5))
    for i in range(num_models):
        ax[0].plot(datas_up[i][3], datas_up[i][2], label=labels[i])
        ax[1].plot(datas_down[i][3], datas_down[i][2], label=labels[i])
        
    ax[0].set_xlabel('Cumulative number of training samples')
    ax[0].set_ylabel('Test accuracy')
    ax[0].set_title('Upstream meta-learning on 4-dim half-space')
    ax[0].legend()
    ax[1].set_xlabel('Cumulative number of training samples')
    ax[1].set_ylabel('Test accuracy')
    ax[1].set_title('Downstream training on 8-dim half-space')
    ax[1].legend()
    plt.show()

## First, test original RNN as sanity check.

In [5]:
# IMPORTANT: Henceforth, we use GD directly on inputs but use plasticity rules in the output and hidden layers.
opts_up = Options(gd_input=True,
                  use_graph_rule=True,
                  gd_graph_rule=True,
                  use_output_rule=True,
                  gd_output_rule=True,
                  gd_output=False)
opts_down = Options(gd_input=True,
                    use_graph_rule=True,
                    gd_graph_rule=False,  # Not meta-trainable anymore!
                    use_output_rule=True,
                    gd_output_rule=False,  # Not meta-trainable anymore!
                    gd_output=False)
scheme = UpdateScheme(cross_entropy_loss=True,
                      mse_loss=False,
                      update_misclassified_only=False,
                      update_all_edges=True)
n_up = 4  # Input layer size for meta-learning.
n_down = 8  # Input layer size for desired task training.
m = 2  # Output layer size.
l = 2  # Number of hidden layers.
w = 10  # Width of hidden layers.
p = 0.5  # Connectivity probability.
cap = 5  # Number of nodes firing per layer.

In [6]:
from network import *
brain_rnn_up = LocalNet(n_up, m, 100, p, 50, 3, options=opts_up, update_scheme=scheme)
brain_rnn_down = LocalNet(n_down, m, 100, p, 50, 3, options=opts_down, update_scheme=scheme)

In [None]:
print('==== Original RNN (very different from all the rest) ====')
data_rnn = evaluate_simple(brain_rnn_up, brain_rnn_down, n_up, n_down)
plot_curves(*data_rnn)

  0%|                  | 0/20 [00:00<?, ?it/s]

==== Original RNN (very different from all the rest) ====
Meta-learning...


 15%|█▌        | 3/20 [00:12<01:10,  4.13s/it]

## Then, test all table-based feed-forward networks by Brett.

In [None]:
from FFLocalTableRules.FFLocalTable_PrePost import FFLocalTable_PrePost
from FFLocalTableRules.FFLocalTable_PrePostCount import FFLocalTable_PrePostCount
from FFLocalTableRules.FFLocalTable_PrePostPercent import FFLocalTable_PrePostPercent
from FFLocalTableRules.FFLocalTable_PostCount import FFLocalTable_PostCount

In [None]:
# Initialize models.
brain_prepost_up = FFLocalTable_PrePost(n_up, m, l, w, p, cap, options=opts_up, update_scheme=scheme)
brain_prepost_down = FFLocalTable_PrePost(n_down, m, l, w, p, cap, options=opts_down, update_scheme=scheme)
brain_prepostcount_up = FFLocalTable_PrePostCount(n_up, m, l, w, p, cap, options=opts_up, update_scheme=scheme)
brain_prepostcount_down = FFLocalTable_PrePostCount(n_down, m, l, w, p, cap, options=opts_down, update_scheme=scheme)
brain_prepostpercent_up = FFLocalTable_PrePostPercent(n_up, m, l, w, p, cap, options=opts_up, update_scheme=scheme)
brain_prepostpercent_down = FFLocalTable_PrePostPercent(n_down, m, l, w, p, cap, options=opts_down, update_scheme=scheme)
brain_postcount_up = FFLocalTable_PostCount(n_up, m, l, w, p, cap, options=opts_up, update_scheme=scheme)
brain_postcount_down = FFLocalTable_PostCount(n_down, m, l, w, p, cap, options=opts_down, update_scheme=scheme)

In [None]:
# Evaluate models.
print('==== Interpretation: Pre and Post ====')
data_prepost = evaluate_simple(brain_prepost_up, brain_prepost_down, n_up, n_down)
plot_curves(*data_prepost)
print('==== Interpretation: Pre and Post and Incoming Count ====')
data_prepostcount = evaluate_simple(brain_prepostcount_up, brain_prepostcount_down, n_up, n_down)
plot_curves(*data_prepostcount)
print('==== Interpretation: Pre and Post and Binned Incoming Fraction ====')
data_prepostpercent = evaluate_simple(brain_prepostpercent_up, brain_prepostpercent_down, n_up, n_down)
plot_curves(*data_prepostpercent)
print('==== Interpretation: Post and Incoming Count ====')
data_postcount = evaluate_simple(brain_postcount_up, brain_postcount_down, n_up, n_down)
plot_curves(*data_postcount)

In [None]:
datas = [data_rnn, data_prepost, data_prepostcount, data_prepostpercent, data_postcount]
labels = ['RNN', 'PrePost', 'PrePostCount', 'PrePostPercent', 'PostCount']
plot_compare_models([x[0] for x in datas], [x[1] for x in datas], labels)