In [18]:
"""Utility functions for benchmarking online learning"""
from __future__ import division
import numpy as np
import keras
from keras.utils import np_utils

from keras.datasets import mnist, cifar10, cifar100
from keras.optimizers import Adam, RMSprop, SGD
import keras.backend as K

import pickle
import gzip

import tensorflow as tf
from keras.layers.core import Dense
from keras.layers import Conv2D, AveragePooling2D, MaxPool2D, Flatten, InputLayer
from keras.datasets import mnist
from keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt
#from util import computer_fisher, ewc_reg




import numpy as np
import keras.backend as K
from keras.regularizers import Regularizer


def computer_fisher(model, imgset, num_sample=30):
    f_accum = []
    for i in range(len(model.weights)):
        f_accum.append(np.zeros(K.int_shape(model.weights[i])))
    f_accum = np.array(f_accum)
    for j in range(num_sample):
        img_index = np.random.randint(imgset.shape[0])
        for m in range(len(model.weights)):
            grads = K.gradients(K.log(model.output), model.weights)[m]
            result = K.function([model.input], [grads])
            f_accum[m] += np.square(result([np.expand_dims(imgset[img_index], 0)])[0])
    f_accum /= num_sample
    return f_accum


class ewc_reg(Regularizer):
    def __init__(self, fisher, prior_weights, Lambda=0.99):
        self.fisher = fisher
        self.prior_weights = prior_weights
        self.Lambda = Lambda

    def __call__(self, x):
        regularization = 0.
        regularization += Lambda * K.sum(self.fisher * K.square(x - self.prior_weights))
        return regularization

    def get_config(self):
        return {'Lambda': float(Lambda)}





def split_dataset_by_labels(X, y, task_labels, nb_classes=None, multihead=False):
    """Split dataset by labels.

    Args:
        X: data
        y: labels
        task_labels: list of list of labels, one for each dataset
        nb_classes: number of classes (used to convert to one-hot)
    Returns:
        List of (X, y) tuples representing each dataset
    """
    if nb_classes is None:
        nb_classes = len(np.unique(y))
    datasets = []
    for labels in task_labels:
        idx = np.in1d(y, labels)
        if multihead:
            label_map = np.arange(nb_classes)
            label_map[labels] = np.arange(len(labels))
            data = X[idx], np_utils.to_categorical(label_map[y[idx]], len(labels))
        else:
            data = X[idx], np_utils.to_categorical(y[idx], nb_classes)
        datasets.append(data)
    return datasets


def load_mnist(split='train'):
  
    # input image dimensions
    img_rows, img_cols = 28, 28
    
    
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
#     X_train = X_train.reshape(-1, 784)
#     X_test = X_test.reshape(-1, 784)
    
    if K.image_data_format() == 'channels_first':
      X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
      X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
      input_shape = (1, img_rows, img_cols)
    else:
      X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
      X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
      input_shape = (img_rows, img_cols, 1)
      
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_train /= 255
    X_test /= 255

    if split == 'train':
        X, y = X_train, y_train
    else:
        X, y = X_test, y_test
    nb_classes = 10
    y = np_utils.to_categorical(y, nb_classes)
    return X, y

def construct_split_mnist(task_labels,  split='train', multihead=False):
    """Split MNIST dataset by labels.

        Args:
                task_labels: list of list of labels, one for each dataset
                split: whether to use train or testing data

        Returns:
            List of (X, y) tuples representing each dataset
    """
    # Load MNIST data and normalize
    nb_classes = 10
    # input image dimensions
    img_rows, img_cols = 28, 28
    
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    
#     X_train = X_train.reshape(-1, 784)
#     X_test = X_test.reshape(-1, 784)
    
    
    if K.image_data_format() == 'channels_first':
      X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
      X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
      input_shape = (1, img_rows, img_cols)
    else:
      X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
      X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
      input_shape = (img_rows, img_cols, 1)
      
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_train /= 255
    X_test /= 255
    
#     # convert class vectors to binary class matrices
#     y_train = keras.utils.to_categorical(y_train, nb_classes)
#     y_test = keras.utils.to_categorical(y_test, nb_classes)

    if split == 'train':
        X, y = X_train, y_train
    else:
        X, y = X_test, y_test

    return split_dataset_by_labels(X, y, task_labels, nb_classes, multihead)
  
  
class TestCallback(keras.callbacks.Callback):
  def __init__(self, test_data):
      self.test_data = test_data

  def on_epoch_end(self, epoch, logs={}):
      x, y = self.test_data
      loss, acc = self.model.evaluate(x, y, verbose=0)
      if(acc > 0.95):
        Lambda = 1000 * S_lam
        print("lam=",Lambda)
      elif(acc > 0.90):
        Lambda = 800 * S_lam
        print("lam=",Lambda)
      elif(acc > 0.85):
        Lambda = 700 * S_lam
        print("lam=",Lambda)
      elif(acc > 0.80):
        Lambda = 600 * S_lam
        print("lam=",Lambda)
      elif(acc > 0.75):
        Lambda = 300 * S_lam
        print("lam=",Lambda)
      elif(acc > 0.70):
        Lambda = 200 * S_lam
        print("lam=",Lambda)
      elif(acc > 0.65):
        Lambda = 100 * S_lam
        print("lam=",Lambda)
      #print('\nTesting loss: {}, acc: {}\n'.format(loss, acc))
      

np.random.seed(104)
Batch_size = 65536
Epochs = 100

global S_lam 
global Lambda
S_lam = 0.1
Lambda = S_lam

task_labels = [[0,1], [2,3]]#, [4,5], [6,7], [8,9]]
#task_labels = [[0,1], [2,3], [4,5], [6,7], [8,9]]
# task_labels = [[0,1,2,3,4], [5,6,7,8,9]]
n_tasks = len(task_labels)
training_datasets =  construct_split_mnist(task_labels, split='train')
validation_datasets = construct_split_mnist(task_labels, split='test')

print(validation_datasets[0][1].shape)
#
# ####Display three Tasks Dataset images
# plt.figure()
# plt.subplot(1, 2, 1)
# plt.imshow(training_datasets[0][0][0], cmap='gray')
# plt.title('Task A')
# plt.axis('off')
# plt.subplot(1, 2, 2)
# plt.imshow(training_datasets[1][0][0], cmap='gray')
# plt.title('Task B')
# plt.axis('off')
# plt.show()


##### Task A training and save the prior weights for the next Task
model = Sequential()
model.add(InputLayer(input_shape=(28,28,1)))
model.add(Conv2D(8, (5, 5), padding="same", activation="relu"))
model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(Conv2D(16, (5, 5), padding="same", activation="relu"))
model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(10, activation='softmax'))
model.summary()
model.compile('adam', 'categorical_crossentropy', metrics=['accuracy'])
model.fit(training_datasets[0][0], training_datasets[0][1], Batch_size, Epochs, validation_data=(validation_datasets[0][0], validation_datasets[0][1]))
model.save('MNISTA.h5')



##### Compute the Fisher Information for each parameter in Task A
print('Processing Fisher Information...')
I = computer_fisher(model, training_datasets[0][0])
print('Processing Finish!')

##### Task B EWC training
model_ewcB = Sequential()
model_ewcB.add(InputLayer(input_shape=(28,28,1)))
print("lam_check=",Lambda)
model_ewcB.add(Conv2D(8, (5, 5), padding="same", activation="relu",kernel_regularizer=ewc_reg(I[0], model.weights[0]),
                      bias_regularizer=ewc_reg(I[1], model.weights[1])))
lam = ewc_reg.get_config(ewc_reg(I[1], model.weights[1]))
print("lam_check=",lam)
model_ewcB.add(AveragePooling2D(pool_size=(2, 2)))
model_ewcB.add(Conv2D(16, (5, 5), padding="same", activation="relu",kernel_regularizer=ewc_reg(I[2], model.weights[2]),
                      bias_regularizer=ewc_reg(I[3], model.weights[3])))
model_ewcB.add(AveragePooling2D(pool_size=(2, 2)))
# model_ewcB.add(Dense(10, activation='relu', input_dim=784, kernel_regularizer=ewc_reg(I[0], model.weights[0]),
#                  bias_regularizer=ewc_reg(I[1], model.weights[1])))
model_ewcB.add(Flatten())
model_ewcB.add(Dense(10, activation='softmax', kernel_regularizer=ewc_reg(I[4], model.weights[4]),
                 bias_regularizer=ewc_reg(I[5], model.weights[5])))
model_ewcB.compile('adam', 'categorical_crossentropy', metrics=['accuracy'])
model_ewcB.load_weights('MNISTA.h5')
model_ewcB.fit(training_datasets[1][0], training_datasets[1][1],Batch_size, Epochs, validation_data=(validation_datasets[1][0],validation_datasets[1][1]),
              callbacks=[TestCallback((validation_datasets[1][0],validation_datasets[1][1]))])

# # Task B no penalty training
# model_NoP_B = Sequential()
# model_NoP_B.add(Dense(10, activation='relu', input_dim=784))
# model_NoP_B.add(Dense(10, activation='softmax'))
# model_NoP_B.compile('adam', 'categorical_crossentropy', metrics=['accuracy'])
# model_NoP_B.load_weights('MNISTA.h5')
# model_NoP_B.fit(training_datasets[1][0], training_datasets[1][1], 100, 10, validation_data=(validation_datasets[1][0],validation_datasets[1][1]))

# Current Task Performance
B_EWC = 100 * model_ewcB.evaluate(validation_datasets[1][0],validation_datasets[1][1], verbose=0)[1]
# B_No_P = 100 * model_NoP_B.evaluate(validation_datasets[1][0],validation_datasets[1][1], verbose=0)[1]
# Previous Task Performance
A_EWC = 100 * model_ewcB.evaluate(validation_datasets[0][0],validation_datasets[0][1], verbose=0)[1]
# A_No_P = 100 * model_NoP_B.evaluate(validation_datasets[0][0],validation_datasets[0][1], verbose=0)[1]

print("Task A Original Accuracy: %.2f%%" % (100 * model.evaluate(validation_datasets[0][0], validation_datasets[0][1])[1]))
print("Task B EWC method penalty Accuracy: %.2f%%" % B_EWC)
# print("Task B SGD method Accuracy: %.2f%%" % B_No_P)
print("Task A EWC method penalty Accuracy: %.2f%%" % A_EWC)
# print("Task A SGD method Accuracy: %.2f%%" % A_No_P)

x = 0
total_width, n = 0.1, 2
width = total_width / n
x = x - (total_width - width) / 2
plt.style.use('ggplot')
plt.bar(x, B_EWC, width=width, label='EWC Task B', hatch='w/', ec='w')
# plt.bar(x + width, B_No_P, width=width, label='SGD Task B', hatch='w/', ec='w')
plt.bar(x + 3.5 * width, A_EWC, width=width, label='EWC Task A', hatch='w/', ec='w')
# plt.bar(x + 4.5 * width, A_No_P, width=width, label='SGD Task A', hatch='w/', ec='w')
plt.legend(facecolor='white')
plt.xticks(np.array([0., 3.5 * width]), ('Current', 'Previous'))
plt.title('EWC method vs SGD method on \n Current task and Previous task')
plt.xlim(-0.15, 0.35)
plt.ylim(0., 105.)
plt.show()


(2115, 10)
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_42 (Conv2D)           (None, 28, 28, 8)         208       
_________________________________________________________________
average_pooling2d_39 (Averag (None, 14, 14, 8)         0         
_________________________________________________________________
conv2d_43 (Conv2D)           (None, 14, 14, 16)        3216      
_________________________________________________________________
average_pooling2d_40 (Averag (None, 7, 7, 16)          0         
_________________________________________________________________
flatten_20 (Flatten)         (None, 784)               0         
_________________________________________________________________
dense_20 (Dense)             (None, 10)                7850      
Total params: 11,274
Trainable params: 11,274
Non-trainable params: 0
_____________________________________________________________

KeyboardInterrupt: ignored