In [1]:
import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt

import argparse
import time
import tensorflow as tf

from tuner import HyperparameterTuner

use_tpu = False

if use_tpu:
    from tensorflow.contrib import tpu
    from tensorflow.contrib.cluster_resolver import TPUClusterResolver
    
%load_ext autoreload
%autoreload 2

In [5]:
hidden_layers = 2
hidden_units = 500
trials = 10
epochs = 10
task_home = ''
if use_tpu:
    task_home = 'gs://continual_learning/permMNIST_EWC/'
else:
    task_home = './'
checkpoint_path = task_home + 'unequal_split_logs/checkpoints/'
summaries_path = task_home + 'unequal_split_logs/summaries/'
data_path = task_home + 'MNIST_data/'
split_path = './split.txt' 
if use_tpu:
    tpu_name = 'gectpu'

In [3]:
if use_tpu:
    tpu_cluster = TPUClusterResolver(tpu=[tpu_name]).get_master()
    sess = tf.Session(tpu_cluster)
    sess.run(tpu.initialize_system())
else:
    sess = tf.Session()

In [9]:
tuner = HyperparameterTuner(sess=sess, hidden_layers=hidden_layers, hidden_units=hidden_units,
                            trials=trials, epochs=epochs,
                            checkpoint_path=checkpoint_path, summaries_path=summaries_path, 
                            data_path=data_path, split_path=split_path)

Extracting ./MNIST_data/train-images-idx3-ubyte.gz
Extracting ./MNIST_data/train-labels-idx1-ubyte.gz
Extracting ./MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ./MNIST_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See @{tf.nn.softmax_cross_entropy_with_logits_v2}.



In [10]:
from queue import PriorityQueue

In [None]:
# task 0
t = 0
queue = PriorityQueue()
for learning_rate in tuner.trial_learning_rates:
    tuner.train_on_task(t, learning_rate, queue)
tuner.best_parameters.append(queue.get())

training layers=2,hidden=500,lr=0.00094,multiplier=1067.63,mbsize=250,epochs=10,perm=0 with weights initialized at None


In [10]:
# accuracy on validation sets
t = 1
print("task0", tuner.classifier.test(sess=tuner.sess, model_name=tuner.best_parameters[t - 1][1], batch_xs=tuner.task_list[0].validation.images, batch_ys=tuner.task_list[0].validation.labels))
print("task1", tuner.classifier.test(sess=tuner.sess, model_name=tuner.best_parameters[t - 1][1], batch_xs=tuner.task_list[1].validation.images, batch_ys=tuner.task_list[1].validation.labels))

INFO:tensorflow:Restoring parameters from gs://continual_learning/permMNIST_EWC/logs/checkpoints/layers=2,hidden=500,lr=0.00077,multiplier=1291.96,mbsize=250,epochs=10,perm=0.ckpt-1119
task0 0.9910086
INFO:tensorflow:Restoring parameters from gs://continual_learning/permMNIST_EWC/logs/checkpoints/layers=2,hidden=500,lr=0.00077,multiplier=1291.96,mbsize=250,epochs=10,perm=0.ckpt-1119
task1 0.0


In [11]:
# task1
t = 1

N_learning_rates = 10
learning_rates = 10.0 ** np.arange(-9, 2)
# lr = 0.00094
# 1e19 works good - [62, 72] for 15000
#1e22 - [69, 68] for 15000
# fisher_multiplier = 1e24
fisher_multipliers = 10.0 ** np.arange(18, 28)
dataset_train = tuner.task_list[t].train
dataset_lagged = tuner.task_list[t - 1].train if t > 0 else None
model_init_name = tuner.best_parameters[t - 1][1] if t > 0 else None
MINI_BATCH_SIZE = 256
LOG_FREQUENCY = 100
dataset_train.initialize_iterator(MINI_BATCH_SIZE)
if (dataset_lagged is not None):
    dataset_lagged.initialize_iterator(MINI_BATCH_SIZE)

eval_frequency = 100
num_updates = 3000

result = {}
dropout_input = 1.0

In [22]:
num_updates = 5000
dropouts = [0.5, 0.4, 0.3, 0.2, 0.1]
# fisher_multipliers = np.logspace(1, 6, 50)
fisher_multipliers = [0.0]
learning_rates = [5e-6]
best_avg = 0.0
best_params = -1
best_num_updates = -1
for dropout in dropouts:
    for fisher_multiplier in fisher_multipliers:
        for lr in learning_rates:
            print("dropout: %f, fisher_multiplier: %e, lr: %e" % (dropout, fisher_multiplier, lr))
            start_time = time.time()
            model_name = tuner.file_name(lr, t)
            tuner.classifier.set_dropout(dropout_input, dropout)
            tuner.classifier.prepare_for_training(sess=tuner.sess, 
                                                model_name=model_name, 
                                                model_init_name=model_init_name, 
                                                fisher_multiplier=fisher_multiplier, 
                                                learning_rate=lr)
            val_acc = [[], []]
            val_loss = [[], []]
            loss = []
            loss_with_penalty = []
            cur_best_avg = 0.0
            cur_best_avg_num_updates = -1
            i = 0
            count_not_improving = 0
            while (True):
                cur_loss, cur_loss_with_penalty = tuner.classifier.minibatch_sgd(tuner.sess, i, dataset_train, MINI_BATCH_SIZE, LOG_FREQUENCY)
                loss.append(cur_loss)
                loss_with_penalty.append(cur_loss_with_penalty)
                if (i % eval_frequency == 0):
                    cur_iter_avg = 0.0
                    for j in range(num_perms):
                        val_data = tuner.task_list[j].validation
                        feed_dict = tuner.classifier.create_feed_dict(val_data.images, val_data.labels, keep_input=1.0, keep_hidden=1.0)
                        accuracy = sess.run([tuner.classifier.loss, tuner.classifier.accuracy], feed_dict=feed_dict)
                        val_loss[j].append(accuracy[0])
                        val_acc[j].append(accuracy[1])
                        cur_iter_avg += accuracy[1]
                    cur_iter_avg /= num_perms

                    if (val_acc[0][-1] < val_acc[1][-1]):
                        if (cur_best_avg >= cur_iter_avg):
                            count_not_improving += 1
                        else:
                            count_not_improving = 0

                    if (cur_iter_avg > cur_best_avg):
                        cur_best_avg = cur_iter_avg
                        cur_best_avg_num_updates = i

                    if (count_not_improving >= 5):
                        break
                i += 1
        cur_params = (dropout, fisher_multiplier, lr)
        result[cur_params] = {}
        result[cur_params]["val_acc"] = val_acc
        result[cur_params]["val_loss"] = val_loss
        result[cur_params]["loss"] = loss
        result[cur_params]["loss_with_penalty"] = loss_with_penalty
        result[cur_params]["best_avg"] = (cur_best_avg, cur_best_avg_num_updates)
        if (best_avg < cur_best_avg):
            best_avg = cur_best_avg
            best_params = cur_params
            best_num_updates = cur_best_avg_num_updates
        print("time taken: %f" % (time.time() - start_time))
        print("loss with penalty: %f, loss: %f, val0 accuracy: %f, val1 accuracy: %f"
              % (loss_with_penalty[-1], loss[-1], 
                 val_acc[0][cur_best_avg_num_updates // eval_frequency], val_acc[1][cur_best_avg_num_updates // eval_frequency]))

dropout: 0.500000, fisher_multiplier: 0.000000e+00, lr: 5.000000e-06
INFO:tensorflow:Restoring parameters from gs://continual_learning/permMNIST_EWC/logs/checkpoints/layers=2,hidden=500,lr=0.00077,multiplier=1291.96,mbsize=250,epochs=10,perm=0.ckpt-1119
time taken: 53.323640
loss with penalty: 0.792711, loss: 0.792711, val0 accuracy: 0.910868, val1 accuracy: 0.791564
dropout: 0.400000, fisher_multiplier: 0.000000e+00, lr: 5.000000e-06
INFO:tensorflow:Restoring parameters from gs://continual_learning/permMNIST_EWC/logs/checkpoints/layers=2,hidden=500,lr=0.00077,multiplier=1291.96,mbsize=250,epochs=10,perm=0.ckpt-1119
time taken: 57.301655
loss with penalty: 1.191719, loss: 1.191719, val0 accuracy: 0.901095, val1 accuracy: 0.791155
dropout: 0.300000, fisher_multiplier: 0.000000e+00, lr: 5.000000e-06
INFO:tensorflow:Restoring parameters from gs://continual_learning/permMNIST_EWC/logs/checkpoints/layers=2,hidden=500,lr=0.00077,multiplier=1291.96,mbsize=250,epochs=10,perm=0.ckpt-1119
time t

In [None]:
print("best_avg: %e, best_params: %s" % (best_avg, str(best_params)))

for k, v in result.items():
    cur_res = v
    x = np.arange(0, len(cur_res['loss']), eval_frequency)
    cur_best_avg = cur_res['best_avg']
    print("dropout: %f, fisher_multiplier: %e, lr: %e" % (k[0], k[1], k[2]))
    print("cur_best_avg: %e, num_updates: %e" % (cur_best_avg[0], cur_best_avg[1]))
    print("val0_acc: %e, val1_acc: %e" %
        (cur_res['val_acc'][0][cur_best_avg[1] // eval_frequency], cur_res['val_acc'][1][cur_best_avg[1] // eval_frequency]))
    plt.plot(cur_res['loss_with_penalty'], color='g')
    plt.plot(cur_res['loss'], color='m')
    plt.plot(x, cur_res['val_loss'][1], color='b')
    plt.show()
    plt.plot(cur_res['val_acc'][0], color='b')
    plt.plot(cur_res['val_acc'][1], color='g')
    plt.show()

In [12]:
def get_confusion_matrix(tuner):
    num_labels = 10
    pred = np.array([])
    actual = np.array([])
    for j in range(num_perms):
        val_data = tuner.task_list[j].validation
        feed_dict = tuner.classifier.create_feed_dict(val_data.images, val_data.labels, keep_input=1.0, keep_hidden=1.0)
        cur_scores, cur_y = tuner.classifier.get_predictions(sess, feed_dict)
        cur_pred = np.argmax(cur_scores, 1)
        cur_actual = np.argmax(cur_y, 1)
        print(cur_actual)
        actual = np.concatenate([actual, cur_actual])
        pred = np.concatenate([pred, cur_pred])
    confusion_matrix = np.zeros((num_labels,num_labels), dtype=np.int64)

    for i in range(actual.shape[0]):
        confusion_matrix[int(actual[i]), int(pred[i])] += 1
    return confusion_matrix

def print_confusion_matrix(confusion_matrix):
    print("%3d" % (0, ), end=' ')
    for j in range(confusion_matrix.shape[1]):
        print("%3d" % (j, ), end=' ')
    print("")
    for i in range(confusion_matrix.shape[0]):
        print("%3d" % (i, ), end=' ')
        for j in range(confusion_matrix.shape[1]):
            print("%3d" % (confusion_matrix[i][j], ), end= ' ')
        print("")

In [13]:
# confusion matrix before training
# train on best hyperparameters

# best_params = (1.0, 10 ** (18 / 19 + 4), 5e-6)
# best_num_updates = 3600

dropout, fisher_multiplier, lr = best_params

model_name = tuner.file_name(lr, t)
tuner.classifier.prepare_for_training(sess=tuner.sess, 
                                    model_name=model_name, 
                                    model_init_name=model_init_name, 
                                    fisher_multiplier=fisher_multiplier, 
                                    learning_rate=lr)

confusion_matrix = get_confusion_matrix(tuner)
print_confusion_matrix(confusion_matrix)


INFO:tensorflow:Restoring parameters from gs://continual_learning/permMNIST_EWC/logs/checkpoints/layers=2,hidden=500,lr=0.00077,multiplier=1291.96,mbsize=250,epochs=10,perm=0.ckpt-1119
[0 4 1 ... 2 1 2]
[5 9 5 ... 7 8 7]
  0   0   1   2   3   4   5   6   7   8   9 
  0 477   0   0   1   1   0   0   0   0   0 
  1   0 558   2   1   2   0   0   0   0   0 
  2   4   3 480   1   0   0   0   0   0   0 
  3   0   0   4 488   1   0   0   0   0   0 
  4   0   2   1   0 532   0   0   0   0   0 
  5  46  20   4 292  72   0   0   0   0   0 
  6  53  35 171  13 229   0   0   0   0   0 
  7  42  27 100 213 168   0   0   0   0   0 
  8  13  81  68 266  34   0   0   0   0   0 
  9   4   5   2  14 470   0   0   0   0   0 


In [14]:
# train on best hyperparameters
dropout, fisher_multiplier, lr = best_params
print("dropout: %f, fisher_multiplier: %e, lr: %e" % (dropout, fisher_multiplier, lr))
start_time = time.time()
model_name = tuner.file_name(lr, t)
tuner.classifier.set_dropout(dropout_input, dropout)
tuner.classifier.prepare_for_training(sess=tuner.sess, 
                                    model_name=model_name, 
                                    model_init_name=model_init_name, 
                                    fisher_multiplier=fisher_multiplier, 
                                    learning_rate=lr)
val_acc = [[], []]
val_loss = [[], []]
loss = []
loss_with_penalty = []
cur_best_avg = 0.0
cur_best_avg_num_updates = -1
for i in range(best_num_updates):
    cur_loss, cur_loss_with_penalty = tuner.classifier.minibatch_sgd(tuner.sess, i, dataset_train, MINI_BATCH_SIZE, LOG_FREQUENCY)
    loss.append(cur_loss)
    loss_with_penalty.append(cur_loss_with_penalty)
    if (i % eval_frequency == 0):
        cur_iter_avg = 0.0
        for j in range(num_perms):
            val_data = tuner.task_list[j].validation
            feed_dict = tuner.classifier.create_feed_dict(val_data.images, val_data.labels, keep_input=1.0, keep_hidden=1.0)
            accuracy = sess.run([tuner.classifier.loss, tuner.classifier.accuracy], feed_dict=feed_dict)
            val_loss[j].append(accuracy[0])
            val_acc[j].append(accuracy[1])
            cur_iter_avg += accuracy[1]
        cur_iter_avg /= num_perms

        if (cur_iter_avg > cur_best_avg):
            cur_best_avg = cur_iter_avg
            cur_best_avg_num_updates = i

cur_params = (dropout, fisher_multiplier, lr)
result[cur_params] = {}
result[cur_params]["val_acc"] = val_acc
result[cur_params]["val_loss"] = val_loss
result[cur_params]["loss"] = loss
result[cur_params]["loss_with_penalty"] = loss_with_penalty
result[cur_params]["best_avg"] = (cur_best_avg, cur_best_avg_num_updates)
print("time taken: %f" % (time.time() - start_time))
print("loss with penalty: %f, loss: %f, val0 accuracy: %f, val1 accuracy: %f"
      % (loss_with_penalty[-1], loss[-1], 
         val_acc[0][cur_best_avg_num_updates // eval_frequency], val_acc[1][cur_best_avg_num_updates // eval_frequency]))

dropout: 1.000000, fisher_multiplier: 8.858668e+04, lr: 5.000000e-06
INFO:tensorflow:Restoring parameters from gs://continual_learning/permMNIST_EWC/logs/checkpoints/layers=2,hidden=500,lr=0.00077,multiplier=1291.96,mbsize=250,epochs=10,perm=0.ckpt-1119
time taken: 56.365972
loss with penalty: 1.133778, loss: 0.486593, val0 accuracy: 0.919077, val1 accuracy: 0.860770


In [15]:
# confusion matrix 
# finding which digits are being confused by classifier
confusion_matrix = get_confusion_matrix(tuner)
print_confusion_matrix(confusion_matrix)

[0 4 1 ... 2 1 2]
[5 9 5 ... 7 8 7]
  0   0   1   2   3   4   5   6   7   8   9 
  0 460   0   0   0   0  10   7   1   1   0 
  1   0 555   1   0   1   1   0   1   3   1 
  2   2   4 463   1   0   1   7   1   9   0 
  3   0   1   4 419   0  42   1   4  19   3 
  4   0   2   2   0 432   0  13   0   0  86 
  5   5   6   2  20   1 355  13   3  18  11 
  6   4   6   5   0   2   9 470   0   4   1 
  7   9   6  14   0   6   6   6 473   7  23 
  8   1   9   5   0   0  33   8   3 384  19 
  9   3   5   1   3   3   7   5  15   6 447 


In [16]:
# test accuracy
for j in range(num_perms):
    test_data = tuner.task_list[j].test
    feed_dict = tuner.classifier.create_feed_dict(test_data.images, test_data.labels, keep_input=1.0, keep_hidden=1.0)
    accuracy = sess.run([tuner.classifier.loss, tuner.classifier.accuracy], feed_dict=feed_dict)
    print(accuracy)

[0.30288878, 0.89920217]
[0.53332376, 0.85969967]


In [17]:
if use_tpu:
    sess.run(tpu.shutdown_system())

sess.close()