## Using another approach where the architecture of the network is constant while learning each task. We use the method called "Elastic Weight Consolidation".

#### Get the source code from https://github.com/ganguli-lab/pathint 
Put this jupytper notebook into `fig_split_mnist` folder. Then run.

In [1]:
%load_ext autoreload
%autoreload 2
%pylab inline

import tensorflow as tf
slim = tf.contrib.slim
graph_replace = tf.contrib.graph_editor.graph_replace

import sys, os
sys.path.extend([os.path.expanduser('..')])
from pathint import utils
import seaborn as sns
sns.set_style("ticks")

from tqdm import trange, tqdm

# import operator
import matplotlib.colors as colors
import matplotlib.cm as cmx

rcParams['pdf.fonttype'] = 42
rcParams['ps.fonttype'] = 42

Populating the interactive namespace from numpy and matplotlib


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  from ._conv import register_converters as _register_converters



For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



Using TensorFlow backend.


## Parameters

In [2]:
select = tf.select if hasattr(tf, 'select') else tf.where

In [3]:
# Data params
input_dim = 784
output_dim = 10

# Network params
n_hidden_units = 50
activation_fn = tf.nn.relu

# Optimization params
batch_size = 1000
epochs_per_task = 5

n_stats = 1

# Reset optimizer after each age
reset_optimizer = True

## Construct datasets

In [4]:
task_labels = [[0,1], [2,3], [4,5], [6,7], [8,9]]
#task_labels = [[0,1], [2,3], [4,5], [6,7], [8,9], [4,6],[8,1],[0,3],[2,4],[5,7]]
#task_labels = [[0,1], [2,3], [4,5], [6,7], [8,9], [4,6],[8,1],[0,3],[2,9],[5,7]]
# task_labels = [[0,1,2,3,4], [5,6,7,8,9]]
n_tasks = len(task_labels)
training_datasets = utils.construct_split_mnist(task_labels, split='train')
validation_datasets = utils.construct_split_mnist(task_labels, split='test')
# training_datasets = utils.mk_training_validation_splits(full_datasets, split_fractions=(0.9, 0.1))

## Construct network, loss, and updates

In [5]:
tf.reset_default_graph()

In [6]:
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.InteractiveSession(config=config)
sess.run(tf.global_variables_initializer())

In [7]:
# tf.equal(output_mask[None, :], 1.0)

In [8]:
import keras.backend as K
import keras.activations as activations

output_mask = tf.Variable(tf.zeros(output_dim), name="mask", trainable=False)

def masked_softmax(logits):
    # logits are [batch_size, output_dim]
    x = select(tf.tile(tf.equal(output_mask[None, :], 1.0), [tf.shape(logits)[0], 1]), logits, -1e32 * tf.ones_like(logits))
    return activations.softmax(x)

def set_active_outputs(labels):
    new_mask = np.zeros(output_dim)
    for l in labels:
        new_mask[l] = 1.0
    sess.run(output_mask.assign(new_mask))
    #print(sess.run(output_mask))
    
def masked_predict(model, data, targets):
    pred = model.predict(data)
    #print(pred)
    acc = np.argmax(pred,1)==np.argmax(targets,1)
    return acc.mean()

Instructions for updating:
Colocations handled automatically by placer.


In [9]:
from keras.models import Sequential
from keras.layers import Dense

from keras.layers import Activation
from keras.utils.generic_utils import get_custom_objects
def custom_activation(x):
    A=K.relu(x)
    return(K.log(1+A))
get_custom_objects().update({'custom_activation':Activation(custom_activation)})

model = Sequential()
model.add(Dense(n_hidden_units, activation=activation_fn, input_shape=(input_dim,)))
#model.add(Dense(n_hidden_units, activation=activation_fn))
model.add(Dense(output_dim, kernel_initializer='zero', activation=masked_softmax))
#model.add(Dense(output_dim, activation=masked_softmax, input_shape=(input_dim,)))

In [10]:
from pathint import protocols
from pathint.optimizers import KOOptimizer
from keras.optimizers import Adam, RMSprop,SGD
from keras.callbacks import Callback
from pathint.keras_utils import LossHistory
from keras.callbacks import History 
from keras.callbacks import LambdaCallback

#protocol_name, protocol = protocols.PATH_INT_PROTOCOL(omega_decay='sum',xi=1e-3)
protocol_name, protocol = protocols.PATH_INT_PROTOCOL(omega_decay='sum',xi=1e-3)
#protocol_name, protocol = protocols.FISHER_PROTOCOL('sum')
opt = Adam(lr=1e-3, beta_1=0.9, beta_2=0.999)
# opt = SGD(1e-3)
# opt = RMSprop(lr=1e-3)
oopt = KOOptimizer(opt, model=model, compute_fisher=False, **protocol)
model.compile(loss='categorical_crossentropy', optimizer=oopt, metrics=['accuracy'])
model.model._make_train_function()
saved_weights = model.get_weights()

save_weights_epoch=[]
save_loss_epoch=[]
print_weights = LambdaCallback(on_epoch_end=lambda batch, logs: save_weights_epoch.append(model.get_weights()))
history = LossHistory()
callbacks = [history]
datafile_name = "split_mnist_data_%s.pkl.gz"%protocol_name

Instructions for updating:
keep_dims is deprecated, use keepdims instead
Instructions for updating:
Use tf.cast instead.


## Train!

In [11]:
import pdb
def run_fits(cvals, training_data, valid_data, eval_on_train_set=False, nstats=1):
    acc_mean = dict()
    acc_std = dict()
    model_weights_save = []   #Empty list to save the model weights aftertraining each task
    for cidx, cval_ in enumerate(tqdm(cvals)):
        runs = []
        for runid in range(nstats):
            sess.run(tf.global_variables_initializer())
            # model.set_weights(saved_weights)
            cstuffs = []
            evals = []
            print("setting cval")
            cval = cval_
            oopt.set_strength(cval)
            oopt.init_task_vars()
            print("cval is", sess.run(oopt.lam))
            for age, tidx in enumerate(range(n_tasks)):
                print("Task %i"%(age))
                set_active_outputs(task_labels[age])
                stuffs = model.fit(training_data[tidx][0], training_data[tidx][1], batch_size, epochs_per_task, callbacks=[history,print_weights], verbose=0)
                oopt.update_task_metrics(training_data[tidx][0], training_data[tidx][1], batch_size)
                oopt.update_task_vars()
                ftask = []
                model_weights_save.append(model.get_weights()) #Save the model weights aftertraining each task
                for j in range(n_tasks):
                    set_active_outputs(task_labels[j])
                    if eval_on_train_set:
                        f_ = masked_predict(model, training_data[j][0], training_data[j][1])
                    else:
                        f_ = masked_predict(model, valid_data[j][0], valid_data[j][1])
                    ftask.append(np.mean(f_))
                print("Accuracy", ftask)
                evals.append(ftask)
                cstuffs.append(stuffs)

                # Re-initialize optimizater variables
                if reset_optimizer:
                    oopt.reset_optimizer()

            evals = np.array(evals)
            runs.append(evals)
        
        runs = np.array(runs)
        acc_mean[cval_] = runs.mean(0)
        acc_std[cval_] = runs.std(0)
    return dict(mean=acc_mean, std=acc_std),model_weights_save,cstuffs

In [12]:
# cvals = np.concatenate(([0], np.logspace(-2, 2, 10)))
# cvals = np.concatenate(([0], np.logspace(-1, 2, 2)))
# cvals = np.concatenate(([0], np.logspace(-2, 0, 3)))
cvals = np.logspace(-3, 3, 7)#[0, 1.0, 2, 5, 10]
cvals = [1.0]
print(cvals)

[1.0]


In [13]:
#%%capture
recompute_data = True
if recompute_data:
    data,model_weights_save,cstuffs = run_fits(cvals, training_datasets, validation_datasets, eval_on_train_set=False, nstats=n_stats)
    utils.save_zipped_pickle(data, datafile_name)
    
for task_id in range(len(task_labels)):
    set_active_outputs(task_labels[task_id])    
    print('Task ', {task_id}, ' accuracy: ', masked_predict(model, validation_datasets[task_id][0], validation_datasets[task_id][1]))        

  0%|          | 0/1 [00:00<?, ?it/s]

setting cval
cval is 1.0
Task 0
Accuracy [0.9995271867612293, 0.5053868756121449, 0.5240128068303095, 0.4823766364551863, 0.49117498739283916]
Task 1
Accuracy [0.9995271867612293, 0.9701273261508325, 0.5240128068303095, 0.4823766364551863, 0.49117498739283916]
Task 2
Accuracy [0.9995271867612293, 0.910871694417238, 0.9754535752401281, 0.4823766364551863, 0.49117498739283916]
Task 3
Accuracy [0.9995271867612293, 0.9079333986287953, 0.971718249733191, 0.9879154078549849, 0.49117498739283916]
Task 4


100%|██████████| 1/1 [00:05<00:00,  5.42s/it]

Accuracy [0.9990543735224586, 0.9094025465230167, 0.9791889007470651, 0.9884189325276939, 0.9652042360060514]
Task  {0}  accuracy:  0.9990543735224586
Task  {1}  accuracy:  0.9094025465230167





Task  {2}  accuracy:  0.9791889007470651
Task  {3}  accuracy:  0.9884189325276939
Task  {4}  accuracy:  0.9652042360060514


#### Unlearning task 1 containing labels [2,3]. Please note that the task numbering starts from 0.

Zeroing only output layer weights.

In [14]:
#Accuracy aftering forgetting one of the task
#Assume that task 1 is forgetted.

forget_task=1
class_to_forget = task_labels[forget_task]
#Zero out the weights corresponding to this class
for task_id in range(len(task_labels)):
    set_active_outputs(task_labels[task_id])
    
    for k,layer in enumerate(model.layers):  # Exclude output layer
        if isinstance(layer, Dense):
            weights, biases = layer.get_weights()
            if layer == model.layers[-1]:
                for cl in class_to_forget:
                    weights[:, cl] = 0  # Zero out weights for the forgotten class in the output layer
                    biases[cl] = 0
            layer.set_weights([weights, biases])
    
    print('Task ', {task_id}, ' accuracy: ', masked_predict(model, validation_datasets[task_id][0], validation_datasets[task_id][1]))        

Task  {0}  accuracy:  0.9990543735224586
Task  {1}  accuracy:  0.5053868756121449
Task  {2}  accuracy:  0.9791889007470651
Task  {3}  accuracy:  0.9884189325276939
Task  {4}  accuracy:  0.9652042360060514


That shows the forgetting for task 1 as its accuracy is close to 50%. Other tasks accuracy is unimpacted. 

In [15]:
#Inspection of behavior of importance parameter big_omega after every task 
from numpy import count_nonzero
last_epoch_batchindex = []
counter = 0
for i,j in enumerate(history.batchindex):
    try:
        if history.batchindex[i+1] > j:
            pass
        else:
            counter = counter+1
            if counter%epochs_per_task == 0:
                last_epoch_batchindex.append(i)
    except:
        last_epoch_batchindex.append(i)
        

print('Batch-id for last epoch of each task:',last_epoch_batchindex)
print('*----------------------------*')
for i,epoch_id in enumerate(last_epoch_batchindex):
    key = list(history.big_omega[epoch_id].keys())
    print('Parameter importance (big_omega) for task {0} shape is '.format(i), [history.big_omega[epoch_id][ke].shape for ke in key]) 

non_zero_imp_param = []
print('*----------------------------*')
for i,epoch_id in enumerate(last_epoch_batchindex):
    key = list(history.big_omega[epoch_id].keys())
    non_zero_imp_param.append(sum([count_nonzero(history.big_omega[epoch_id][ke]) for ke in key]))
    print('Number of non-zero elements of importance (big_omega) for task {0} is '.format(i),non_zero_imp_param[i])

Batch-id for last epoch of each task: [64, 129, 189, 254, 314]
*----------------------------*
Parameter importance (big_omega) for task 0 shape is  [(50,), (10,), (50, 10), (784, 50)]
Parameter importance (big_omega) for task 1 shape is  [(50,), (10,), (50, 10), (784, 50)]
Parameter importance (big_omega) for task 2 shape is  [(50,), (10,), (50, 10), (784, 50)]
Parameter importance (big_omega) for task 3 shape is  [(50,), (10,), (50, 10), (784, 50)]
Parameter importance (big_omega) for task 4 shape is  [(50,), (10,), (50, 10), (784, 50)]
*----------------------------*
Number of non-zero elements of importance (big_omega) for task 0 is  0
Number of non-zero elements of importance (big_omega) for task 1 is  30426
Number of non-zero elements of importance (big_omega) for task 2 is  33539
Number of non-zero elements of importance (big_omega) for task 3 is  34353
Number of non-zero elements of importance (big_omega) for task 4 is  36096


In [30]:
import numpy as np

def find_zero_indices(array):
    zero_indices = []
    if isinstance(array, np.ndarray):
        zero_indices = np.argwhere(array == 0).tolist()
    else:
        if isinstance(array, (list, tuple)):
            for i in range(len(array)):
                if isinstance(array[i], (list, tuple)):
                    for j in range(len(array[i])):
                        if array[i][j] == 0:
                            zero_indices.append((i, j))
                else:
                    if array[i] == 0:
                        zero_indices.append((i,))
    return zero_indices

In [39]:
Importance_collect= {}
for i,epoch_id in enumerate(last_epoch_batchindex):
    key = list(history.big_omega[epoch_id].keys())
    Importance_collect[i] = [history.big_omega[epoch_id][ke] for ke in key]
    
print(key)

weight_indices_zero = {}
for i in range(1,len(last_epoch_batchindex)): #after frist task
    weight_indices_zero[i] = []
    for j in range(len(Importance_collect[i])):
        weight_indices_zero[i].append(find_zero_indices(Importance_collect[i][j]))

[<tf.Variable 'dense_1/bias:0' shape=(50,) dtype=float32_ref>, <tf.Variable 'dense_2/bias:0' shape=(10,) dtype=float32_ref>, <tf.Variable 'dense_2/kernel:0' shape=(50, 10) dtype=float32_ref>, <tf.Variable 'dense_1/kernel:0' shape=(784, 50) dtype=float32_ref>]


#### Working on this cell.

In [35]:
#Accuracy aftering forgetting one of the task
#Assume that task 1 is forgetted.

forget_task=1
class_to_forget = task_labels[forget_task]
#Zero out the weights corresponding to this class
for task_id in range(len(task_labels)):
    set_active_outputs(task_labels[task_id])
    
    for k,layer in enumerate(model.layers):  # Exclude output layer
        if isinstance(layer, Dense):
            weights, biases = layer.get_weights()
            if layer == model.layers[-1]:
                for cl in class_to_forget:
                    weights[:, cl] = 0  # Zero out weights for the forgotten class in the output layer
                    biases[cl] = 0
            else:
                for iids in weight_indices_zero[forget_task][k]:
                    biases[iids[0]] = 0
                for iids in weight_indices_zero[forget_task][k+2]:
                    biases[iids[0]] = 0
                    
            layer.set_weights([weights, biases])
    
    print('Task ', {task_id}, ' accuracy: ', masked_predict(model, validation_datasets[task_id][0], validation_datasets[task_id][1]))        

> <ipython-input-35-5b9edfeae5c1>(18)<module>()
-> for iids in weight_indices_zero[forget_task]:
(Pdb) p iids
[[2], [3], [4], [5], [6], [7], [8], [9]]
(Pdb) q


BdbQuit: 

In [16]:
#Inspection of losses after every task 
loss_after_task=[]
print('Batch-id for last epoch of each task:',last_epoch_batchindex)
print('*----------------------------*')
for i,epoch_id in enumerate(last_epoch_batchindex):
    loss_after_task.append(history.losses[epoch_id])
    print('Loss after task {0} '.format(i), history.losses[epoch_id]) 
print('*----------------------------*')

#Inspection of surrogate loss or regularization (sum(big_omega x (theta'-theta))) after every task 
surrogate_loss_after_task=[]
for i,epoch_id in enumerate(last_epoch_batchindex):
    surrogate_loss_after_task.append(history.regs[epoch_id])
    print('Surrogate loss after task {0} '.format(i), history.regs[epoch_id]) 

Batch-id for last epoch of each task: [64, 129, 189, 254, 314]
*----------------------------*
Loss after task 0  0.009885119
Loss after task 1  0.08796781
Loss after task 2  0.09093419
Loss after task 3  0.04081106
Loss after task 4  0.113942795
*----------------------------*
Surrogate loss after task 0  0.0
Surrogate loss after task 1  0.037397042
Surrogate loss after task 2  0.06461409
Surrogate loss after task 3  0.034605086
Surrogate loss after task 4  0.029006097


In [17]:
print(model.summary())

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 50)                39250     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                510       
Total params: 39,760
Trainable params: 39,760
Non-trainable params: 0
_________________________________________________________________
None
